From dc228f23c32549e0fb0b2f9030e67be58d14b75f Mon Sep 17 00:00:00 2001 From: Dennis Fetterly Date: Wed, 17 Jul 2013 13:30:36 -0700 Subject: [PATCH] Initial commit --- CommonCode/AzureUtils.cs | 104 + CommonCode/Constants.cs | 169 + CommonCode/DiscLocalMonitor.cs | 163 + CommonCode/DryadTracing.cs | 693 + .../DryadVertexServiceAuthorizationManager.cs | 118 + CommonCode/ExecutionHelper.cs | 297 + CommonCode/IDryadVertexCallback.cs | 80 + CommonCode/IDryadVertexService.cs | 204 + CommonCode/NativeMethods.cs | 453 + CommonCode/NetShareWrapper.cs | 519 + CommonCode/ProcessPathHelper.cs | 172 + CommonCode/ProcessState.cs | 38 + CommonCode/QueryUtility.cs | 85 + CommonCode/RetryFramework.cs | 358 + CommonCode/SchedulerHelper.cs | 753 ++ CommonCode/SoftAffinity.cs | 79 + Dryad.sln | 359 + Dryad.v11.suo | Bin 0 -> 115712 bytes .../VertexHost/system/channel/channel.vcxproj | 186 + .../system/channel/include/channelbuffer.h | 275 + .../system/channel/include/channelinterface.h | 731 ++ .../system/channel/include/channelitem.h | 294 + .../system/channel/include/channelmarshaler.h | 161 + .../channel/include/channelmemorybuffers.h | 141 + .../system/channel/include/channelparser.h | 560 + .../system/channel/include/concreterchannel.h | 114 + .../system/channel/include/recordarray.h | 485 + .../system/channel/include/recorditem.h | 239 + .../system/channel/include/recordparser.h | 201 + .../system/channel/src/channelbuffer.cpp | 196 + .../system/channel/src/channelbufferhdfs.cpp | 1325 ++ .../system/channel/src/channelbufferhdfs.h | 157 + .../channel/src/channelbuffernativereader.cpp | 1755 +++ .../channel/src/channelbuffernativereader.h | 241 + .../channel/src/channelbuffernativewriter.cpp | 3228 +++++ .../channel/src/channelbuffernativewriter.h | 336 + .../system/channel/src/channelbufferqueue.cpp | 777 ++ .../system/channel/src/channelbufferqueue.h | 116 + .../system/channel/src/channelfifo.cpp | 1421 +++ .../system/channel/src/channelfifo.h | 241 + .../system/channel/src/channelhelpers.cpp | 316 + .../system/channel/src/channelhelpers.h | 201 + .../system/channel/src/channelitem.cpp | 343 + .../system/channel/src/channelmarshaler.cpp | 115 + .../system/channel/src/channelparser.cpp | 879 ++ .../system/channel/src/channelreader.cpp | 1485 +++ .../system/channel/src/channelreader.h | 385 + .../system/channel/src/channelwriter.cpp | 1524 +++ .../system/channel/src/channelwriter.h | 318 + .../system/channel/src/concreterchannel.cpp | 1645 +++ .../channel/src/concreterchannelhelpers.h | 194 + .../system/channel/src/memorybuffers.cpp | 415 + .../system/channel/src/recorditem.cpp | 1140 ++ .../system/classlib/classlib.vcxproj | 205 + .../system/classlib/include/DrBList.h | 411 + .../system/classlib/include/DrCommon.h | 47 + .../classlib/include/DrCriticalSection.h | 520 + .../system/classlib/include/DrError.h | 557 + .../system/classlib/include/DrErrorDef.h | 77 + .../system/classlib/include/DrExecution.h | 103 + .../system/classlib/include/DrExitCodes.h | 72 + .../system/classlib/include/DrExitCodesDef.h | 64 + .../system/classlib/include/DrFPrint.h | 193 + .../classlib/include/DrFPrint_polynomials.h | 190 + .../system/classlib/include/DrFunctions.h | 119 + .../system/classlib/include/DrGuid.h | 155 + .../system/classlib/include/DrHash.h | 231 + .../system/classlib/include/DrHeap.h | 104 + .../system/classlib/include/DrList.h | 606 + .../system/classlib/include/DrLogging.h | 105 + .../system/classlib/include/DrMemory.h | 362 + .../system/classlib/include/DrMemoryStream.h | 2302 ++++ .../system/classlib/include/DrNodeAddress.h | 897 ++ .../system/classlib/include/DrProperties.h | 79 + .../system/classlib/include/DrPropertiesDef.h | 60 + .../classlib/include/DrPropertyDumper.h | 712 ++ .../system/classlib/include/DrPropertyType.h | 63 + .../system/classlib/include/DrRefCounter.h | 839 ++ .../system/classlib/include/DrStringUtil.h | 1011 ++ .../system/classlib/include/DrTags.h | 410 + .../system/classlib/include/DrTagsDef.h | 33 + .../system/classlib/include/DrThread.h | 1036 ++ .../system/classlib/include/DrTypes.h | 118 + .../system/classlib/include/Dryad.h | 2185 ++++ .../system/classlib/include/DryadTags.h | 57 + .../system/classlib/include/DryadTagsDef.h | 33 + .../system/classlib/include/Interlocked.h | 150 + .../system/classlib/include/LogIds.h | 846 ++ .../classlib/include/LogIdsCustomized.h | 91 + .../system/classlib/include/LogTagIds.h | 612 + .../system/classlib/include/MSMutex.h | 101 + .../system/classlib/include/PropertyIds.h | 49 + .../system/classlib/include/RefCount.h | 119 + .../system/classlib/include/XCompute.h | 1783 +++ .../system/classlib/include/XComputeTypes.h | 1236 ++ .../system/classlib/include/basic_types.h | 321 + .../system/classlib/include/fingerprint.h | 54 + .../system/classlib/include/ms_fprint.h | 57 + .../system/classlib/src/DrCriticalSection.cpp | 196 + .../system/classlib/src/DrError.cpp | 480 + .../system/classlib/src/DrExecution.cpp | 50 + .../system/classlib/src/DrExitCodes.cpp | 209 + .../system/classlib/src/DrFPrint.cpp | 283 + .../system/classlib/src/DrFunctions.cpp | 1680 +++ .../VertexHost/system/classlib/src/DrGuid.cpp | 359 + .../VertexHost/system/classlib/src/DrHash.cpp | 528 + .../VertexHost/system/classlib/src/DrHeap.cpp | 182 + .../system/classlib/src/DrLogging.cpp | 195 + .../system/classlib/src/DrMemory.cpp | 386 + .../system/classlib/src/DrMemoryStream.cpp | 1218 ++ .../system/classlib/src/DrNodeAddress.cpp | 973 ++ .../system/classlib/src/DrRefCounter.cpp | 48 + .../system/classlib/src/DrStringUtil.cpp | 2490 ++++ .../system/classlib/src/DrThread.cpp | 769 ++ .../system/classlib/src/fingerprint.cpp | 82 + .../system/classlib/src/ms_fprint.cpp | 162 + .../VertexHost/system/common/common.vcxproj | 186 + .../system/common/include/CsEnhancedTimer.h | 241 + .../system/common/include/DObjPool.h | 260 + .../common/include/cosmospropertyblock.h | 59 + .../include/cosmosstreampropertyupdater.h | 30 + .../common/include/dryadcosmosresources.h | 53 + .../system/common/include/dryaderror.h | 59 + .../system/common/include/dryaderrordef.h | 35 + .../system/common/include/dryadeventcache.h | 50 + .../system/common/include/dryadlisthelper.h | 72 + .../system/common/include/dryadmetadata.h | 288 + .../system/common/include/dryadmetadatatag.h | 84 + .../common/include/dryadmetadatatagtypes.h | 559 + .../system/common/include/dryadnativeport.h | 153 + .../common/include/dryadopaqueresources.h | 46 + .../system/common/include/dryadproperties.h | 90 + .../common/include/dryadpropertiesdef.h | 64 + .../common/include/dryadpropertydumper.h | 26 + .../system/common/include/dryadpropertytype.h | 25 + .../common/include/dryadstandaloneini.h | 40 + .../system/common/include/dryadtags.h | 52 + .../system/common/include/dryadtagsdef.h | 32 + .../common/include/dryadxcomputeresources.h | 55 + .../system/common/include/dvertexcommand.h | 227 + .../system/common/include/errorreporter.h | 49 + .../system/common/include/orderedsendlatch.h | 150 + .../system/common/include/portmemorybuffers.h | 97 + .../system/common/include/workqueue.h | 73 + .../common/include/xcomputepropertyblock.h | 57 + .../system/common/include/yarnpropertyblock.h | 53 + .../VertexHost/system/common/src/DObjPool.cpp | 444 + .../system/common/src/dryadeventcache.cpp | 81 + .../system/common/src/dryadmetadata.cpp | 2456 ++++ .../system/common/src/dryadmetadatatag.cpp | 52 + .../common/src/dryadmetadatatagtypes.cpp | 1138 ++ .../system/common/src/dryadnativeport.cpp | 923 ++ .../common/src/dryadopaqueresources.cpp | 32 + .../system/common/src/dryadpropertydumper.cpp | 187 + .../system/common/src/dryadstandaloneini.cpp | 607 + .../common/src/dryadxcomputeresources.cpp | 150 + .../system/common/src/dvertexcommand.cpp | 997 ++ .../system/common/src/errorreporter.cpp | 88 + .../system/common/src/portmemorybuffers.cpp | 298 + .../system/common/src/workqueue.cpp | 373 + .../common/src/xcomputepropertyblock.cpp | 84 + .../system/common/src/yarnpropertyblock.cpp | 90 + .../system/dprocess/dprocess.vcxproj | 170 + .../system/dprocess/include/dryadvertex.h | 590 + .../include/dvertexcosmosenvironment.h | 38 + .../dprocess/include/dvertexenvironment.h | 47 + .../system/dprocess/include/dvertexmain.h | 40 + .../include/dvertexxcomputeenvironment.h | 29 + .../system/dprocess/include/vertexfactory.h | 183 + .../system/dprocess/src/dryadvertex.cpp | 2009 +++ .../dprocess/src/dvertexcmdlinecontrol.cpp | 1073 ++ .../dprocess/src/dvertexcmdlinecontrol.h | 53 + .../dprocess/src/dvertexenvironment.cpp | 84 + .../system/dprocess/src/dvertexmain.cpp | 146 + .../system/dprocess/src/dvertexpncontrol.cpp | 1192 ++ .../system/dprocess/src/dvertexpncontrol.h | 111 + .../src/dvertexxcomputeenvironment.cpp | 62 + .../dprocess/src/dvertexxcomputepncontrol.cpp | 596 + .../dprocess/src/dvertexxcomputepncontrol.h | 59 + .../dprocess/src/dvertexyarnpncontrol.h | 58 + .../system/dprocess/src/subgraphvertex.cpp | 1408 ++ .../system/dprocess/src/subgraphvertex.h | 202 + .../system/dprocess/src/vertexfactory.cpp | 403 + .../vertex/WrapperNativeInfo/FifoChannel.cpp | 255 + .../WrapperNativeInfo/FifoInputChannel.cpp | 143 + .../WrapperNativeInfo/FifoOutputChannel.cpp | 101 + .../GzipCompressionChannelTransform.cpp | 296 + .../GzipDecompressionChannelTransform.cpp | 391 + .../vertex/WrapperNativeInfo/InputChannel.cpp | 164 + .../NullChannelTransform.cpp | 64 + .../WrapperNativeInfo/OutputChannel.cpp | 140 + .../WrapperNativeInfo/WrapperNativeInfo.cpp | 332 + .../WrapperNativeInfo.vcxproj | 151 + .../WrapperNativeInfo.vcxproj.filters | 51 + .../vertex/WrapperNativeInfo/stdafx.h | 25 + .../DryadLINQNativeChannels.def | 19 + .../WrapperNativeInfoDll.vcxproj | 182 + .../WrapperNativeInfoDll.vcxproj.filters | 32 + .../vertex/WrapperNativeInfoDll/stdafx.h | 25 + .../wrappernativeinfostubs.cpp | 168 + .../vertex/include/ChannelTransform.h | 39 + .../vertex/include/CompressionVertex.h | 46 + .../VertexHost/vertex/include/DataBlockItem.h | 73 + .../VertexHost/vertex/include/FifoChannel.h | 69 + .../vertex/include/FifoInputChannel.h | 47 + .../vertex/include/FifoOutputChannel.h | 42 + .../include/GzipCompressionChannelTransform.h | 62 + .../GzipDecompressionChannelTransform.h | 62 + .../VertexHost/vertex/include/InputChannel.h | 52 + .../vertex/include/ManagedWrapper.h | 49 + .../vertex/include/NullChannelTransform.h | 36 + .../VertexHost/vertex/include/OutputChannel.h | 51 + .../vertex/include/wrappernativeinfo.h | 150 + .../managedwrappervertex/DataBlockItem.cpp | 122 + .../ManagedWrapperVertex.cpp | 357 + .../ManagedWrapperVertex.vcxproj | 150 + .../ManagedWrapperVertex.vcxproj.filters | 35 + .../vertex/managedwrappervertex/stdafx.h | 25 + .../vertex/vertexHost/VertexHost.vcxproj | 200 + .../vertexHost/VertexHost.vcxproj.filters | 6 + .../vertex/vertexHost/vertexHost.cpp | 526 + DryadVertex/service/DryadVertexService.csproj | 150 + DryadVertex/service/ReplyDispatcher.cs | 416 + .../service/VertexCallbackServiceClient.cs | 51 + DryadVertex/service/VertexProcess.cs | 1089 ++ DryadVertex/service/VertexService.cs | 507 + DryadVertex/service/app.config | 3 + DryadVertex/service/program.cs | 217 + DryadYarnBridge/DryadYarnBridge.cpp | 52 + DryadYarnBridge/DryadYarnBridge.def | 9 + DryadYarnBridge/DryadYarnBridge.h | 32 + DryadYarnBridge/DryadYarnBridge.vcxproj | 198 + .../DryadYarnBridge.vcxproj.filters | 56 + DryadYarnBridge/YarnAppMasterManaged.cpp | 128 + DryadYarnBridge/YarnAppMasterManaged.h | 54 + DryadYarnBridge/YarnAppMasterNative.cpp | 317 + DryadYarnBridge/YarnAppMasterNative.h | 85 + DryadYarnBridge/YarnDryadBridge.h | 36 + DryadYarnBridge/dllmain.cpp | 42 + DryadYarnBridge/stdafx.h | 33 + DryadYarnBridge/yarndryadbridge.cpp | 60 + GraphManager/GraphManager.vcxproj | 296 + GraphManager/GraphManager.vcxproj.filters | 348 + GraphManager/filesystem/DrFileSystems.h | 32 + GraphManager/filesystem/DrHdfsClient.cpp | 468 + GraphManager/filesystem/DrHdfsClient.h | 98 + GraphManager/filesystem/DrPartitionFile.cpp | 518 + GraphManager/filesystem/DrPartitionFile.h | 73 + GraphManager/gang/DrGangHeaders.h | 28 + GraphManager/gang/DrMetaData.cpp | 150 + GraphManager/gang/DrMetaData.h | 41 + GraphManager/gang/DrMetaDataTag.cpp | 129 + GraphManager/gang/DrMetaDataTag.h | 201 + GraphManager/gang/DrProperties.h | 105 + GraphManager/gang/DrProperty.cpp | 556 + GraphManager/gang/DrProperty.h | 132 + GraphManager/graph/DrDefaultParameters.h | 30 + GraphManager/graph/DrFileSystem.cpp | 206 + GraphManager/graph/DrFileSystem.h | 114 + GraphManager/graph/DrGraphExecutor.cpp | 91 + GraphManager/graph/DrGraphExecutor.h | 41 + GraphManager/graph/DrGraphHeaders.h | 27 + GraphManager/graph/DrGraphParameters.cpp | 76 + GraphManager/jobmanager/DrHeaders.h | 31 + GraphManager/jobmanager/targetver.h | 28 + GraphManager/jobmanager/version.cpp | 23 + GraphManager/kernel/DrCluster.cpp | 338 + GraphManager/kernel/DrCluster.h | 148 + GraphManager/kernel/DrKernel.h | 29 + GraphManager/kernel/DrMessagePump.cpp | 674 + GraphManager/kernel/DrMessagePump.h | 317 + GraphManager/kernel/DrProcess.cpp | 507 + GraphManager/kernel/DrProcess.h | 274 + GraphManager/kernel/DrXCompute.h | 89 + GraphManager/kernel/drxcompute.cpp | 1107 ++ GraphManager/kernel/drxcomputeinternal.h | 164 + .../reporting/DrArtemisLegacyReporting.cpp | 350 + .../reporting/DrArtemisLegacyReporting.h | 41 + GraphManager/reporting/DrReporting.h | 25 + GraphManager/shared/DrArray.h | 170 + GraphManager/shared/DrArrayList.h | 469 + GraphManager/shared/DrAssert.h | 23 + GraphManager/shared/DrCritSec.h | 379 + GraphManager/shared/DrDictionary.h | 451 + GraphManager/shared/DrError.cpp | 72 + GraphManager/shared/DrError.h | 104 + GraphManager/shared/DrErrorInternal.h | 35 + GraphManager/shared/DrFileWriter.cpp | 208 + GraphManager/shared/DrFileWriter.h | 62 + GraphManager/shared/DrLogging.cpp | 763 ++ GraphManager/shared/DrLogging.h | 99 + GraphManager/shared/DrMultiMap.h | 88 + GraphManager/shared/DrRef.cpp | 126 + GraphManager/shared/DrRef.h | 544 + GraphManager/shared/DrSet.h | 220 + GraphManager/shared/DrShared.h | 56 + GraphManager/shared/DrSort.h | 46 + GraphManager/shared/DrString.cpp | 539 + GraphManager/shared/DrString.h | 161 + GraphManager/shared/DrStringUtil.cpp | 278 + GraphManager/shared/DrStringUtil.h | 32 + GraphManager/shared/DrTypes.h | 50 + .../stagemanager/DrDefaultManager.cpp | 973 ++ GraphManager/stagemanager/DrDefaultManager.h | 375 + .../DrDynamicAggregateManager.cpp | 1594 +++ .../stagemanager/DrDynamicAggregateManager.h | 267 + .../stagemanager/DrDynamicBroadcast.cpp | 242 + .../stagemanager/DrDynamicBroadcast.h | 62 + .../stagemanager/DrDynamicDistributor.cpp | 388 + .../stagemanager/DrDynamicDistributor.h | 124 + .../DrDynamicRangeDistributor.cpp | 156 + .../stagemanager/DrDynamicRangeDistributor.h | 79 + .../stagemanager/DrPipelineSplitManager.cpp | 339 + .../stagemanager/DrPipelineSplitManager.h | 84 + GraphManager/stagemanager/DrStageHeaders.h | 32 + .../stagemanager/DrStageStatistics.cpp | 652 + GraphManager/stagemanager/DrStageStatistics.h | 155 + GraphManager/vertex/DrClique.cpp | 180 + GraphManager/vertex/DrClique.h | 64 + GraphManager/vertex/DrCohort.cpp | 1001 ++ GraphManager/vertex/DrCohort.h | 170 + GraphManager/vertex/DrGraph.cpp | 463 + GraphManager/vertex/DrGraph.h | 125 + GraphManager/vertex/DrOutputGenerator.cpp | 349 + GraphManager/vertex/DrOutputGenerator.h | 157 + GraphManager/vertex/DrStageManager.h | 176 + GraphManager/vertex/DrVertex.cpp | 1744 +++ GraphManager/vertex/DrVertex.h | 436 + GraphManager/vertex/DrVertexCommand.cpp | 871 ++ GraphManager/vertex/DrVertexCommand.h | 197 + GraphManager/vertex/DrVertexHeaders.h | 35 + GraphManager/vertex/DrVertexRecord.cpp | 873 ++ GraphManager/vertex/DrVertexRecord.h | 248 + Hdfs/HdfsBridgeManaged/HdfsBridgeManaged.cpp | 319 + Hdfs/HdfsBridgeManaged/HdfsBridgeManaged.h | 92 + .../HdfsBridgeManaged.vcxproj | 155 + Hdfs/HdfsBridgeNative/HdfsBridgeNative.cpp | 1105 ++ Hdfs/HdfsBridgeNative/HdfsBridgeNative.h | 183 + .../HdfsBridgeNative/HdfsBridgeNative.vcxproj | 157 + Java/DryadAppMaster.java | 478 + Java/DryadLinqYarnApp.java | 284 + Java/HdfsBridge.java | 604 + Java/build.bat | 14 + LinqToDryad/Attributes.cs | 405 + LinqToDryad/BitVector.cs | 114 + LinqToDryad/CodeGenHelper.cs | 152 + LinqToDryad/Constants.cs | 163 + LinqToDryad/DataPath.cs | 239 + LinqToDryad/DataProvider.cs | 234 + LinqToDryad/DataSetInfo.cs | 800 ++ LinqToDryad/DryadBinaryReader.cs | 742 ++ LinqToDryad/DryadBinaryWriter.cs | 646 + LinqToDryad/DryadCodeGen.cs | 2376 ++++ LinqToDryad/DryadFactory.cs | 455 + LinqToDryad/DryadLinqCollection.cs | 1470 +++ LinqToDryad/DryadLinqDecomposition.cs | 917 ++ LinqToDryad/DryadLinqException.cs | 96 + LinqToDryad/DryadLinqExpression.cs | 1774 +++ LinqToDryad/DryadLinqExtension.cs | 247 + LinqToDryad/DryadLinqFaultCodes.cs | 335 + LinqToDryad/DryadLinqFileStream.cs | 249 + LinqToDryad/DryadLinqGlobals.cs | 205 + LinqToDryad/DryadLinqHelper.cs | 416 + LinqToDryad/DryadLinqIEnumerable.cs | 1346 ++ LinqToDryad/DryadLinqIQueryable.cs | 3528 +++++ LinqToDryad/DryadLinqJobSubmission.cs | 505 + LinqToDryad/DryadLinqLog.cs | 120 + LinqToDryad/DryadLinqMetaData.cs | 210 + LinqToDryad/DryadLinqNative.cs | 148 + LinqToDryad/DryadLinqObjectStore.cs | 171 + LinqToDryad/DryadLinqQuery.cs | 889 ++ LinqToDryad/DryadLinqSampler.cs | 249 + LinqToDryad/DryadLinqSerialization.cs | 755 ++ LinqToDryad/DryadLinqStream.cs | 151 + LinqToDryad/DryadLinqUtil.cs | 636 + LinqToDryad/DryadLinqVertex.cs | 10607 ++++++++++++++++ LinqToDryad/DryadLinqVertexParams.cs | 90 + LinqToDryad/DryadQueryDoc.cs | 54 + LinqToDryad/DryadQueryExplain.cs | 785 ++ LinqToDryad/DryadQueryGen.cs | 4658 +++++++ LinqToDryad/DryadQueryNode.cs | 4567 +++++++ LinqToDryad/DryadRecordReader.cs | 637 + LinqToDryad/DryadRecordWriter.cs | 564 + LinqToDryad/DryadRuntime.cs | 110 + LinqToDryad/DryadTextReader.cs | 233 + LinqToDryad/DryadTextWriter.cs | 247 + LinqToDryad/DryadVertexEnv.cs | 315 + LinqToDryad/DryadVertexReader.cs | 243 + LinqToDryad/DryadVertexWriter.cs | 180 + LinqToDryad/DscClientHelper.cs | 548 + LinqToDryad/DscStubs.cs | 178 + LinqToDryad/DynamicManager.cs | 205 + LinqToDryad/ExpressionMatcher.cs | 458 + LinqToDryad/ExpressionSimplifier.cs | 70 + LinqToDryad/ExpressionVisitor.cs | 756 ++ LinqToDryad/ForkTuple.cs | 182 + LinqToDryad/Hash64.cs | 355 + LinqToDryad/HpcJobSubmission.cs | 336 + LinqToDryad/HpcLinqCache.cs | 205 + LinqToDryad/HpcLinqConfiguration.cs | 567 + LinqToDryad/HpcLinqContext.cs | 304 + LinqToDryad/HpcLinqStringDictionary.cs | 170 + LinqToDryad/HpcLinqStringList.cs | 141 + LinqToDryad/IAssociative.cs | 49 + LinqToDryad/IDecomposable.cs | 70 + LinqToDryad/IDryadLinqJobSubmission.cs | 70 + LinqToDryad/LineRecord.cs | 122 + LinqToDryad/LinqToDryad.csproj | 161 + LinqToDryad/MultiBlockStream.cs | 229 + LinqToDryad/MultiEnumerable.cs | 169 + LinqToDryad/MultiQueryable.cs | 330 + LinqToDryad/NativeBlockStream.cs | 239 + LinqToDryad/QueryTraceLevel.cs | 39 + LinqToDryad/SR.Designer.cs | 1953 +++ LinqToDryad/SR.resx | 768 ++ LinqToDryad/SimpleRewriter.cs | 287 + LinqToDryad/TypeSystem.cs | 1393 ++ LinqToDryad/VertexCodeGen.cs | 356 + LinqToDryad/WebHdfsClient.cs | 93 + LinqToDryad/YarnJobSubmission.cs | 516 + LinqToDryad/YarnScheduler.cs | 48 + LinqToDryad/sr.txt | 253 + README.txt | 47 + .../DryadLinqApplication.cs | 410 + linqtodryadjm_managed_yarn/GraphBuilder.cs | 748 ++ .../LinqToDryadException.cs | 41 + linqtodryadjm_managed_yarn/LinqToDryadJM.cs | 297 + linqtodryadjm_managed_yarn/Program.cs | 191 + .../Properties/AssemblyInfo.cs | 54 + linqtodryadjm_managed_yarn/Query.cs | 162 + linqtodryadjm_managed_yarn/QueryParser.cs | 428 + linqtodryadjm_managed_yarn/app.config | 3 + .../linqtodryadjm_managed.csproj | 127 + xcompute_managed/Dispatcher.cs | 772 ++ xcompute_managed/DispatcherPool.cs | 150 + xcompute_managed/JobStatus.cs | 128 + ...osoft.Research.Dryad.ClusterAdapter.csproj | 169 + xcompute_managed/ProcessTable.cs | 109 + xcompute_managed/RequestPool.cs | 97 + xcompute_managed/ScheduleProcessRequest.cs | 110 + xcompute_managed/VertexCallbackService.cs | 80 + xcompute_managed/VertexCallbackServiceHost.cs | 135 + xcompute_managed/VertexScheduler.cs | 1251 ++ xcompute_managed/VertexServiceClient.cs | 201 + xcompute_managed/XComputeProcess.cs | 554 + .../YarnQueryNativeClusterAdapter.vcxproj | 211 + xcompute_native/async.cpp | 133 + xcompute_native/context.cpp | 147 + xcompute_native/dllmain.cpp | 40 + xcompute_native/file.cpp | 257 + xcompute_native/inc/XCompute.h | 1741 +++ xcompute_native/inc/XComputeTypes.h | 1311 ++ xcompute_native/inc/auto_any.h | 331 + xcompute_native/inc/scoped_any.h | 273 + xcompute_native/inc/smart_any_fwd.h | 1079 ++ xcompute_native/locality.cpp | 116 + xcompute_native/node.cpp | 435 + xcompute_native/path.cpp | 199 + xcompute_native/process.cpp | 710 ++ xcompute_native/property.cpp | 423 + xcompute_native/scheduler.cpp | 258 + xcompute_native/session.cpp | 152 + xcompute_native/status.cpp | 111 + xcompute_native/stdafx.cpp | 28 + xcompute_native/stdafx.h | 47 + xcompute_native/targetver.h | 44 + xcompute_native/xcimpl.h | 127 + xcompute_native/xcompute.cpp | 111 + 468 files changed, 183623 insertions(+) create mode 100644 CommonCode/AzureUtils.cs create mode 100644 CommonCode/Constants.cs create mode 100644 CommonCode/DiscLocalMonitor.cs create mode 100644 CommonCode/DryadTracing.cs create mode 100644 CommonCode/DryadVertexServiceAuthorizationManager.cs create mode 100644 CommonCode/ExecutionHelper.cs create mode 100644 CommonCode/IDryadVertexCallback.cs create mode 100644 CommonCode/IDryadVertexService.cs create mode 100644 CommonCode/NativeMethods.cs create mode 100644 CommonCode/NetShareWrapper.cs create mode 100644 CommonCode/ProcessPathHelper.cs create mode 100644 CommonCode/ProcessState.cs create mode 100644 CommonCode/QueryUtility.cs create mode 100644 CommonCode/RetryFramework.cs create mode 100644 CommonCode/SchedulerHelper.cs create mode 100644 CommonCode/SoftAffinity.cs create mode 100644 Dryad.sln create mode 100644 Dryad.v11.suo create mode 100644 DryadVertex/VertexHost/system/channel/channel.vcxproj create mode 100644 DryadVertex/VertexHost/system/channel/include/channelbuffer.h create mode 100644 DryadVertex/VertexHost/system/channel/include/channelinterface.h create mode 100644 DryadVertex/VertexHost/system/channel/include/channelitem.h create mode 100644 DryadVertex/VertexHost/system/channel/include/channelmarshaler.h create mode 100644 DryadVertex/VertexHost/system/channel/include/channelmemorybuffers.h create mode 100644 DryadVertex/VertexHost/system/channel/include/channelparser.h create mode 100644 DryadVertex/VertexHost/system/channel/include/concreterchannel.h create mode 100644 DryadVertex/VertexHost/system/channel/include/recordarray.h create mode 100644 DryadVertex/VertexHost/system/channel/include/recorditem.h create mode 100644 DryadVertex/VertexHost/system/channel/include/recordparser.h create mode 100644 DryadVertex/VertexHost/system/channel/src/channelbuffer.cpp create mode 100644 DryadVertex/VertexHost/system/channel/src/channelbufferhdfs.cpp create mode 100644 DryadVertex/VertexHost/system/channel/src/channelbufferhdfs.h create mode 100644 DryadVertex/VertexHost/system/channel/src/channelbuffernativereader.cpp create mode 100644 DryadVertex/VertexHost/system/channel/src/channelbuffernativereader.h create mode 100644 DryadVertex/VertexHost/system/channel/src/channelbuffernativewriter.cpp create mode 100644 DryadVertex/VertexHost/system/channel/src/channelbuffernativewriter.h create mode 100644 DryadVertex/VertexHost/system/channel/src/channelbufferqueue.cpp create mode 100644 DryadVertex/VertexHost/system/channel/src/channelbufferqueue.h create mode 100644 DryadVertex/VertexHost/system/channel/src/channelfifo.cpp create mode 100644 DryadVertex/VertexHost/system/channel/src/channelfifo.h create mode 100644 DryadVertex/VertexHost/system/channel/src/channelhelpers.cpp create mode 100644 DryadVertex/VertexHost/system/channel/src/channelhelpers.h create mode 100644 DryadVertex/VertexHost/system/channel/src/channelitem.cpp create mode 100644 DryadVertex/VertexHost/system/channel/src/channelmarshaler.cpp create mode 100644 DryadVertex/VertexHost/system/channel/src/channelparser.cpp create mode 100644 DryadVertex/VertexHost/system/channel/src/channelreader.cpp create mode 100644 DryadVertex/VertexHost/system/channel/src/channelreader.h create mode 100644 DryadVertex/VertexHost/system/channel/src/channelwriter.cpp create mode 100644 DryadVertex/VertexHost/system/channel/src/channelwriter.h create mode 100644 DryadVertex/VertexHost/system/channel/src/concreterchannel.cpp create mode 100644 DryadVertex/VertexHost/system/channel/src/concreterchannelhelpers.h create mode 100644 DryadVertex/VertexHost/system/channel/src/memorybuffers.cpp create mode 100644 DryadVertex/VertexHost/system/channel/src/recorditem.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/classlib.vcxproj create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrBList.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrCommon.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrCriticalSection.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrError.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrErrorDef.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrExecution.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrExitCodes.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrExitCodesDef.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrFPrint.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrFPrint_polynomials.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrFunctions.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrGuid.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrHash.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrHeap.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrList.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrLogging.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrMemory.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrMemoryStream.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrNodeAddress.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrProperties.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrPropertiesDef.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrPropertyDumper.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrPropertyType.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrRefCounter.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrStringUtil.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrTags.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrTagsDef.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrThread.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DrTypes.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/Dryad.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DryadTags.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/DryadTagsDef.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/Interlocked.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/LogIds.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/LogIdsCustomized.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/LogTagIds.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/MSMutex.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/PropertyIds.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/RefCount.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/XCompute.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/XComputeTypes.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/basic_types.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/fingerprint.h create mode 100644 DryadVertex/VertexHost/system/classlib/include/ms_fprint.h create mode 100644 DryadVertex/VertexHost/system/classlib/src/DrCriticalSection.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/DrError.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/DrExecution.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/DrExitCodes.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/DrFPrint.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/DrFunctions.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/DrGuid.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/DrHash.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/DrHeap.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/DrLogging.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/DrMemory.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/DrMemoryStream.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/DrNodeAddress.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/DrRefCounter.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/DrStringUtil.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/DrThread.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/fingerprint.cpp create mode 100644 DryadVertex/VertexHost/system/classlib/src/ms_fprint.cpp create mode 100644 DryadVertex/VertexHost/system/common/common.vcxproj create mode 100644 DryadVertex/VertexHost/system/common/include/CsEnhancedTimer.h create mode 100644 DryadVertex/VertexHost/system/common/include/DObjPool.h create mode 100644 DryadVertex/VertexHost/system/common/include/cosmospropertyblock.h create mode 100644 DryadVertex/VertexHost/system/common/include/cosmosstreampropertyupdater.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryadcosmosresources.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryaderror.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryaderrordef.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryadeventcache.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryadlisthelper.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryadmetadata.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryadmetadatatag.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryadmetadatatagtypes.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryadnativeport.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryadopaqueresources.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryadproperties.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryadpropertiesdef.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryadpropertydumper.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryadpropertytype.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryadstandaloneini.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryadtags.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryadtagsdef.h create mode 100644 DryadVertex/VertexHost/system/common/include/dryadxcomputeresources.h create mode 100644 DryadVertex/VertexHost/system/common/include/dvertexcommand.h create mode 100644 DryadVertex/VertexHost/system/common/include/errorreporter.h create mode 100644 DryadVertex/VertexHost/system/common/include/orderedsendlatch.h create mode 100644 DryadVertex/VertexHost/system/common/include/portmemorybuffers.h create mode 100644 DryadVertex/VertexHost/system/common/include/workqueue.h create mode 100644 DryadVertex/VertexHost/system/common/include/xcomputepropertyblock.h create mode 100644 DryadVertex/VertexHost/system/common/include/yarnpropertyblock.h create mode 100644 DryadVertex/VertexHost/system/common/src/DObjPool.cpp create mode 100644 DryadVertex/VertexHost/system/common/src/dryadeventcache.cpp create mode 100644 DryadVertex/VertexHost/system/common/src/dryadmetadata.cpp create mode 100644 DryadVertex/VertexHost/system/common/src/dryadmetadatatag.cpp create mode 100644 DryadVertex/VertexHost/system/common/src/dryadmetadatatagtypes.cpp create mode 100644 DryadVertex/VertexHost/system/common/src/dryadnativeport.cpp create mode 100644 DryadVertex/VertexHost/system/common/src/dryadopaqueresources.cpp create mode 100644 DryadVertex/VertexHost/system/common/src/dryadpropertydumper.cpp create mode 100644 DryadVertex/VertexHost/system/common/src/dryadstandaloneini.cpp create mode 100644 DryadVertex/VertexHost/system/common/src/dryadxcomputeresources.cpp create mode 100644 DryadVertex/VertexHost/system/common/src/dvertexcommand.cpp create mode 100644 DryadVertex/VertexHost/system/common/src/errorreporter.cpp create mode 100644 DryadVertex/VertexHost/system/common/src/portmemorybuffers.cpp create mode 100644 DryadVertex/VertexHost/system/common/src/workqueue.cpp create mode 100644 DryadVertex/VertexHost/system/common/src/xcomputepropertyblock.cpp create mode 100644 DryadVertex/VertexHost/system/common/src/yarnpropertyblock.cpp create mode 100644 DryadVertex/VertexHost/system/dprocess/dprocess.vcxproj create mode 100644 DryadVertex/VertexHost/system/dprocess/include/dryadvertex.h create mode 100644 DryadVertex/VertexHost/system/dprocess/include/dvertexcosmosenvironment.h create mode 100644 DryadVertex/VertexHost/system/dprocess/include/dvertexenvironment.h create mode 100644 DryadVertex/VertexHost/system/dprocess/include/dvertexmain.h create mode 100644 DryadVertex/VertexHost/system/dprocess/include/dvertexxcomputeenvironment.h create mode 100644 DryadVertex/VertexHost/system/dprocess/include/vertexfactory.h create mode 100644 DryadVertex/VertexHost/system/dprocess/src/dryadvertex.cpp create mode 100644 DryadVertex/VertexHost/system/dprocess/src/dvertexcmdlinecontrol.cpp create mode 100644 DryadVertex/VertexHost/system/dprocess/src/dvertexcmdlinecontrol.h create mode 100644 DryadVertex/VertexHost/system/dprocess/src/dvertexenvironment.cpp create mode 100644 DryadVertex/VertexHost/system/dprocess/src/dvertexmain.cpp create mode 100644 DryadVertex/VertexHost/system/dprocess/src/dvertexpncontrol.cpp create mode 100644 DryadVertex/VertexHost/system/dprocess/src/dvertexpncontrol.h create mode 100644 DryadVertex/VertexHost/system/dprocess/src/dvertexxcomputeenvironment.cpp create mode 100644 DryadVertex/VertexHost/system/dprocess/src/dvertexxcomputepncontrol.cpp create mode 100644 DryadVertex/VertexHost/system/dprocess/src/dvertexxcomputepncontrol.h create mode 100644 DryadVertex/VertexHost/system/dprocess/src/dvertexyarnpncontrol.h create mode 100644 DryadVertex/VertexHost/system/dprocess/src/subgraphvertex.cpp create mode 100644 DryadVertex/VertexHost/system/dprocess/src/subgraphvertex.h create mode 100644 DryadVertex/VertexHost/system/dprocess/src/vertexfactory.cpp create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfo/FifoChannel.cpp create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfo/FifoInputChannel.cpp create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfo/FifoOutputChannel.cpp create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfo/GzipCompressionChannelTransform.cpp create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfo/GzipDecompressionChannelTransform.cpp create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfo/InputChannel.cpp create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfo/NullChannelTransform.cpp create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfo/OutputChannel.cpp create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfo/WrapperNativeInfo.cpp create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfo/WrapperNativeInfo.vcxproj create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfo/WrapperNativeInfo.vcxproj.filters create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfo/stdafx.h create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/DryadLINQNativeChannels.def create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/WrapperNativeInfoDll.vcxproj create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/WrapperNativeInfoDll.vcxproj.filters create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/stdafx.h create mode 100644 DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/wrappernativeinfostubs.cpp create mode 100644 DryadVertex/VertexHost/vertex/include/ChannelTransform.h create mode 100644 DryadVertex/VertexHost/vertex/include/CompressionVertex.h create mode 100644 DryadVertex/VertexHost/vertex/include/DataBlockItem.h create mode 100644 DryadVertex/VertexHost/vertex/include/FifoChannel.h create mode 100644 DryadVertex/VertexHost/vertex/include/FifoInputChannel.h create mode 100644 DryadVertex/VertexHost/vertex/include/FifoOutputChannel.h create mode 100644 DryadVertex/VertexHost/vertex/include/GzipCompressionChannelTransform.h create mode 100644 DryadVertex/VertexHost/vertex/include/GzipDecompressionChannelTransform.h create mode 100644 DryadVertex/VertexHost/vertex/include/InputChannel.h create mode 100644 DryadVertex/VertexHost/vertex/include/ManagedWrapper.h create mode 100644 DryadVertex/VertexHost/vertex/include/NullChannelTransform.h create mode 100644 DryadVertex/VertexHost/vertex/include/OutputChannel.h create mode 100644 DryadVertex/VertexHost/vertex/include/wrappernativeinfo.h create mode 100644 DryadVertex/VertexHost/vertex/managedwrappervertex/DataBlockItem.cpp create mode 100644 DryadVertex/VertexHost/vertex/managedwrappervertex/ManagedWrapperVertex.cpp create mode 100644 DryadVertex/VertexHost/vertex/managedwrappervertex/ManagedWrapperVertex.vcxproj create mode 100644 DryadVertex/VertexHost/vertex/managedwrappervertex/ManagedWrapperVertex.vcxproj.filters create mode 100644 DryadVertex/VertexHost/vertex/managedwrappervertex/stdafx.h create mode 100644 DryadVertex/VertexHost/vertex/vertexHost/VertexHost.vcxproj create mode 100644 DryadVertex/VertexHost/vertex/vertexHost/VertexHost.vcxproj.filters create mode 100644 DryadVertex/VertexHost/vertex/vertexHost/vertexHost.cpp create mode 100644 DryadVertex/service/DryadVertexService.csproj create mode 100644 DryadVertex/service/ReplyDispatcher.cs create mode 100644 DryadVertex/service/VertexCallbackServiceClient.cs create mode 100644 DryadVertex/service/VertexProcess.cs create mode 100644 DryadVertex/service/VertexService.cs create mode 100644 DryadVertex/service/app.config create mode 100644 DryadVertex/service/program.cs create mode 100644 DryadYarnBridge/DryadYarnBridge.cpp create mode 100644 DryadYarnBridge/DryadYarnBridge.def create mode 100644 DryadYarnBridge/DryadYarnBridge.h create mode 100644 DryadYarnBridge/DryadYarnBridge.vcxproj create mode 100644 DryadYarnBridge/DryadYarnBridge.vcxproj.filters create mode 100644 DryadYarnBridge/YarnAppMasterManaged.cpp create mode 100644 DryadYarnBridge/YarnAppMasterManaged.h create mode 100644 DryadYarnBridge/YarnAppMasterNative.cpp create mode 100644 DryadYarnBridge/YarnAppMasterNative.h create mode 100644 DryadYarnBridge/YarnDryadBridge.h create mode 100644 DryadYarnBridge/dllmain.cpp create mode 100644 DryadYarnBridge/stdafx.h create mode 100644 DryadYarnBridge/yarndryadbridge.cpp create mode 100644 GraphManager/GraphManager.vcxproj create mode 100644 GraphManager/GraphManager.vcxproj.filters create mode 100644 GraphManager/filesystem/DrFileSystems.h create mode 100644 GraphManager/filesystem/DrHdfsClient.cpp create mode 100644 GraphManager/filesystem/DrHdfsClient.h create mode 100644 GraphManager/filesystem/DrPartitionFile.cpp create mode 100644 GraphManager/filesystem/DrPartitionFile.h create mode 100644 GraphManager/gang/DrGangHeaders.h create mode 100644 GraphManager/gang/DrMetaData.cpp create mode 100644 GraphManager/gang/DrMetaData.h create mode 100644 GraphManager/gang/DrMetaDataTag.cpp create mode 100644 GraphManager/gang/DrMetaDataTag.h create mode 100644 GraphManager/gang/DrProperties.h create mode 100644 GraphManager/gang/DrProperty.cpp create mode 100644 GraphManager/gang/DrProperty.h create mode 100644 GraphManager/graph/DrDefaultParameters.h create mode 100644 GraphManager/graph/DrFileSystem.cpp create mode 100644 GraphManager/graph/DrFileSystem.h create mode 100644 GraphManager/graph/DrGraphExecutor.cpp create mode 100644 GraphManager/graph/DrGraphExecutor.h create mode 100644 GraphManager/graph/DrGraphHeaders.h create mode 100644 GraphManager/graph/DrGraphParameters.cpp create mode 100644 GraphManager/jobmanager/DrHeaders.h create mode 100644 GraphManager/jobmanager/targetver.h create mode 100644 GraphManager/jobmanager/version.cpp create mode 100644 GraphManager/kernel/DrCluster.cpp create mode 100644 GraphManager/kernel/DrCluster.h create mode 100644 GraphManager/kernel/DrKernel.h create mode 100644 GraphManager/kernel/DrMessagePump.cpp create mode 100644 GraphManager/kernel/DrMessagePump.h create mode 100644 GraphManager/kernel/DrProcess.cpp create mode 100644 GraphManager/kernel/DrProcess.h create mode 100644 GraphManager/kernel/DrXCompute.h create mode 100644 GraphManager/kernel/drxcompute.cpp create mode 100644 GraphManager/kernel/drxcomputeinternal.h create mode 100644 GraphManager/reporting/DrArtemisLegacyReporting.cpp create mode 100644 GraphManager/reporting/DrArtemisLegacyReporting.h create mode 100644 GraphManager/reporting/DrReporting.h create mode 100644 GraphManager/shared/DrArray.h create mode 100644 GraphManager/shared/DrArrayList.h create mode 100644 GraphManager/shared/DrAssert.h create mode 100644 GraphManager/shared/DrCritSec.h create mode 100644 GraphManager/shared/DrDictionary.h create mode 100644 GraphManager/shared/DrError.cpp create mode 100644 GraphManager/shared/DrError.h create mode 100644 GraphManager/shared/DrErrorInternal.h create mode 100644 GraphManager/shared/DrFileWriter.cpp create mode 100644 GraphManager/shared/DrFileWriter.h create mode 100644 GraphManager/shared/DrLogging.cpp create mode 100644 GraphManager/shared/DrLogging.h create mode 100644 GraphManager/shared/DrMultiMap.h create mode 100644 GraphManager/shared/DrRef.cpp create mode 100644 GraphManager/shared/DrRef.h create mode 100644 GraphManager/shared/DrSet.h create mode 100644 GraphManager/shared/DrShared.h create mode 100644 GraphManager/shared/DrSort.h create mode 100644 GraphManager/shared/DrString.cpp create mode 100644 GraphManager/shared/DrString.h create mode 100644 GraphManager/shared/DrStringUtil.cpp create mode 100644 GraphManager/shared/DrStringUtil.h create mode 100644 GraphManager/shared/DrTypes.h create mode 100644 GraphManager/stagemanager/DrDefaultManager.cpp create mode 100644 GraphManager/stagemanager/DrDefaultManager.h create mode 100644 GraphManager/stagemanager/DrDynamicAggregateManager.cpp create mode 100644 GraphManager/stagemanager/DrDynamicAggregateManager.h create mode 100644 GraphManager/stagemanager/DrDynamicBroadcast.cpp create mode 100644 GraphManager/stagemanager/DrDynamicBroadcast.h create mode 100644 GraphManager/stagemanager/DrDynamicDistributor.cpp create mode 100644 GraphManager/stagemanager/DrDynamicDistributor.h create mode 100644 GraphManager/stagemanager/DrDynamicRangeDistributor.cpp create mode 100644 GraphManager/stagemanager/DrDynamicRangeDistributor.h create mode 100644 GraphManager/stagemanager/DrPipelineSplitManager.cpp create mode 100644 GraphManager/stagemanager/DrPipelineSplitManager.h create mode 100644 GraphManager/stagemanager/DrStageHeaders.h create mode 100644 GraphManager/stagemanager/DrStageStatistics.cpp create mode 100644 GraphManager/stagemanager/DrStageStatistics.h create mode 100644 GraphManager/vertex/DrClique.cpp create mode 100644 GraphManager/vertex/DrClique.h create mode 100644 GraphManager/vertex/DrCohort.cpp create mode 100644 GraphManager/vertex/DrCohort.h create mode 100644 GraphManager/vertex/DrGraph.cpp create mode 100644 GraphManager/vertex/DrGraph.h create mode 100644 GraphManager/vertex/DrOutputGenerator.cpp create mode 100644 GraphManager/vertex/DrOutputGenerator.h create mode 100644 GraphManager/vertex/DrStageManager.h create mode 100644 GraphManager/vertex/DrVertex.cpp create mode 100644 GraphManager/vertex/DrVertex.h create mode 100644 GraphManager/vertex/DrVertexCommand.cpp create mode 100644 GraphManager/vertex/DrVertexCommand.h create mode 100644 GraphManager/vertex/DrVertexHeaders.h create mode 100644 GraphManager/vertex/DrVertexRecord.cpp create mode 100644 GraphManager/vertex/DrVertexRecord.h create mode 100644 Hdfs/HdfsBridgeManaged/HdfsBridgeManaged.cpp create mode 100644 Hdfs/HdfsBridgeManaged/HdfsBridgeManaged.h create mode 100644 Hdfs/HdfsBridgeManaged/HdfsBridgeManaged.vcxproj create mode 100644 Hdfs/HdfsBridgeNative/HdfsBridgeNative.cpp create mode 100644 Hdfs/HdfsBridgeNative/HdfsBridgeNative.h create mode 100644 Hdfs/HdfsBridgeNative/HdfsBridgeNative.vcxproj create mode 100644 Java/DryadAppMaster.java create mode 100644 Java/DryadLinqYarnApp.java create mode 100644 Java/HdfsBridge.java create mode 100644 Java/build.bat create mode 100644 LinqToDryad/Attributes.cs create mode 100644 LinqToDryad/BitVector.cs create mode 100644 LinqToDryad/CodeGenHelper.cs create mode 100644 LinqToDryad/Constants.cs create mode 100644 LinqToDryad/DataPath.cs create mode 100644 LinqToDryad/DataProvider.cs create mode 100644 LinqToDryad/DataSetInfo.cs create mode 100644 LinqToDryad/DryadBinaryReader.cs create mode 100644 LinqToDryad/DryadBinaryWriter.cs create mode 100644 LinqToDryad/DryadCodeGen.cs create mode 100644 LinqToDryad/DryadFactory.cs create mode 100644 LinqToDryad/DryadLinqCollection.cs create mode 100644 LinqToDryad/DryadLinqDecomposition.cs create mode 100644 LinqToDryad/DryadLinqException.cs create mode 100644 LinqToDryad/DryadLinqExpression.cs create mode 100644 LinqToDryad/DryadLinqExtension.cs create mode 100644 LinqToDryad/DryadLinqFaultCodes.cs create mode 100644 LinqToDryad/DryadLinqFileStream.cs create mode 100644 LinqToDryad/DryadLinqGlobals.cs create mode 100644 LinqToDryad/DryadLinqHelper.cs create mode 100644 LinqToDryad/DryadLinqIEnumerable.cs create mode 100644 LinqToDryad/DryadLinqIQueryable.cs create mode 100644 LinqToDryad/DryadLinqJobSubmission.cs create mode 100644 LinqToDryad/DryadLinqLog.cs create mode 100644 LinqToDryad/DryadLinqMetaData.cs create mode 100644 LinqToDryad/DryadLinqNative.cs create mode 100644 LinqToDryad/DryadLinqObjectStore.cs create mode 100644 LinqToDryad/DryadLinqQuery.cs create mode 100644 LinqToDryad/DryadLinqSampler.cs create mode 100644 LinqToDryad/DryadLinqSerialization.cs create mode 100644 LinqToDryad/DryadLinqStream.cs create mode 100644 LinqToDryad/DryadLinqUtil.cs create mode 100644 LinqToDryad/DryadLinqVertex.cs create mode 100644 LinqToDryad/DryadLinqVertexParams.cs create mode 100644 LinqToDryad/DryadQueryDoc.cs create mode 100644 LinqToDryad/DryadQueryExplain.cs create mode 100644 LinqToDryad/DryadQueryGen.cs create mode 100644 LinqToDryad/DryadQueryNode.cs create mode 100644 LinqToDryad/DryadRecordReader.cs create mode 100644 LinqToDryad/DryadRecordWriter.cs create mode 100644 LinqToDryad/DryadRuntime.cs create mode 100644 LinqToDryad/DryadTextReader.cs create mode 100644 LinqToDryad/DryadTextWriter.cs create mode 100644 LinqToDryad/DryadVertexEnv.cs create mode 100644 LinqToDryad/DryadVertexReader.cs create mode 100644 LinqToDryad/DryadVertexWriter.cs create mode 100644 LinqToDryad/DscClientHelper.cs create mode 100644 LinqToDryad/DscStubs.cs create mode 100644 LinqToDryad/DynamicManager.cs create mode 100644 LinqToDryad/ExpressionMatcher.cs create mode 100644 LinqToDryad/ExpressionSimplifier.cs create mode 100644 LinqToDryad/ExpressionVisitor.cs create mode 100644 LinqToDryad/ForkTuple.cs create mode 100644 LinqToDryad/Hash64.cs create mode 100644 LinqToDryad/HpcJobSubmission.cs create mode 100644 LinqToDryad/HpcLinqCache.cs create mode 100644 LinqToDryad/HpcLinqConfiguration.cs create mode 100644 LinqToDryad/HpcLinqContext.cs create mode 100644 LinqToDryad/HpcLinqStringDictionary.cs create mode 100644 LinqToDryad/HpcLinqStringList.cs create mode 100644 LinqToDryad/IAssociative.cs create mode 100644 LinqToDryad/IDecomposable.cs create mode 100644 LinqToDryad/IDryadLinqJobSubmission.cs create mode 100644 LinqToDryad/LineRecord.cs create mode 100644 LinqToDryad/LinqToDryad.csproj create mode 100644 LinqToDryad/MultiBlockStream.cs create mode 100644 LinqToDryad/MultiEnumerable.cs create mode 100644 LinqToDryad/MultiQueryable.cs create mode 100644 LinqToDryad/NativeBlockStream.cs create mode 100644 LinqToDryad/QueryTraceLevel.cs create mode 100644 LinqToDryad/SR.Designer.cs create mode 100644 LinqToDryad/SR.resx create mode 100644 LinqToDryad/SimpleRewriter.cs create mode 100644 LinqToDryad/TypeSystem.cs create mode 100644 LinqToDryad/VertexCodeGen.cs create mode 100644 LinqToDryad/WebHdfsClient.cs create mode 100644 LinqToDryad/YarnJobSubmission.cs create mode 100644 LinqToDryad/YarnScheduler.cs create mode 100644 LinqToDryad/sr.txt create mode 100644 README.txt create mode 100644 linqtodryadjm_managed_yarn/DryadLinqApplication.cs create mode 100644 linqtodryadjm_managed_yarn/GraphBuilder.cs create mode 100644 linqtodryadjm_managed_yarn/LinqToDryadException.cs create mode 100644 linqtodryadjm_managed_yarn/LinqToDryadJM.cs create mode 100644 linqtodryadjm_managed_yarn/Program.cs create mode 100644 linqtodryadjm_managed_yarn/Properties/AssemblyInfo.cs create mode 100644 linqtodryadjm_managed_yarn/Query.cs create mode 100644 linqtodryadjm_managed_yarn/QueryParser.cs create mode 100644 linqtodryadjm_managed_yarn/app.config create mode 100644 linqtodryadjm_managed_yarn/linqtodryadjm_managed.csproj create mode 100644 xcompute_managed/Dispatcher.cs create mode 100644 xcompute_managed/DispatcherPool.cs create mode 100644 xcompute_managed/JobStatus.cs create mode 100644 xcompute_managed/Microsoft.Research.Dryad.ClusterAdapter.csproj create mode 100644 xcompute_managed/ProcessTable.cs create mode 100644 xcompute_managed/RequestPool.cs create mode 100644 xcompute_managed/ScheduleProcessRequest.cs create mode 100644 xcompute_managed/VertexCallbackService.cs create mode 100644 xcompute_managed/VertexCallbackServiceHost.cs create mode 100644 xcompute_managed/VertexScheduler.cs create mode 100644 xcompute_managed/VertexServiceClient.cs create mode 100644 xcompute_managed/XComputeProcess.cs create mode 100644 xcompute_native/YarnQueryNativeClusterAdapter.vcxproj create mode 100644 xcompute_native/async.cpp create mode 100644 xcompute_native/context.cpp create mode 100644 xcompute_native/dllmain.cpp create mode 100644 xcompute_native/file.cpp create mode 100644 xcompute_native/inc/XCompute.h create mode 100644 xcompute_native/inc/XComputeTypes.h create mode 100644 xcompute_native/inc/auto_any.h create mode 100644 xcompute_native/inc/scoped_any.h create mode 100644 xcompute_native/inc/smart_any_fwd.h create mode 100644 xcompute_native/locality.cpp create mode 100644 xcompute_native/node.cpp create mode 100644 xcompute_native/path.cpp create mode 100644 xcompute_native/process.cpp create mode 100644 xcompute_native/property.cpp create mode 100644 xcompute_native/scheduler.cpp create mode 100644 xcompute_native/session.cpp create mode 100644 xcompute_native/status.cpp create mode 100644 xcompute_native/stdafx.cpp create mode 100644 xcompute_native/stdafx.h create mode 100644 xcompute_native/targetver.h create mode 100644 xcompute_native/xcimpl.h create mode 100644 xcompute_native/xcompute.cpp diff --git a/CommonCode/AzureUtils.cs b/CommonCode/AzureUtils.cs new file mode 100644 index 0000000..2a2b0f0 --- /dev/null +++ b/CommonCode/AzureUtils.cs @@ -0,0 +1,104 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//------------------------------------------------------------------------------ +// +// Utils used by L2H for dealing with Azure +// +//------------------------------------------------------------------------------ + +namespace Microsoft.Research.Dryad +{ + using System; + + public class AzureUtils + { + /// + /// Flag to denote whether process is on Azure or not + /// + private static bool? isOnAzure = null; + + /// + /// Name of HPC node + /// + private static string hostName = null; + + /// + /// Determine whether current process on Azure or not + /// + public static bool IsOnAzure + { + get + { + if (isOnAzure == null) + { + isOnAzure = Environment.GetEnvironmentVariable("CCP_ONAZURE") != null; + } + + return (bool)isOnAzure; + } + } + + /// + /// Returns name of node or HPC alias if in Azure + /// + /// name to use for current node + public static string CurrentHostName + { + get + { + if(string.IsNullOrEmpty(hostName)) + { + if (Microsoft.Research.Dryad.AzureUtils.IsOnAzure) + { + hostName = Environment.GetEnvironmentVariable(@"HPC_NODE_NAME"); + if (string.IsNullOrEmpty(hostName)) + { + throw new Exception("Unable to get HPC_NODE_NAME environment variable"); + } + } + else + { + hostName = Environment.MachineName; + } + } + + return hostName; + } + } + + /// + /// This needs to be set in the Azure bootstrapper code. + /// + + internal static bool IsDatabaseShared + { + get + { + if (AzureUtils.IsOnAzure) + { + return (Environment.GetEnvironmentVariable(@"HPC_SHARED_DATABASE", EnvironmentVariableTarget.Machine) != null); + } + + return false; + } + } + } +} \ No newline at end of file diff --git a/CommonCode/Constants.cs b/CommonCode/Constants.cs new file mode 100644 index 0000000..5a3063c --- /dev/null +++ b/CommonCode/Constants.cs @@ -0,0 +1,169 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//------------------------------------------------------------------------------ +// +// Constants used by managed code in Dryad +// +//------------------------------------------------------------------------------ + +namespace Microsoft.Research.Dryad +{ + using System; + + internal class Constants + { + // + // Constants for all WCF nettcp bindings + // + public const int MaxReceivedMessageSize = 16 * 1024 * 1024; + public const int MaxBufferPoolSize = 16 * 1024 * 1024; + public const int MaxConnections = 1024; + public const int ListenBacklog = 256; + public static readonly TimeSpan SendTimeout = new TimeSpan(0, 2, 0); + public static readonly TimeSpan ReceiveTimeout = new TimeSpan(0, 10, 0); + public static readonly TimeSpan VertexSendTimeout = new TimeSpan(0, 1, 0); + + // For Seal and Delete Node, use 6 minutes for the WCF timeout because of the 5 minute SQL timeout for the DB call + // Otherwise, default to 2 minutes + // TODO: Post-SP3, re-examine longest running operations, as they slow down service failure and failover time + public static readonly TimeSpan DscOperationTimeout = new TimeSpan(0, 2, 0); + public static readonly TimeSpan DscExtendedOperationTimeout = new TimeSpan(0, 6, 0); + + public static readonly String DryadConnectionString = String.Empty; + + public const string CommonRegistryPath = @"SOFTWARE\Microsoft\HPC"; + public const string HpcSchedulerNameString = "ClusterName"; + public const string HpcInstallPath = "BinDir"; + public const string DscServerName = "DscServiceNodeName"; + public const string DscConnectionFormat = @"net.tcp://{0}:{1}/HpcDsc/Service/DscService"; + public const string DscServiceDefaultScheme = "hpcdsc"; + public const UInt32 DscServiceDefaultPort = 6498; + + + public const Int32 HdfsServiceDefaultHttpPort = 50070; + + + public const string ServiceLocationString = @"ServiceLocation"; + + public const String jobManager = "XCJOBMANAGER"; + + public const string vertexAddrFormat = "net.tcp://{0}:8050/{1}/"; // net.tcp://:8050// + public const string vertexCallbackAddrFormat = "net.tcp://{0}:8051/{1}/"; // net.tcp://:8051// + public const string vertexCallbackServiceName = "DryadVertexCallback"; + public const string vertexServiceName = "DryadVertexService"; + public const string vertexFileServiceName = "DryadVertexFileService"; + public const int vertexFileChunkSize = 1024 * 16; + + public const string vertexCountEnvVar = "HPC_VERTEXCOUNT"; + public const string vertexEnvVarFormat = "HPC_VERTEX{0}"; + public const string vertexSvcInstanceEnvVar = "HPC_VERTEXSVCINST"; + public const string vertexSvcLocalAddrEnvVar = "CCP_DRYADVERTEXLOCALADDRESS"; + + public const string schedulerTypeEnvVar = "CCP_SCHEDULERTYPE"; + public const string schedulerTypeLocal = "LOCAL"; + public const string schedulerTypeCluster = "CLUSTER"; + public const string schedulerTypeAzure = "AZURE"; + public const string debugAzure = "DEBUG_AZURE"; + public const string schedulerTypeYarn = "YARN"; + + + // Recognized values are: OFF, CRITICAL, ERROR, WARN, INFO, VERBOSE + public const string traceLevelEnvVar = "CCP_DRYADTRACELEVEL"; + public const string traceOff = "OFF"; + public const string traceCritical = "CRITICAL"; + public const string traceError = "ERROR"; + public const string traceWarning = "WARN"; + public const string traceInfo = "INFO"; + public const string traceVerbose = "VERBOSE"; + + public const int traceOffNum = 0; + public const int traceCriticalNum = 1; + public const int traceErrorNum = 3; + public const int traceWarningNum = 7; + public const int traceInfoNum = 15; + public const int traceVerboseNum = 31; + + public const string VertexSecurityEnvVar = "HPC_VERTEX_SECURITY"; + + // SchedulerHelper environment variables + public const string clusterNameEnvVar = "CCP_CLUSTER_NAME"; + public const string jobIdEnvVar = "CCP_JOBID"; + public const string taskIdEnvVar = "CCP_TASKID"; + public const string nodesEnvVar = "CCP_NODES"; + public const string jobNameEnvVar = "CCP_JOBNAME"; + public const string requiredNodesEnvVar = "CCP_REQUIREDNODES"; + public const string localProcessComputeNodesEnvVar = "CCP_LOCALPROCESSCOMPUTENODES"; + + // DrError.h values used in managed code + // need to keep this section in sync with drerror.h changes... + public const uint DrError_VertexReceivedTermination = 0x830A0003; + public const uint DrError_VertexCompleted = 0x830A0016; + public const uint DrError_VertexError = 0x830A0017; + public const uint DrError_VertexInitialization = 0x830A0019; + public const uint DrError_ProcessingInterrupted = 0x830A001A; + public const uint DrError_VertexHostLostCommunication = 0x830A0FFF; + + // DSC Share Names + public const string DscTempShare = "HpcTemp"; + public const string DscDataShare = "HpcData"; + public const string RuntimeShareConfig = "HPC_RUNTIMESHARE"; + + // Cluster name + public const string ClusterNameConfig = "CCP_CLUSTER_NAME"; + + // NodeAdmin constants + // Retain time set to one day + // todo: this should be configurable + public static readonly TimeSpan RetainTime = new TimeSpan(1, 0, 0, 0); + public static readonly TimeSpan FileTimeStampMarginForGC = new TimeSpan(0, 0, 5, 0); + public const string runningJobEnvVar = "CCP_RUNNING_JOBS"; + public const string replicaPathFormat = @"\\{0}\HpcData\{1}.data"; + public const string nodeAdminMutexName = "A19A8AC1-4129-46e2-BB81-ED7EE3265B05"; + public const string nodeAdminUsage = "Syntax:\n\t" + + "HpcDscNodeAdmin [/r] [/g] [/wd] [/e] [/v] [/u]\n\n" + + "Parameters:\n\t" + + "/? \t- Display this help message.\n\t" + + "/g \t- Delete files not managed by DSC from the HpcData share.\n\t" + + "/wd\t- Delete old job working directories from the HpcTemp share.\n\t" + + "/r \t- Replicate DSC files onto this node.\n\t" + + "/e \t- Print full error traces.\n\t" + + "/u \t- Resets HpcReplication account password.\n\t" + + "/v \t- Print verbose activity traces.\n"; + + // HpcReplication user account + internal const string HpcReplicationUserName = "HpcReplication"; + + // Client retry period is 1 second for first retry, increasing up to 12 seconds for a total of 30 seconds + // These timeouts are intended to ride through transient network failures + internal const int StartRetryPeriod = 1000; + internal const int MaxRetryPeriod = 12000; + internal const int TotalRetryPeriod = 30000; + internal const int ClientRetryCount = 4; + + // Runtime retry period is 10 seconds for first retry, increasing up to 60 seconds for a total of 6 minutes + // Runtime timeouts intended to ride through a failover and more severe network disruptions with the goal + // of keeping running jobs alive + internal const int RuntimeStartRetryPeriod = 10000; + internal const int RuntimeMaxRetryPeriod = 60000; + internal const int RuntimeTotalRetryPeriod = 360000; + internal const int RuntimeClientRetryCount = 7; + } +} diff --git a/CommonCode/DiscLocalMonitor.cs b/CommonCode/DiscLocalMonitor.cs new file mode 100644 index 0000000..7b6bbf7 --- /dev/null +++ b/CommonCode/DiscLocalMonitor.cs @@ -0,0 +1,163 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Text; + using System.IO; + using System.Threading; + using System.Diagnostics; + using System.ServiceModel; + using System.ServiceModel.Channels; + using System.ServiceModel.Description; + using System.Net.Security; + using System.Runtime.Serialization; + + [ServiceContract(SessionMode = SessionMode.Allowed)] + public interface IDiscLocalJobMonitor + { + [OperationContract] + void UpdateJobProgress(string JobId, string JobMessage, double JobProgress); + + [OperationContract] + void UpdateJobState(string JobId, string JobState); + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("System.ServiceModel", "3.0.0.0")] + public interface IDiscLocalJobMonitorServiceChannel : IDiscLocalJobMonitor, System.ServiceModel.IClientChannel + { + } + + [System.Diagnostics.DebuggerStepThroughAttribute()] + [System.CodeDom.Compiler.GeneratedCodeAttribute("System.ServiceModel", "3.0.0.0")] + public partial class DiscLocalJobMonitorClient : System.ServiceModel.ClientBase, IDiscLocalJobMonitor + { + + public DiscLocalJobMonitorClient() + { + } + + public DiscLocalJobMonitorClient(System.ServiceModel.Channels.Binding binding, System.ServiceModel.EndpointAddress remoteAddress) : + base(binding, remoteAddress) + { + } + + public void UpdateJobProgress(string JobId, string JobMessage, double JobProgress) + { + base.Channel.UpdateJobProgress(JobId, JobMessage, JobProgress); + } + + public void UpdateJobState(string JobId, string JobState) + { + base.Channel.UpdateJobState(JobId, JobState); + } + } + + public class DiscLocalMonitorHelper + { + public string DiscLocalMonitorMachine = @"localhost"; + public DiscLocalMonitorHelper() + { + } + ~DiscLocalMonitorHelper() + { + m_client = null; + } + + public NetTcpBinding LocalMonitorBinding + { + get + { + NetTcpBinding binding = new NetTcpBinding(SecurityMode.Transport, false); + binding.Security.Transport.ClientCredentialType = TcpClientCredentialType.Windows; + binding.Security.Transport.ProtectionLevel = ProtectionLevel.None; + return binding; + } + } + + public string LocalMonitorEpr + { + get + { + return String.Format("net.tcp://{0}:8042/Service/DiscLocalJobMonitor", DiscLocalMonitorMachine); + } + } + + private DiscLocalJobMonitorClient m_client = null; + private bool faultedClient = false; + public DiscLocalJobMonitorClient Client + { + get + { + if (faultedClient) + { + return null; + } + if (m_client != null) + { + return m_client; + } + m_client = new DiscLocalJobMonitorClient(LocalMonitorBinding, new EndpointAddress(LocalMonitorEpr)); + return m_client; + } + } + + public void UpdateProgress(string JobId, string JobMessage, double JobProgress) + { + try + { + if (this.Client != null) + { + this.Client.UpdateJobProgress(JobId, JobMessage, JobProgress); + } + } + catch (Exception e) + { + faultedClient = true; + m_client = null; + Console.WriteLine("ERROR: DiscLocalMonitorHelper '{0}'", e); + } + + } + + public void UpdateJobState(string JobId, string JobState) + { + try + { + if (this.Client != null) + { + this.Client.UpdateJobState(JobId, JobState); + } + } + catch (Exception e) + { + faultedClient = true; + m_client = null; + Console.WriteLine("ERROR: DiscLocalMonitorHelper '{0}'", e); + } + + } + + } + +} diff --git a/CommonCode/DryadTracing.cs b/CommonCode/DryadTracing.cs new file mode 100644 index 0000000..39ffe97 --- /dev/null +++ b/CommonCode/DryadTracing.cs @@ -0,0 +1,693 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; +using System.Collections.Generic; +using System.Text; +using System.Diagnostics; +using System.Diagnostics.Eventing; +using System.IO; +using Microsoft.Win32; +using System.Runtime.InteropServices; +using System.Security.Principal; +using System.Security.Permissions; +using System.Threading; + +namespace Microsoft.Research.Dryad +{ + internal struct DryadEventDescriptor + { + private int m_level; + private string m_message; + + public int Level + { + get { return m_level; } + } + + public string Message + { + get { return m_message; } + } + + public DryadEventDescriptor(int level, string message) + { + m_level = level; + m_message = message; + } + } + + internal sealed class DISCTextProvider : IDisposable + { + private bool m_disposed = false; + private uint m_rolloverCount = 0; + private static readonly long m_maxFileBytes = 10 * 1024 * 1024; + private object m_lock = new object(); + private static readonly string m_logFileFormat = "{0}{1:D6}{2}"; + private string m_baseFilePath = String.Empty; + private StreamWriter m_provider = null; + private Timer m_flushTimer = null; + private static readonly int m_flushTimeout = 100; + + public DISCTextProvider(string path) + { + m_baseFilePath = path; + OpenLogFile(); + m_flushTimer = new Timer(new TimerCallback(this.FlushLogTimer), null, 1000, 1000); + } + + private void FlushLogTimer(Object state) + { + try + { + Flush(); + } + catch + { + } + } + + // Must be called while holding m_lock + private bool ArchiveLogFile() + { + string archivePath = String.Format(m_logFileFormat, Path.Combine(Path.GetDirectoryName(m_baseFilePath),Path.GetFileNameWithoutExtension(m_baseFilePath)), m_rolloverCount, Path.GetExtension(m_baseFilePath)); + + for (int i = 0; i < 3; i++) + { + try + { + File.Move(m_baseFilePath, archivePath); + return true; + } + catch + { + } + } + + return false; + } + + // Must be called while holding m_lock + private void CloseLogFile() + { + if (m_provider != null) + { + // Close the current log file + try + { + Flush(); + + // Dispose + m_provider.Dispose(); + m_provider = null; + } + catch + { + return; + } + } + } + + private void OpenLogFile() + { + try + { + // Open log file + m_provider = new StreamWriter(new FileStream(m_baseFilePath, FileMode.Append, FileAccess.Write, FileShare.ReadWrite)); + } + catch + { + } + } + + private void ReopenLogFile() + { + if (m_provider != null) + { + // Close the current log file + CloseLogFile(); + + ArchiveLogFile(); + } + + OpenLogFile(); + + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!m_disposed) + { + if (disposing) + { + lock (m_lock) + { + if (m_flushTimer != null) + { + m_flushTimer.Dispose(); + m_flushTimer = null; + } + + if (m_provider != null) + { + CloseLogFile(); + m_provider = null; + } + } + } + m_disposed = true; + } + } + + public void WriteEvent(ref DryadEventDescriptor dryadEventDescriptor, string ModuleName, string Source, DateTime TimeStamp, string OperationContext, string Message) + { + try + { + StringBuilder newMessage = new StringBuilder(); + if (!String.IsNullOrEmpty(OperationContext)) + { + newMessage.Append(OperationContext); + newMessage.Append(": "); + } + newMessage.Append(Message); + + WriteEvent(ref dryadEventDescriptor, ModuleName, Source, TimeStamp, newMessage.ToString()); + } + catch + { + } + } + + public void WriteEvent(ref DryadEventDescriptor dryadEventDescriptor, string ModuleName, string Source, DateTime TimeStamp, string OperationContext, string MessageFormat, params object[] MessageParameters) + { + try + { + StringBuilder newMessage = new StringBuilder(); + if (!String.IsNullOrEmpty(OperationContext)) + { + newMessage.Append(OperationContext); + newMessage.Append(": "); + } + + newMessage.Append(String.Format(MessageFormat, MessageParameters)); + + WriteEvent(ref dryadEventDescriptor, ModuleName, Source, TimeStamp, newMessage.ToString()); + } + catch + { + } + + } + + public void WriteEvent(ref DryadEventDescriptor dryadEventDescriptor, string ModuleName, string Source, DateTime TimeStamp, string Message) + { + try + { + lock (m_lock) + { + if (m_provider != null) + { + m_provider.WriteLine( + String.Format("\"{0}\",\"{1}\",\"{2}\",\"{3}\",\"{4}\",\"{5}\",\"{6}\"", + dryadEventDescriptor.Message, + TimeStamp.ToString("yyyy/MM/dd HH:mm:ss.fff", System.Globalization.CultureInfo.InvariantCulture), + System.Diagnostics.Process.GetCurrentProcess().Id, + System.Threading.Thread.CurrentThread.ManagedThreadId, + ModuleName, + Source, + String.IsNullOrEmpty(Message) ? String.Empty : Message + )); + + if (m_provider.BaseStream.Position > m_maxFileBytes) + { + m_rolloverCount++; + ReopenLogFile(); + } + } + } + } + catch + { + } + } + + public void WriteEvent(ref DryadEventDescriptor dryadEventDescriptor, string ModuleName, string Source, DateTime TimeStamp, Exception Exception, string Message) + { + try + { + lock (m_lock) + { + if (m_provider != null) + { + m_provider.WriteLine( + String.Format("\"{0}\",\"{1}\",\"{2}\",\"{3}\",\"{4}\",\"{5}\",\"{6}\",\"{7}\"", + dryadEventDescriptor.Message, + TimeStamp.ToString("yyyy/MM/dd HH:mm:ss.fff", System.Globalization.CultureInfo.InvariantCulture), + System.Diagnostics.Process.GetCurrentProcess().Id, + System.Threading.Thread.CurrentThread.ManagedThreadId, + ModuleName, + Source, + String.IsNullOrEmpty(Message) ? String.Empty : Message, + (Exception == null) ? String.Empty : Exception.ToString() + )); + + if (m_provider.BaseStream.Position > m_maxFileBytes) + { + m_rolloverCount++; + ReopenLogFile(); + } + } + } + } + catch + { + } + } + + public void WriteEvent(ref DryadEventDescriptor dryadEventDescriptor, string ModuleName, string Source, DateTime TimeStamp, Exception Exception, string MessageFormat, params object[] MessageParameters) + { + try + { + WriteEvent(ref dryadEventDescriptor, ModuleName, Source, TimeStamp, Exception, String.Format(MessageFormat, MessageParameters)); + } + catch + { + } + } + + public void Flush() + { + try + { + if (Monitor.TryEnter(m_lock, m_flushTimeout)) + { + try + { + if (m_provider != null) + { + m_provider.Flush(); + } + } + catch + { + } + finally + { + Monitor.Exit(m_lock); + } + } + } + catch + { + } + } + + } + + public sealed class DryadLogger + { + private static DISCTextProvider s_discTracer; + // TODO: Change the default once there is better support for setting per job + private static int s_traceLevel = Constants.traceVerboseNum; + private static object s_syncRoot = new object(); + private static bool s_initialized = false; + + private static DryadEventDescriptor DryadMethodEntry = new DryadEventDescriptor(Constants.traceVerboseNum, "MethodEntry"); + private static DryadEventDescriptor DryadMethodExit = new DryadEventDescriptor(Constants.traceVerboseNum, "MethodExit"); + private static DryadEventDescriptor DryadError = new DryadEventDescriptor(Constants.traceErrorNum, "Error"); + private static DryadEventDescriptor DryadCritical = new DryadEventDescriptor(Constants.traceErrorNum, "Critical"); + private static DryadEventDescriptor DryadWarning = new DryadEventDescriptor(Constants.traceWarningNum, "Warning"); + private static DryadEventDescriptor DryadInformational = new DryadEventDescriptor(Constants.traceInfoNum, "Info"); + private static DryadEventDescriptor DryadVerbose = new DryadEventDescriptor(Constants.traceVerboseNum, "Verbose"); + + + private DryadLogger() + { + } + + static DryadLogger() + { + try + { + string debugLevel = Environment.GetEnvironmentVariable(Constants.traceLevelEnvVar); + + if (!String.IsNullOrEmpty(debugLevel)) + { + s_traceLevel = HpcQueryUtility.GetTraceLevelFromString(debugLevel); + } + + Console.Out.WriteLine("Trace level set to {0}", HpcQueryUtility.ConvertTraceLevelToString(s_traceLevel)); + } + + catch (Exception e) + { + Console.Error.WriteLine("Failed to get tracing level: {0}", e); + } + } + + public static int TraceLevel + { + get { return s_traceLevel; } + set { s_traceLevel = value; } + } + + public static bool Start(string path) + { + try + { + if (!s_initialized) + { + lock (s_syncRoot) + { + if (!s_initialized) + { + s_discTracer = new DISCTextProvider(path); + if (s_discTracer != null) + { + s_initialized = true; + } + else + { + Console.Error.WriteLine("Tracing initialization failed: failed to get intance of tracing provider"); + } + } + } + } + } + catch (Exception e) + { + Console.Error.WriteLine("Tracing initialization failed: {0}", e); + } + + return s_initialized; + } + + public static void Stop() + { + if (s_initialized) + { + lock (s_syncRoot) + { + if (s_initialized) + { + s_initialized = false; + s_discTracer.Flush(); + s_discTracer.Dispose(); + s_discTracer = null; + } + } + } + } + + private static bool IsEnabled(int level) + { + return (s_initialized) && ((level & s_traceLevel) == level); + } + + private static string GetModuleName(StackTrace inputStack) + { + if (inputStack.FrameCount > 0) + { + return (String.Format("{0}!{1}!{2}", inputStack.GetFrame(0).GetMethod().Module, inputStack.GetFrame(0).GetMethod().ReflectedType.Name, inputStack.GetFrame(0).GetMethod().Name)); + } + else + { + return String.Empty; + } + } + + public static void LogMethodEntry(params object[] methodParameters) + { + if (!IsEnabled(Constants.traceVerboseNum)) return; + + StringBuilder parameterString = new StringBuilder(); + + if (methodParameters != null && methodParameters.Length > 0) + { + Int32 parameterCount = 0; + + foreach (object methodParameter in methodParameters) + { + if (parameterCount == methodParameters.Length - 1) + { + parameterString.Append(methodParameter); + } + else + { + parameterString.Append(methodParameter); + parameterString.Append(", "); + } + + parameterCount++; + } + } + + s_discTracer.WriteEvent(ref DryadMethodEntry, Process.GetCurrentProcess().ProcessName, GetModuleName(new StackTrace(1)), DateTime.Now, parameterString.ToString()); + + return; + } + + public static void LogMethodExit(params object[] methodParameters) + { + if (!IsEnabled(Constants.traceVerboseNum)) return; + + StringBuilder parameterString = new StringBuilder(); + + if (methodParameters != null && methodParameters.Length > 0) + { + Int32 parameterCount = 0; + + foreach (object methodParameter in methodParameters) + { + if (parameterCount == methodParameters.Length - 1) + { + parameterString.Append(methodParameter); + } + else + { + parameterString.Append(methodParameter); + parameterString.Append(", "); + } + + parameterCount++; + } + } + + s_discTracer.WriteEvent(ref DryadMethodExit, Process.GetCurrentProcess().ProcessName, GetModuleName(new StackTrace(1)), DateTime.Now, parameterString.ToString()); + + + return; + } + + public static void LogError(Int32 errorCode, Exception discException) + { + if (!IsEnabled(Constants.traceErrorNum)) return; + + StackTrace currentStack; + + if (discException != null) + { + currentStack = new StackTrace(discException); + } + else + { + currentStack = new StackTrace(1); + } + + s_discTracer.WriteEvent(ref DryadError, Process.GetCurrentProcess().ProcessName, GetModuleName(currentStack), DateTime.Now, discException, String.Empty); + + return; + } + + public static void LogError(Int32 errorCode, Exception discException, string messageFormat, params object[] parameterValues) + { + if (!IsEnabled(Constants.traceErrorNum)) return; + + StackTrace currentStack; + + if (discException != null) + { + currentStack = new StackTrace(discException); + } + else + { + currentStack = new StackTrace(1); + } + + s_discTracer.WriteEvent(ref DryadError, Process.GetCurrentProcess().ProcessName, GetModuleName(currentStack), DateTime.Now, discException, messageFormat, parameterValues); + + return; + } + + private static void LogCriticalToConsole(Int32 errorCode, Exception discException, string messageFormat, params object[] parameterValues) + { + if (discException == null && String.IsNullOrEmpty(messageFormat)) + { + // Sadly, nothing to log + return; + } + + StringBuilder message = new StringBuilder(); + + if (discException != null) + { + message.Append("Critical Exception occurred: "); + int hr = System.Runtime.InteropServices.Marshal.GetHRForException(discException); + if (hr != 0) + { + message.Append("0x"); + message.Append(hr.ToString("X8")); + } + Console.Error.WriteLine(message.ToString()); + Console.Error.WriteLine(discException.ToString()); + } + else if (errorCode != 0) + { + message.Append("Critical error occured: code = "); + message.Append(errorCode); + Console.Error.WriteLine(message.ToString()); + } + + if (!String.IsNullOrEmpty(messageFormat)) + { + try + { + Console.Error.WriteLine(messageFormat, parameterValues); + } + catch + { + } + } + Console.Error.Flush(); + } + + public static void LogCritical(Int32 errorCode, Exception discException) + { + // + // For LogCritical only, write message to Console.Error so that it shows up in task's output + // + LogCriticalToConsole(errorCode, discException, String.Empty); + + if (!IsEnabled(Constants.traceCriticalNum)) return; + + StackTrace currentStack; + + if (discException != null) + { + currentStack = new StackTrace(discException); + } + else + { + currentStack = new StackTrace(1); + } + + s_discTracer.WriteEvent(ref DryadCritical, Process.GetCurrentProcess().ProcessName, GetModuleName(currentStack), DateTime.Now, discException, String.Empty); + s_discTracer.Flush(); + + return; + } + + public static void LogCritical(Int32 errorCode, Exception discException, string messageFormat, params object[] parameterValues) + { + // + // For LogCritical only, write message to Console.Error so that it shows up in task's output + // + LogCriticalToConsole(errorCode, discException, messageFormat, parameterValues); + + if (!IsEnabled(Constants.traceCriticalNum)) return; + + StackTrace currentStack; + + if (discException != null) + { + currentStack = new StackTrace(discException); + } + else + { + currentStack = new StackTrace(1); + } + + s_discTracer.WriteEvent(ref DryadCritical, Process.GetCurrentProcess().ProcessName, GetModuleName(currentStack), DateTime.Now, discException, messageFormat, parameterValues); + s_discTracer.Flush(); + + return; + } + + public static void LogWarning(string operationContext, string warningMessage) + { + if (!IsEnabled(Constants.traceWarningNum)) return; + + s_discTracer.WriteEvent(ref DryadWarning, Process.GetCurrentProcess().ProcessName, GetModuleName(new StackTrace(1)), DateTime.Now, operationContext, warningMessage); + + return; + } + + public static void LogWarning(string operationContext, string warningMessageFormat, params object[] parameterValues) + { + if (!IsEnabled(Constants.traceWarningNum)) return; + + s_discTracer.WriteEvent(ref DryadWarning, Process.GetCurrentProcess().ProcessName, GetModuleName(new StackTrace(1)), DateTime.Now, operationContext, warningMessageFormat, parameterValues); + + return; + } + + public static void LogInformation(string operationContext, string operationalMessage) + { + if (!IsEnabled(Constants.traceInfoNum)) return; + + s_discTracer.WriteEvent(ref DryadInformational, Process.GetCurrentProcess().ProcessName, GetModuleName(new StackTrace(1)), DateTime.Now, operationContext, operationalMessage); + + return; + } + + public static void LogInformation(string operationContext, string operationalMessageFormat, params object[] parameterValues) + { + if (!IsEnabled(Constants.traceInfoNum)) return; + + s_discTracer.WriteEvent(ref DryadInformational, Process.GetCurrentProcess().ProcessName, GetModuleName(new StackTrace(1)), DateTime.Now, operationContext, operationalMessageFormat, parameterValues); + + return; + } + + public static void LogDebug(string operationContext, string debugMessage) + { + if (!IsEnabled(Constants.traceVerboseNum)) return; + + s_discTracer.WriteEvent(ref DryadVerbose, Process.GetCurrentProcess().ProcessName, GetModuleName(new StackTrace(1)), DateTime.Now, operationContext, debugMessage); + + return; + } + + public static void LogDebug(string operationContext, string debugMessageFormat, params object[] parameterValues) + { + if (!IsEnabled(Constants.traceVerboseNum)) return; + + s_discTracer.WriteEvent(ref DryadVerbose, Process.GetCurrentProcess().ProcessName, GetModuleName(new StackTrace(1)), DateTime.Now, operationContext, debugMessageFormat, parameterValues); + + return; + } + } +} \ No newline at end of file diff --git a/CommonCode/DryadVertexServiceAuthorizationManager.cs b/CommonCode/DryadVertexServiceAuthorizationManager.cs new file mode 100644 index 0000000..7fd4cbf --- /dev/null +++ b/CommonCode/DryadVertexServiceAuthorizationManager.cs @@ -0,0 +1,118 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//------------------------------------------------------------------------------ +// +// Auth manager used for checking if the caller identity matches with the +// current identity. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.Collections.Generic; + using System.Security.Principal; + using System.ServiceModel; + using System.ServiceModel.Channels; + using System.Net; + using System.Text; + + /// + /// Auth manager used for checking if the caller identity matches with the current identity. + /// + public class DryadVertexServiceAuthorizationManager : ServiceAuthorizationManager + { + /// + /// Reference Identity + /// + WindowsIdentity currentIdentity; + + /// + /// Creates an instance of the DryadVertexServiceAuthorizationManager class. + /// + public DryadVertexServiceAuthorizationManager() + { + this.currentIdentity = WindowsIdentity.GetCurrent(); + } + + /// + /// Check whether current operation context should be allowed access + /// + /// Current operation context + /// true = allowed + protected override bool CheckAccessCore(OperationContext operationContext) + { + //TODO: Put logging information to appropriate channels when available. + + // + // Fail if context is annonymous + // + if (operationContext.ServiceSecurityContext.IsAnonymous) + { + DryadLogger.LogError(0, null, "Vertex authentication failed : Service security context is anonymous."); + return false; + } + + // + // Get identity used in current context + // + WindowsIdentity callerIdentity = operationContext.ServiceSecurityContext.WindowsIdentity; + if (callerIdentity == null) + { + // + // Fail if identity is not set + // + DryadLogger.LogError(0, null, "Vertex authentication failed : Caller identity is null."); + return false; + } + else if (callerIdentity.IsAnonymous) + { + // + // Fail if identity is anonymous + // + DryadLogger.LogError(0, null, "Vertex authentication failed : Caller identity is anonymous."); + return false; + } + else if (!callerIdentity.IsAuthenticated) + { + // + // Fail if identity is not authenticated + // + DryadLogger.LogError(0, null, "Vertex authentication failed : Caller identity is not authenticated."); + return false; + } + + // + // If operation context has same user as vertex service, then allow, otherwise fail. + // + if (this.currentIdentity.User == callerIdentity.User) + { + return true; + } + else + { + DryadLogger.LogError(0, null, "Vertex authentication failed : Current identity is {0}, caller identity is {1}", this.currentIdentity.Name, callerIdentity.Name); + } + + return false; + } + } +} diff --git a/CommonCode/ExecutionHelper.cs b/CommonCode/ExecutionHelper.cs new file mode 100644 index 0000000..5c6c59b --- /dev/null +++ b/CommonCode/ExecutionHelper.cs @@ -0,0 +1,297 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.Collections.Generic; + using System.IO; + using System.Linq; + using System.Net; + using System.Security.Principal; + using System.Security.AccessControl; + using Microsoft.Research.Dryad; + + internal static class ExecutionHelper + { + /// + /// List of known files in bin directory + /// + private static List binFileList = new List(10); + + /// + /// %DRYAD_HOME% bin directory where installed binaries can be found + /// + private static string dryadHome = Environment.GetEnvironmentVariable("DRYAD_HOME"); + + /// + /// Lockable object used to help make this class thread safe. + /// + private static object initializeLock = new object(); + + /// + /// Used to initialize known file list + /// + private static void InitializeBinFileList() + { + if (binFileList.Count == 0) + { + lock (initializeLock) + { + if (binFileList.Count == 0) + { + binFileList.Add("VertexHost.exe"); + binFileList.Add("Microsoft.Research.Dryad.dll"); + binFileList.Add("DryadLINQNativeChannels.dll"); + binFileList.Add("YarnQueryNativeClusterAdapter.dll"); + binFileList.Add("Microsoft.Research.Dryad.dll"); + binFileList.Add("Microsoft.Research.DryadLINQ.dll"); + binFileList.Add("Microsoft.Research.Dryad.Hdfs.dll"); + } + } + } + } + + /// + /// Check for azure execution + /// + private static bool AzureExecution + { + get + { + string debugAzure = Environment.GetEnvironmentVariable(Constants.debugAzure); + if (!String.IsNullOrEmpty(debugAzure)) + { + return true; + } + + string schedulerType = Environment.GetEnvironmentVariable(Constants.schedulerTypeEnvVar); + return (!String.IsNullOrEmpty(schedulerType) && schedulerType == Constants.schedulerTypeAzure); + } + } + + /// + /// Copy local resources from %ccp_home%bin to working directory + /// todo: Post SP2, we should run directly from %ccp_home%bin rather than performing a local copy. + /// + /// success = true + private static bool CopyLocalBinaries() + { + //Console.Error.WriteLine("Copying source files from {0}", dryadHome); //DEBUG + foreach (string localFile in binFileList) + { + // Get path to job working directory in \\hpctemp + string jobFilePath = Path.Combine(ProcessPathHelper.JobPath, localFile); + + // Only copy files that do not already exist + // Avoids overwriting files when a vertex service task fails or finishes on one node + // and a new vertex service task is scheduled on the same node within the same job + if (!File.Exists(jobFilePath)) + { + // Get path to file source in CCP_HOME\bin + string sourceFilePath = Path.Combine(dryadHome, localFile); + + try + { + File.Copy(sourceFilePath, jobFilePath, true); + } + catch (Exception e) + { + Console.Error.WriteLine("[ExecutionHelper.CopyLocalResources] Exception copying '{0}' to '{1}': {2}", sourceFilePath, jobFilePath, e.Message); + return false; + } + } + } + + return true; + } + + private static void GetHdfsFile(string hdfsDir, string fileName, string destFileName) + { + if(!hdfsDir.EndsWith("/")) + { + hdfsDir = hdfsDir + "/"; + } + var hdfsDirUri = new Uri(hdfsDir, UriKind.Absolute); + var hdfsFileUri = new Uri(hdfsDirUri, fileName); + var builder = new UriBuilder(); + builder.Host = hdfsFileUri.DnsSafeHost; + builder.Port = Constants.HdfsServiceDefaultHttpPort; + builder.Path = "webhdfs/v1/" + hdfsFileUri.AbsolutePath.TrimStart('/'); + builder.Query = "op=OPEN"; + Console.WriteLine(builder.Uri); + var wc = new WebClient(); + wc.DownloadFile(builder.Uri, destFileName); + } + + /// + /// Copy the resources from staging dir to working dir + /// + /// list of resources supplied by dryadlinq + /// success = true + private static bool CopyStagedJobResources(string resources) + { + if (resources != null) + { + if (resources[0] == '@') + { + resources = File.ReadAllText(resources.Substring(1)); + } + + if (resources.EndsWith(",")) + { + resources = resources.Substring(0, resources.Length - 1); + } + string[] files = resources.Split(','); + DryadLogger.LogInformation("CopyStagedJobResources", string.Format("Will copy {0} resource files.", files.Length)); + + if (files.Length > 1) + { + string source = files[0]; + for (int i = 1; i < files.Length; i++) + { + string jobFilePath = Path.Combine(ProcessPathHelper.JobPath, files[i]); + + // + // File may already exist due to local resource copying + // + if (File.Exists(jobFilePath) == false) + { + // + // If file doesn't exist today, get it from staging location + // + if(source.StartsWith("hdfs://", StringComparison.InvariantCultureIgnoreCase)) + { + // copy from HDFS + DryadLogger.LogDebug("CopyStagedJobResources", string.Format( + "[ExecutionHelper.CopyJobResources] Copying '{0}' to '{1}' from HDFS dir {2}", + files[i], jobFilePath, source)); + GetHdfsFile(source, files[i], jobFilePath); + } + else + { + string sourceFile = Path.Combine(source, files[i]); + try + { + DryadLogger.LogDebug("CopyStagedJobResources", string.Format( + "[ExecutionHelper.CopyJobResources] Copying '{0}' to '{1}'", + sourceFile, jobFilePath)); + File.Copy(sourceFile, jobFilePath); + } + catch (Exception e) + { + DryadLogger.LogInformation("CopyStagedJobResources", string.Format( + "[ExecutionHelper.CopyJobResources] Exception copying '{0}' to '{1}': {2}", + sourceFile, jobFilePath, e.Message)); + return false; + } + } + } + } + } + else + { + Console.Error.WriteLine("[ExecutionHelper.CopyJobResources] invalid XC_RESOURCEFILES length = {0}", files.Length); + return false; + } + } + else + { + Console.Error.WriteLine("[ExecutionHelper.CopyJobResources] resources = null"); + return false; + } + return true; + } + + /// + /// Create working directory for vertex + /// + /// + /// + public static bool InitializeForProcessExecution(int id, string resources) + { + try + { + Directory.CreateDirectory(ProcessPathHelper.ProcessPath(id)); + Console.Error.WriteLine("Created directory: " + ProcessPathHelper.ProcessPath(id)); + return true; + } + catch (Exception e) + { + Console.Error.WriteLine("[ExecutionHelper.InitializeForProcessExecution] Exception: {0}", e.Message); + Console.Error.WriteLine(e.StackTrace); + return false; + } + } + + /// + /// Initialize the job directory for vertex execution + /// + /// list of DryadLINQ-requested resources + /// success/failure + public static bool InitializeForJobExecution(string resources) + { + try + { + // + // Update list of known local binaries if needed + // + InitializeBinFileList(); + + ProcessPathHelper.CreateUserWorkingDirectory(); + + Directory.CreateDirectory(ProcessPathHelper.JobPath); + + // + // copy any files that already live locally and may be needed for the job + // + bool success = CopyLocalBinaries(); + + // + // copy any user-specified files that haven't already been copied + // + success &= CopyStagedJobResources(resources); + return success; + } + catch (Exception e) + { + // + // Write out any errors and return false on exception + // + Console.Error.WriteLine("[ExecutionHelper.InitializeForJobExecution] Exception: {0}", e.Message); + Console.Error.WriteLine(e.StackTrace); + return false; + } + } + + /// + /// Checks if resource is one of the binaries already on the compute node + /// + /// name of resource to check + public static bool IsLocalResource(string resourceName) + { + InitializeBinFileList(); + + return ((from myfile in binFileList + where string.Compare(resourceName, myfile, StringComparison.OrdinalIgnoreCase) == 0 + select myfile).Count() == 1); + } + } +} \ No newline at end of file diff --git a/CommonCode/IDryadVertexCallback.cs b/CommonCode/IDryadVertexCallback.cs new file mode 100644 index 0000000..8cfa917 --- /dev/null +++ b/CommonCode/IDryadVertexCallback.cs @@ -0,0 +1,80 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.Collections.Generic; + using System.ServiceModel; + using System.Runtime.Serialization; + + [DataContract] + public class ProcessStatistics + { + [DataMember] + public uint flags; + [DataMember] + public long processUserTime; + [DataMember] + public long processKernelTime; + [DataMember] + public int pageFaults; + [DataMember] + public int totalProcessesCreated; + [DataMember] + public ulong peakVMUsage; + [DataMember] + public ulong peakMemUsage; + [DataMember] + public ulong memUsageSeconds; + [DataMember] + public ulong totalIo; + }; + + [DataContract] + public class ProcessInfo + { + [DataMember] + public uint flags; + [DataMember] + public ProcessState processState; + [DataMember] + public uint processStatus; + [DataMember] + public uint exitCode; + [DataMember] + public ProcessPropertyInfo[] propertyInfos; + [DataMember] + public ProcessStatistics processStatistics; + }; + + [ServiceContract(SessionMode = SessionMode.Allowed)] + public interface IDryadVertexCallback + { + [OperationContract] + void FireStateChange(int processId, ProcessState newState); + + [OperationContract] + void SetGetPropsComplete(int processId, ProcessInfo info, string[] propertyLabels, ulong[] propertyVersions); + + [OperationContract] + void ProcessExited(int processId, int exitCode); + } +} diff --git a/CommonCode/IDryadVertexService.cs b/CommonCode/IDryadVertexService.cs new file mode 100644 index 0000000..e49b557 --- /dev/null +++ b/CommonCode/IDryadVertexService.cs @@ -0,0 +1,204 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//------------------------------------------------------------------------------ +// +// Vertex Service contracts +// +//------------------------------------------------------------------------------ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.Collections.Generic; + using System.Collections.Specialized; + using System.Runtime.Serialization; + using System.ServiceModel; + using System.Diagnostics; + + /// + /// Holds property information + /// + [DataContract] + public class ProcessPropertyInfo + { + [DataMember] + public string propertyLabel; + [DataMember] + public ulong propertyVersion; + [DataMember] + public string propertyString; + [DataMember] + public byte[] propertyBlock; + } + + /// + /// Keeps track of vertex host process information + /// + [DataContract] + public class VertexProcessInfo + { + [DataMember] + public int DryadId; + [DataMember] + public string commandLine; + [DataMember] + public ProcessState State; + } + + /// + /// Keeps track of CPU and memory info for vertex host process + /// + [DataContract] + public class VertexStatus + { + [DataMember] + public bool serviceIsAlive = false; + [DataMember] + public Dictionary freeDiskSpaces = new Dictionary(); + [DataMember] + public ulong freeVirtualMemory = 0; + [DataMember] + public ulong freePhysicalMemory = 0; + [DataMember] + public uint runningProcessCount = 0; + [DataMember] + public List vps = new List(); + } + + [DataContract(Namespace = "http://hpc.microsoft.com/dryadvertex/")] + [Serializable] + public class VertexServiceError + { + public const string Action = "http://hpc.microsoft.com/dryadvertex/VertexServiceError"; + + /// + /// Stores the Operation + /// + [DataMember] + private string operation; + + /// + /// Stores the Reason + /// + [DataMember] + private string reason; + + /// + /// Initializes a new instance of the VertexServiceError class + /// + /// The operation that failed + /// The detailed reason for the failure (exception.ToString()) + public VertexServiceError(string operation, string reason) + { + this.operation = operation; + this.reason = reason; + } + + /// + /// The detailed reason for the failure + /// + public string Reason + { + get + { + return this.reason; + } + } + + /// + /// The operation that failed + /// + public string Operation + { + get + { + return this.operation; + } + } + } + + [DataContract(Namespace = "http://hpc.microsoft.com/dryadvertex/")] + [Serializable] + public class UnknownProcessError + { + public const string Action = "http://hpc.microsoft.com/dryadvertex/UnknownProcessError"; + + /// + /// Stores the ProcessId + /// + [DataMember] + private int processId; + + /// + /// Initializes a new instance of the UnknownProcessError class + /// + /// Id of the unknown process + public UnknownProcessError(int id) + { + this.processId = id; + } + + /// + /// The process id which was not found + /// + public int Processid + { + get + { + return this.processId; + } + } + + } + + /// + /// Dryad Vertex Service Contract - allows GM to schedule vertices and VH to report status + /// + [ServiceContract(Name = "IDryadVertexService", Namespace = "http://hpc.microsoft.com/dryadvertex/", SessionMode = SessionMode.Allowed)] + public partial interface IDryadVertexService + { + [OperationContract(IsOneWay=true, Action = "http://hpc.microsoft.com/dryadvertex/cancelscheduleprocess")] + void CancelScheduleProcess(int processId); + + // TODO: Deprecated. + [OperationContract(Action = "http://hpc.microsoft.com/dryadvertex/checkstatus")] + VertexStatus CheckStatus(); + + [OperationContract(IsOneWay = true, Action = "http://hpc.microsoft.com/dryadvertex/initialize")] + void Initialize(StringDictionary vertexEndpointAddresses); + + [OperationContract(IsOneWay=true, Action = "http://hpc.microsoft.com/dryadvertex/releaseprocess")] + void ReleaseProcess(int processId); + + [OperationContract(Action = "http://hpc.microsoft.com/dryadvertex/scheduleprocess")] + [FaultContract(typeof(VertexServiceError), Action = VertexServiceError.Action)] + bool ScheduleProcess(string replyUri, int processId, string commandLine, StringDictionary environment); + + [OperationContract(Action = "http://hpc.microsoft.com/dryadvertex/setgetprops")] + [FaultContract(typeof(VertexServiceError), Action = VertexServiceError.Action)] + [FaultContract(typeof(UnknownProcessError), Action = UnknownProcessError.Action)] + bool SetGetProps(string replyUri, int processId, ProcessPropertyInfo[] infos, string blockOnLabel, ulong blockOnVersion, long maxBlockTime, string getPropLabel, bool ProcessStatistics); + + [OperationContract(IsOneWay=true, Action = "http://hpc.microsoft.com/dryadvertex/shutdown")] + void Shutdown(uint ShutdownCode); + } + +} diff --git a/CommonCode/NativeMethods.cs b/CommonCode/NativeMethods.cs new file mode 100644 index 0000000..71207be --- /dev/null +++ b/CommonCode/NativeMethods.cs @@ -0,0 +1,453 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//------------------------------------------------------------------------------ +// +// Wrapped native methods +// +//------------------------------------------------------------------------------ +namespace Microsoft.Research.Dryad +{ + using System; + using System.Collections.Generic; + using System.Text; + using System.Runtime.InteropServices; + using System.Runtime.ConstrainedExecution; + using System.Security; + using System.Security.Permissions; + using Microsoft.Win32.SafeHandles; + + [SecurityPermission(SecurityAction.InheritanceDemand, UnmanagedCode = true)] + [SecurityPermission(SecurityAction.Demand, UnmanagedCode = true)] + public sealed class SafeThreadHandle : SafeHandleZeroOrMinusOneIsInvalid + { + private SafeThreadHandle() : base(true) { } + + public SafeThreadHandle(IntPtr handle) + : base(false) + { + this.SetHandle(handle); + } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + return NativeMethods.CloseHandle(this.handle); + } + } + + + [SecurityPermission(SecurityAction.InheritanceDemand, UnmanagedCode = true)] + [SecurityPermission(SecurityAction.Demand, UnmanagedCode = true)] + public sealed class SafeProcessHandle : SafeHandleZeroOrMinusOneIsInvalid + { + private SafeProcessHandle() : base(true) { } + + public SafeProcessHandle(IntPtr handle) + : base(false) + { + this.SetHandle(handle); + } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + return NativeMethods.CloseHandle(this.handle); + } + } + + [SecurityPermission(SecurityAction.InheritanceDemand, UnmanagedCode = true)] + [SecurityPermission(SecurityAction.Demand, UnmanagedCode = true)] + public sealed class SafeImpersonationToken : SafeHandleZeroOrMinusOneIsInvalid + { + private SafeImpersonationToken() : base(true) { } + + public SafeImpersonationToken(IntPtr token) + : base(false) + { + this.SetHandle(token); + } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + return NativeMethods.CloseHandle(this.handle); + } + + } + + /// + /// Wrapped native methods + /// + [SuppressUnmanagedCodeSecurity] + public static class NativeMethods + { + public static readonly IntPtr INVALID_HANDLE_VALUE = new IntPtr(-1); + + // Create process + public const uint CREATE_UNICODE_ENVIRONMENT = 0x00000400; + public const uint CREATE_SUSPENDED = 0x00000004; + public const uint CREATE_BREAKAWAY_FROM_JOB = 0x01000000; + public const uint CREATE_NO_WINDOW = 0x08000000; + + // LogonUser + public const uint LOGON32_LOGON_INTERACTIVE = 0x00000002; + public const uint LOGON32_LOGON_NETWORK = 0x00000003; + public const uint LOGON32_LOGON_BATCH = 0x00000004; + public const uint LOGON32_LOGON_NETWORK_CLEARTEXT = 0x00000008; + public const uint LOGON32_PROVIDER_DEFAULT = 0x00000000; + + + + /// + /// Error flag for "no error" + /// + public const int ERROR_OK = 0; + + /// + /// Error flag for insufficient buffer + /// + public const int ERROR_INSUFFICIENT_BUFFER = 122; + + + // Job object + public const int JobObjectExtendedLimitInformationClass = 9; + public const uint JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE = 0x00002000; + + [StructLayout(LayoutKind.Sequential)] + public struct PROCESS_INFORMATION + { + public IntPtr hProcess; + public IntPtr hThread; + public int dwProcessId; + public int dwThreadId; + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct STARTUPINFO + { + public Int32 cb; + public string lpReserved; + public string lpDesktop; + public string lpTitle; + public Int32 dwX; + public Int32 dwY; + public Int32 dwXSize; + public Int32 dwYSize; + public Int32 dwXCountChars; + public Int32 dwYCountChars; + public Int32 dwFillAttribute; + public Int32 dwFlags; + public Int16 wShowWindow; + public Int16 cbReserved2; + public IntPtr lpReserved2; + public IntPtr hStdInput; + public IntPtr hStdOutput; + public IntPtr hStdError; + } + + [StructLayout(LayoutKind.Sequential)] + public struct JobObjectExtendedLimitInformation + { + public Int64 PerProcessUserTimeLimit; + public Int64 PerJobUserTimeLimit; + public UInt32 LimitFlags; + public UIntPtr MinimumWorkingSetSize; + public UIntPtr MaximumWorkingSetSize; + public UInt32 ActiveProcessLimit; + public IntPtr Affinity; + public UInt32 PriorityClass; + public UInt32 SchedulingClass; + public UInt64 ReadOperationCount; + public UInt64 WriteOperationCount; + public UInt64 OtherOperationCount; + public UInt64 ReadTransferCount; + public UInt64 WriteTransferCount; + public UInt64 OtherTransferCount; + public UIntPtr ProcessMemoryLimit; + public UIntPtr JobMemoryLimit; + public UIntPtr PeakProcessMemoryUsed; + public UIntPtr PeakJobMemoryUsed; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SECURITY_ATTRIBUTES + { + public UInt32 nLength; + public UIntPtr lpSecurityAttributes; + } + + [DllImport("Kernel32.dll", CallingConvention = CallingConvention.Winapi, SetLastError = true, CharSet = CharSet.Unicode)] + [return: MarshalAs(UnmanagedType.Bool)] + public static extern bool CreateProcess([MarshalAs(UnmanagedType.LPTStr)]string lpApplicationName, + StringBuilder lpCommandLine, IntPtr lpProcessAttributes, + IntPtr lpThreadAttributes, bool bInheritHandles, + uint dwCreationFlags, IntPtr lpEnvironment, string lpCurrentDirectory, + [In] ref STARTUPINFO lpStartupInfo, + out PROCESS_INFORMATION lpProcessInformation); + + [DllImport("Advapi32.dll", CallingConvention = CallingConvention.Winapi, SetLastError = true, CharSet = CharSet.Unicode)] + [return: MarshalAs(UnmanagedType.Bool)] + public static extern bool CreateProcessAsUser(SafeImpersonationToken hToken, [MarshalAs(UnmanagedType.LPTStr)]string lpApplicationName, + StringBuilder lpCommandLine, IntPtr lpProcessAttributes, + IntPtr lpThreadAttributes, bool bInheritHandles, + uint dwCreationFlags, IntPtr lpEnvironment, string lpCurrentDirectory, + [In] ref STARTUPINFO lpStartupInfo, + out PROCESS_INFORMATION lpProcessInformation); + + [DllImport("Advapi32.dll", SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public extern static bool LogonUser(string lpszUserName, string lpszDomain, string lpszPassword, + uint dwLogonType, uint dwLogonProvider, out SafeImpersonationToken phToken + ); + + [DllImport("Kernel32.dll", CallingConvention = CallingConvention.Winapi, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static extern bool GetExitCodeProcess(SafeProcessHandle hProcess, out uint lpExitCode); + + [DllImport("Kernel32.dll", CallingConvention = CallingConvention.Winapi, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public static extern bool TerminateProcess(SafeProcessHandle hProcess, int uExitCode); + + [DllImport("Kernel32.dll", CallingConvention = CallingConvention.Winapi, SetLastError = true)] + public static extern uint ResumeThread(SafeThreadHandle hThread); + + [DllImport("kernel32.dll", CallingConvention = CallingConvention.Winapi, SetLastError = true, CharSet = CharSet.Unicode)] + public static extern IntPtr CreateJobObject(IntPtr lpJobAttributes, string lpName); + + public const int JobObjectExtendedLimitInformationQuery = 9; + public const int JobObjectExtendedLimitInformationSet = 9; + + public const int QueryJobObjectBasicProcessIdList = 3; + [StructLayout(LayoutKind.Sequential)] + public struct JobObjectBasicProcessIdListHeader + { + public UInt32 NumberOfAssignedProcesses; + public UInt32 NumberOfProcessIdsInList; + } + + [DllImport("kernel32.dll", CharSet = System.Runtime.InteropServices.CharSet.Auto, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public extern static bool QueryInformationJobObject( + IntPtr hJob, + int query, + out JobObjectExtendedLimitInformation info, + int size, + out int returnedSize + ); + + [DllImport("kernel32.dll", CallingConvention = CallingConvention.Winapi, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public extern static bool SetInformationJobObject(IntPtr hJob, int informationClass, [In] ref JobObjectExtendedLimitInformation info, int size); + + [DllImport("kernel32.dll", CallingConvention = CallingConvention.Winapi, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + public extern static bool AssignProcessToJobObject(IntPtr hJob, SafeProcessHandle hProcess); + + [DllImport("kernel32.dll", SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + public extern static bool CloseHandle(IntPtr handle); + + [DllImport("kernel32.dll", SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + public extern static bool CloseHandle(HandleRef handleRef); + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + public static void SafeCloseValidHandle(HandleRef handleRef) + { + if (handleRef.Handle != IntPtr.Zero && handleRef.Handle != INVALID_HANDLE_VALUE) + { + try + { + CloseHandle(handleRef); + } + catch + { + // Swallow exception + } + } + } + + + /// + /// contains information about the current state of both physical and virtual memory, including extended memory + /// + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Auto)] + public class MEMORYSTATUSEX + { + /// + /// Size of the structure, in bytes. You must set this member before calling GlobalMemoryStatusEx. + /// + public uint dwLength; + + /// + /// Number between 0 and 100 that specifies the approximate percentage of physical memory that is in use (0 indicates no memory use and 100 indicates full memory use). + /// + public uint dwMemoryLoad; + + /// + /// Total size of physical memory, in bytes. + /// + public ulong ullTotalPhys; + + /// + /// Size of physical memory available, in bytes. + /// + public ulong ullAvailPhys; + + /// + /// Size of the committed memory limit, in bytes. This is physical memory plus the size of the page file, minus a small overhead. + /// + public ulong ullTotalPageFile; + + + + /// + /// Size of available memory to commit, in bytes. The limit is ullTotalPageFile. + /// + public ulong ullAvailPageFile; + + /// + /// Total size of the user mode portion of the virtual address space of the calling process, in bytes. + /// + public ulong ullTotalVirtual; + + /// + /// Size of unreserved and uncommitted memory in the user mode portion of the virtual address space of the calling process, in bytes. + /// + public ulong ullAvailVirtual; + + /// + /// Size of unreserved and uncommitted memory in the extended portion of the virtual address space of the calling process, in bytes. + /// + public ulong ullAvailExtendedVirtual; + + /// + /// Initializes a new instance of the class. + /// + public MEMORYSTATUSEX() + { + this.dwLength = (uint)Marshal.SizeOf(typeof(NativeMethods.MEMORYSTATUSEX)); + } + } + + /// + /// Retrieves information about the system's current usage of both physical and virtual memory. + /// + /// A pointer to a MEMORYSTATUSEX structure that receives information about current memory availability + /// If the function succeeds, the return value is nonzero. Error code otherwise. + [return: MarshalAs(UnmanagedType.Bool)] + [DllImport("kernel32.dll", CharSet = CharSet.Auto, SetLastError = true)] + public static extern bool GlobalMemoryStatusEx([In, Out] MEMORYSTATUSEX lpBuffer); + + /// + /// Retrieves information about the amount of space that is available on a disk volume, which is the total amount of space, + /// the total amount of free space, and the total amount of free space available to the user that is associated with the calling thread. + /// + /// A directory on the disk. + /// A pointer to a variable that receives the total number of free bytes on a disk that are available to the user who is associated with the calling thread. + /// A pointer to a variable that receives the total number of bytes on a disk that are available to the user who is associated with the calling thread. + /// A pointer to a variable that receives the total number of free bytes on a disk. + /// + [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Auto)] + public static extern bool GetDiskFreeSpaceEx(string lpDirectoryName, + out ulong lpFreeBytesAvailable, + out ulong lpTotalNumberOfBytes, + out ulong lpTotalNumberOfFreeBytes); + + /// + /// SID Usage Enum + /// + public enum SID_NAME_USE + { + SidTypeUser = 1, + SidTypeGroup, + SidTypeDomain, + SidTypeAlias, + SidTypeWellKnownGroup, + SidTypeDeletedAccount, + SidTypeInvalid, + SidTypeUnknown, + SidTypeComputer + } + + /// + /// Get SID for account name + /// + /// Compute name + /// Account name + /// Security ID + /// Number of bytes needed to hold the SID + /// Domain name reference by SID + /// Number of bytes needed to hold the domain + /// Account type + /// error flag + [DllImport("advapi32.dll", CharSet = CharSet.Auto, SetLastError = true)] + public static extern bool LookupAccountName( + string lpSystemName, + string lpAccountName, + [MarshalAs(UnmanagedType.LPArray)] byte[] Sid, + ref uint cbSid, + StringBuilder ReferencedDomainName, + ref uint cchReferencedDomainName, + out SID_NAME_USE peUse); + + /// + /// Retrieves the name of the account for this SID and the name of the first domain on which this SID is found + /// + /// string that specifies the target computer + /// the SID to look up + /// buffer that receives the account name that corresponds to the Sid parameter + /// On input, specifies the size of the lpName buffer. If the function fails because + /// the buffer is too small or if cchName is zero, cchName receives the required buffer size + /// buffer that receives the name of the domain where the account name was found. + /// Same as cchName, but for the domain string buffer + /// pointer to a variable that receives a SID_NAME_USE value that indicates the type of the account + /// If the function succeeds, the function returns nonzero.If the function fails, it returns zero + [DllImport("advapi32.dll", CharSet = CharSet.Auto, SetLastError = true)] + public static extern bool LookupAccountSid( + string lpSystemName, + [MarshalAs(UnmanagedType.LPArray)] byte[] Sid, + StringBuilder lpName, + ref uint cchName, + StringBuilder ReferencedDomainName, + ref uint cchReferencedDomainName, + out SID_NAME_USE peUse); + + /// + /// Converts a security ID pointer to the string value + /// + /// pointer to SID + /// string value + /// error flag + [DllImport("advapi32", CharSet = CharSet.Auto, SetLastError = true)] + public static extern bool ConvertSidToStringSid( + [MarshalAs(UnmanagedType.LPArray)] byte[] pSID, + out IntPtr ptrSid); + + /// + /// Frees a pointer + /// + /// pointer to free + /// error flag + [DllImport("kernel32.dll")] + public static extern IntPtr LocalFree(IntPtr hMem); + } +} diff --git a/CommonCode/NetShareWrapper.cs b/CommonCode/NetShareWrapper.cs new file mode 100644 index 0000000..3043b64 --- /dev/null +++ b/CommonCode/NetShareWrapper.cs @@ -0,0 +1,519 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//------------------------------------------------------------------------------ +// +// Provide access to native APIs in Netapi32.dll +// +//------------------------------------------------------------------------------ + +namespace Microsoft.Hpc.Dsc.Internal +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Globalization; + using System.Runtime.InteropServices; + using System.Text; + + + /// + /// Return codes from the Netapi32.dll pinvoke calls + /// + internal enum NetShareError + { + /// + /// Success + /// + NERR_Success = 0, + + /// + /// The user does not have access to the requested information. + /// + ERROR_ACCESS_DENIED = 5, + + /// + /// Not enough storage is available to process this command + /// + ERROR_NOT_ENOUGH_MEMORY = 8, + + /// + /// The network path was not found + /// + ERROR_BAD_NETPATH = 53, + + /// + /// The specified parameter is not valid. + /// + ERROR_INVALID_PARAMETER = 87, + + /// + /// The value specified for the level parameter is not valid. + /// + ERROR_INVALID_LEVEL = 124, + + /// + /// The provided buffer was too small to hold the entire value + /// + ERROR_MORE_DATA = 234, + + /// + /// The filename, directory name, or volume label syntax is incorrect + /// + ERROR_INVALID_NAME = 123, + + /// + /// The device or directory does not exist + /// + NERR_UnknownDevDir = 2116, + + /// + /// The share name is already in use on this server + /// + NERR_DuplicateShare = 2118, + + /// + /// The client request succeeded. More entries are available. The buffer size that is specified by PreferedMaximumLength was too small to fit even a single entry + /// + NERR_BufTooSmall = 2123, + + /// + /// The share name does not exist + /// + NERR_NetNameNotFound = 2310, + + /// + /// The operation is not valid for a redirected resource. The specified device name is assigned to a shared resource + /// + NERR_RedirectedPath = 2117 + } + + // + // Provide access to native APIs in Netapi32.dll + // + internal static class NetShareWrapper + { + /// + /// Specifies information about the shared resource, including the name of the resource, + /// type and permissions, and number of connections. + /// The buf parameter points to a SHARE_INFO_2 structure. + /// + private static UInt32 _InfoLevel = 2; + + /// + /// Disk drive + /// + private static UInt32 _STYPE_DISKTREE = 0; + + /// + /// Permission to read a resource and, by default, execute the resource. + /// + private static UInt32 _PERM_FILE_READ = 1; + + /// + /// Permission to write to a resource. + /// + private static UInt32 _PERM_FILE_WRITE = 2; + + /// + /// Permission to create a resource; data can be written when creating the resource. + /// + private static UInt32 _PERM_FILE_CREATE = 4; + + /// + /// Contains information about the shared resource, including name of the resource, type and permissions, and the number of current connections + /// + [StructLayout(LayoutKind.Sequential)] + private struct _SHARE_INFO_2 + { + /// + /// Pointer to a Unicode string specifying the share name of a resource + /// + [MarshalAs(UnmanagedType.LPWStr)] + public string shi2_netname; + + /// + /// A bitmask of flags that specify the type of the shared resource + /// + [MarshalAs(UnmanagedType.U4)] + public UInt32 shi2_type; + + /// + /// Pointer to a Unicode string that contains an optional comment about the shared resource + /// + [MarshalAs(UnmanagedType.LPWStr)] + public string shi2_remark; + + /// + /// Specifies a DWORD value that indicates the shared resource's permissions for servers running with share-level security + /// + [MarshalAs(UnmanagedType.U4)] + public UInt32 shi2_permissions; + + /// + /// Specifies a DWORD value that indicates the maximum number of concurrent connections that the shared resource can accommodate. + /// The number of connections is unlimited if the value specified in this member is �1. + /// + [MarshalAs(UnmanagedType.U4)] + public UInt32 shi2_max_uses; + + /// + /// Specifies a DWORD value that indicates the number of current connections to the resource + /// + [MarshalAs(UnmanagedType.U4)] + public UInt32 shi2_current_uses; + + /// + /// Pointer to a Unicode string specifying the local path for the shared resource + /// + [MarshalAs(UnmanagedType.LPWStr)] + public string shi2_path; + + /// + /// Pointer to a Unicode string that specifies the share's password when the server is running with share-level security + /// + [MarshalAs(UnmanagedType.LPWStr)] + public string shi2_passwd; + } + + /// + /// Shares a server resource. + /// + /// Pointer to a string that specifies the DNS or NetBIOS name of the remote server on which the function is to execute. If this parameter is NULL, the local computer is used. + /// Specifies the information level of the data + /// Pointer to the buffer that specifies the data + /// Pointer to a value that receives the index of the first member of the share information structure that causes the ERROR_INVALID_PARAMETER error. If this parameter is NULL, the index is not returned on error + /// + [DllImport("Netapi32.dll")] + private static extern Int32 NetShareAdd( + [MarshalAs(UnmanagedType.LPWStr)] string ServerName, + [MarshalAs(UnmanagedType.U4)] UInt32 InfoLevel, + [MarshalAs(UnmanagedType.Struct)] ref _SHARE_INFO_2 ShareInfo, + ref IntPtr OutputBuffer); + + /// + /// Deletes a share name from a server's list of shared resources, disconnecting all connections to the shared resource. + /// + /// Pointer to a string that specifies the DNS or NetBIOS name of the remote server on which the function is to execute. If this parameter is NULL, the local computer is used + /// Pointer to a string that specifies the name of the share to delete. + /// Reserved, must be zero. + /// If the function succeeds, the return value is NERR_Success + [DllImport("Netapi32.dll")] + private static extern Int32 NetShareDel( + [MarshalAs(UnmanagedType.LPWStr)] string ServerName, + [MarshalAs(UnmanagedType.LPWStr)] string ShareName, + UInt32 ParameterReserved); + + /// + /// Retrieves information about a particular shared resource on a server. + /// + /// Pointer to a string that specifies the DNS or NetBIOS name of the remote server on which the function is to execute. If this parameter is NULL, the local computer is used + /// Pointer to a string that specifies the name of the share for which to return information + /// Specifies the information level of the data + /// Pointer to the buffer that receives the data + /// If the function succeeds, the return value is NERR_Success. + [DllImport("Netapi32.dll")] + private static extern Int32 NetShareGetInfo( + [MarshalAs(UnmanagedType.LPWStr)] string ServerName, + [MarshalAs(UnmanagedType.LPWStr)] string ShareName, + [MarshalAs(UnmanagedType.U4)] UInt32 InfoLevel, + ref IntPtr OutputBuffer); + + /// + /// The NetApiBufferFree function frees the memory that the NetApiBufferAllocate function allocates + /// + /// A pointer to a buffer returned previously by another network management function or memory allocated by calling the NetApiBufferAllocate function + /// + [DllImport("Netapi32", CharSet = CharSet.Auto)] + private static extern Int32 NetApiBufferFree(IntPtr InputBuffer); + + /// + /// Prints out the error based on the error code + /// + /// Error code returned by pinvoke call + private static void PrintError(Int32 ErrorCode) + { + switch ((NetShareError)ErrorCode) + { + case NetShareError.ERROR_ACCESS_DENIED: + Console.Error.WriteLine("Access to share denied."); + break; + + case NetShareError.ERROR_NOT_ENOUGH_MEMORY: + Console.Error.WriteLine("Not enough memory available."); + break; + + case NetShareError.ERROR_INVALID_PARAMETER: + Console.Error.WriteLine("Invalid parameter specified."); + break; + + case NetShareError.ERROR_INVALID_LEVEL: + Console.Error.WriteLine("Invalid level specified."); + break; + + case NetShareError.ERROR_MORE_DATA: + Console.Error.WriteLine("More data available and not large enough buffer specified."); + break; + + case NetShareError.ERROR_INVALID_NAME: + Console.Error.WriteLine("The filename, directory name, or volume label syntax is incorrect"); + break; + + case NetShareError.NERR_UnknownDevDir: + Console.Error.WriteLine("Unknown device specified."); + break; + + case NetShareError.NERR_DuplicateShare: + Console.Error.WriteLine("Duplicate share specified."); + break; + + case NetShareError.NERR_BufTooSmall: + Console.Error.WriteLine("Not large enough buffer specified."); + break; + + case NetShareError.NERR_NetNameNotFound: + Console.Error.WriteLine("Share name not found."); + break; + + case NetShareError.NERR_RedirectedPath: + Console.Error.WriteLine("The operation is not valid for a redirected resource. The specified device name is assigned to a shared resource."); + break; + + case NetShareError.ERROR_BAD_NETPATH: + Console.Error.WriteLine("The network path was not found."); + break; + + default: + Console.Error.WriteLine(String.Format(CultureInfo.CurrentCulture, "Unknown error occured (Error Code {0}).", ErrorCode)); + break; + } + } + + /// + /// Create a net share + /// + /// Server hosting share + /// name of share + /// path to share + /// description of share + /// success/failure + internal static int CreateShare(string ServerName, string ShareName, string SharePath, string ShareDescription) + { + // + // Ensure UNC path formatting at start + // + if (!ServerName.StartsWith(@"\\")) + { + if (!ServerName.StartsWith(@"\")) + { + ServerName = @"\\" + ServerName; + } + else + { + ServerName = @"\" + ServerName; + } + } + + // + // Build share information + // + _SHARE_INFO_2 ShareInfo = new _SHARE_INFO_2(); + ShareInfo.shi2_netname = ShareName; + ShareInfo.shi2_type = _STYPE_DISKTREE; + ShareInfo.shi2_remark = ShareDescription; + ShareInfo.shi2_permissions = _PERM_FILE_READ | _PERM_FILE_WRITE | _PERM_FILE_CREATE; + ShareInfo.shi2_max_uses = UInt32.MaxValue; + ShareInfo.shi2_current_uses = 0; + ShareInfo.shi2_path = SharePath; + ShareInfo.shi2_passwd = String.Empty; + + IntPtr OutputBuffer = IntPtr.Zero; + + // + // Create the share and report success or failure + // + Int32 ErrorCode = NetShareAdd(ServerName, _InfoLevel, ref ShareInfo, ref OutputBuffer); + + + if (ErrorCode != (Int32)NetShareError.NERR_Success) + { + PrintError(ErrorCode); + } + + return ErrorCode; + } + + /// + /// Deletes an existing share + /// + /// Server hosting share + /// name of share + /// success/failure + internal static bool DeleteShare(string ServerName, string ShareName) + { + // + // Ensure UNC path formatting at start + // + if (!ServerName.StartsWith(@"\\")) + { + if (!ServerName.StartsWith(@"\")) + { + ServerName = @"\\" + ServerName; + } + else + { + ServerName = @"\" + ServerName; + } + } + + // + // Attempt to delete the share and report success/failure + // + Int32 ErrorCode = NetShareDel(ServerName, ShareName, 0); + if (ErrorCode == (Int32)NetShareError.NERR_Success) + { + return (true); + } + else + { + PrintError(ErrorCode); + } + + return (false); + } + + /// + /// Get the drive where a share is hosted + /// + /// Server hosting share + /// Name of share + /// drive letter (empty string if unsuccessful) + internal static string GetLocalDrive(string ServerName, string ShareName) + { + // + // Ensure UNC path formatting at start + // + if (!ServerName.StartsWith(@"\\")) + { + if (!ServerName.StartsWith(@"\")) + { + ServerName = @"\\" + ServerName; + } + else + { + ServerName = @"\" + ServerName; + } + } + + string ShareDrive = String.Empty; + string SharePath = String.Empty; + IntPtr OutputBuffer = IntPtr.Zero; + + // + // Attempt to get the share information. Report error if unsuccessful + // + Int32 ErrorCode = NetShareGetInfo(ServerName, ShareName, _InfoLevel, ref OutputBuffer); + if (ErrorCode == (Int32)NetShareError.NERR_Success) + { + _SHARE_INFO_2 ShareInfo = (_SHARE_INFO_2)Marshal.PtrToStructure(OutputBuffer, typeof(_SHARE_INFO_2)); + SharePath = ShareInfo.shi2_path; + NetApiBufferFree(OutputBuffer); + } + else + { + PrintError(ErrorCode); + return (ShareDrive); + } + + if (!String.IsNullOrEmpty(SharePath)) + { + // + // If a share path was returned, attempt to parse out the drive letter + // + int Index = SharePath.IndexOf(':'); + if (Index > 0) + { + ShareDrive = SharePath.Substring(0, Index + 1); + } + } + + return (ShareDrive); + } + + // TODO: Sync Dryad GM and vertex with new version that returns error code + internal static string GetLocalPath(string ServerName, string ShareName) + { + string SharePath; + + int err = GetLocalPath(ServerName, ShareName, out SharePath); + if (err != (int)NetShareError.NERR_Success) + { + Console.Error.WriteLine("GetLocalPath failed: server {0}, share {1}", ServerName, ShareName); + PrintError(err); + } + + return SharePath; + } + + /// + /// Returns the local path to a share on the specified server + /// + /// Server the path resides on + /// Share to find path to + /// Local path to share or empty string if failure + internal static int GetLocalPath(string ServerName, string ShareName, out string SharePath) + { + // + // Ensure UNC path formatting at start + // + if (!String.IsNullOrEmpty(ServerName) && !ServerName.StartsWith(@"\\")) + { + if (!ServerName.StartsWith(@"\")) + { + ServerName = @"\\" + ServerName; + } + else + { + ServerName = @"\" + ServerName; + } + } + + IntPtr OutputBuffer = IntPtr.Zero; + SharePath = String.Empty; + + // + // Get share info structure + // + Int32 ErrorCode = NetShareGetInfo(ServerName, ShareName, _InfoLevel, ref OutputBuffer); + + if (ErrorCode == (Int32)NetShareError.NERR_Success) + { + // + // If successful, get the local path to the resource and free the buffer + // + _SHARE_INFO_2 ShareInfo = (_SHARE_INFO_2)Marshal.PtrToStructure(OutputBuffer, typeof(_SHARE_INFO_2)); + SharePath = ShareInfo.shi2_path; + NetApiBufferFree(OutputBuffer); + } + + return ErrorCode; + } + } +} \ No newline at end of file diff --git a/CommonCode/ProcessPathHelper.cs b/CommonCode/ProcessPathHelper.cs new file mode 100644 index 0000000..7d13e7b --- /dev/null +++ b/CommonCode/ProcessPathHelper.cs @@ -0,0 +1,172 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//------------------------------------------------------------------------------ +// +// Helper class that builds working directories for dryad jobs +// +//------------------------------------------------------------------------------ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.IO; + using System.Security.Principal; + using System.Security.AccessControl; + using Microsoft.Hpc.Dsc.Internal; + + /// + /// Helper class that builds working directories for dryad jobs + /// + public static class ProcessPathHelper + { + /// + /// Current job directory path + /// + private static string jobPath = string.Empty; + + /// + /// Lockable object used to synchronize job path creation + /// + private static object jobPathLock = new object(); + + /// + /// Build the current job directory path + /// + private static void BuildJobPath() + { + // + // Get the path to the HpcTemp share on the local machine and append the current user name to build the job path. + // If unable to get local path to HpcTemp share, or necessary environment variables, leave jobPath as string.empty. + // + string rootDir = RootWorkingDirectory; + + if (!string.IsNullOrEmpty(rootDir)) + { + try + { + string userName = Environment.GetEnvironmentVariable("USERNAME"); + if (!string.IsNullOrEmpty(userName)) + { + // + // If all goes well, assign the path to jobPath, so it can accessed. + // + ProcessPathHelper.jobPath = Path.Combine(Path.Combine(rootDir, userName), Environment.GetEnvironmentVariable("CCP_JOBID")); + } + } + catch (Exception ex) + { + Console.Error.WriteLine("ProcessPathHelper.BuildJobPath: Unable to get environment variable: {0}", ex.Message); + } + } + else + { + Console.Error.WriteLine("ProcessPathHelper.BuildJobPath: Unable to get local path to {0} share.", Constants.DscTempShare); + } + } + + /// + /// Returns the path to the specified vertex dir + /// + /// Vertex id + /// path to vertex directory + public static string ProcessPath(int id) + { + return Path.Combine(JobPath, id.ToString()); + } + + /// + /// Returns path to the specified vertex's working dir + /// + /// Vertex ID + /// path to the vertex working directory + public static string ProcessWorkingDirectory(int id) + { + return Path.Combine(ProcessPath(id), "WD"); + } + + /// + /// Gets path to directory for the current job. + /// If job path cannot be retrieved, string.empty may be returned. + /// + public static string JobPath + { + get + { + // + // If job path hasn't been built, build it. + // Lock to allow multiple vertices to access this property without issue. + // + if (string.IsNullOrEmpty(ProcessPathHelper.jobPath)) + { + lock (jobPathLock) + { + if (string.IsNullOrEmpty(ProcessPathHelper.jobPath)) + { + ProcessPathHelper.BuildJobPath(); + } + } + } + + return ProcessPathHelper.jobPath; + } + } + + /// + /// Get the root working directory + /// + public static string RootWorkingDirectory + { + get + { + return (NetShareWrapper.GetLocalPath("localhost", Constants.DscTempShare)); + } + } + + + /// + /// + /// + public static void CreateUserWorkingDirectory() + { + //Create user specific working directory if does not exist and set full control to the user and administrators only. + + string userWorkingDirectory = String.Format(@"{0}\{1}", ProcessPathHelper.RootWorkingDirectory, Environment.UserName); + + if (!Directory.Exists(userWorkingDirectory)) + { + DirectorySecurity directorySecurity = new DirectorySecurity(); + + //Add full control to user + + directorySecurity.AddAccessRule(new FileSystemAccessRule(WindowsIdentity.GetCurrent().Name, FileSystemRights.FullControl, InheritanceFlags.ContainerInherit | InheritanceFlags.ObjectInherit, PropagationFlags.None, AccessControlType.Allow)); + + //Add full control to administrators + + SecurityIdentifier administratorGroup = new SecurityIdentifier(WellKnownSidType.BuiltinAdministratorsSid, null); + + directorySecurity.AddAccessRule(new FileSystemAccessRule(administratorGroup, FileSystemRights.FullControl, InheritanceFlags.ContainerInherit | InheritanceFlags.ObjectInherit, PropagationFlags.None, AccessControlType.Allow)); + + Directory.CreateDirectory(userWorkingDirectory, directorySecurity); + } + + } + } +} \ No newline at end of file diff --git a/CommonCode/ProcessState.cs b/CommonCode/ProcessState.cs new file mode 100644 index 0000000..e87e492 --- /dev/null +++ b/CommonCode/ProcessState.cs @@ -0,0 +1,38 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +namespace Microsoft.Research.Dryad +{ + [System.Flags] + public enum ProcessState + { + // Initial states + Uninitialized = 0x00, + Unscheduled = 0x01, + + // "Running" states + AssignedToNode = 0x10, + Running = 0x11, + + // Terminal states + SchedulingFailed = 0x20, + Completed = 0x21 + } +} diff --git a/CommonCode/QueryUtility.cs b/CommonCode/QueryUtility.cs new file mode 100644 index 0000000..29c2a40 --- /dev/null +++ b/CommonCode/QueryUtility.cs @@ -0,0 +1,85 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; + +namespace Microsoft.Research.Dryad +{ + class HpcQueryUtility + { + /// + /// Converts numeric trace level to string representation + /// + /// + internal static string ConvertTraceLevelToString(int runtimeTraceLevel) + { + switch (runtimeTraceLevel) + { + case Constants.traceOffNum: + return Constants.traceOff; + + case Constants.traceCriticalNum: + return Constants.traceCritical; + + case Constants.traceErrorNum: + return Constants.traceError; + + case Constants.traceWarningNum: + return Constants.traceWarning; + + case Constants.traceInfoNum: + return Constants.traceInfo; + + case Constants.traceVerboseNum: + return Constants.traceVerbose; + + default: + return Constants.traceError; + }; + } + + internal static int GetTraceLevelFromString(string runtimeTraceLevel) + { + switch (runtimeTraceLevel) + { + case Constants.traceOff: + return Constants.traceOffNum; + + case Constants.traceCritical: + return Constants.traceCriticalNum; + + case Constants.traceError: + return Constants.traceErrorNum; + + case Constants.traceWarning: + return Constants.traceWarningNum; + + case Constants.traceInfo: + return Constants.traceInfoNum; + + case Constants.traceVerbose: + return Constants.traceVerboseNum; + + default: + return Constants.traceErrorNum; + }; + } + } +} diff --git a/CommonCode/RetryFramework.cs b/CommonCode/RetryFramework.cs new file mode 100644 index 0000000..76b27a2 --- /dev/null +++ b/CommonCode/RetryFramework.cs @@ -0,0 +1,358 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//------------------------------------------------------------------------------ +// Security review: nzeng 01-11-06 +//------------------------------------------------------------------------------ + +#region Using directives + +using System; +using System.Text; +using System.Diagnostics; +using System.Threading; + +#endregion + +namespace Microsoft.Hpc +{ + internal class RetryManager + { + public const int InfiniteRetries = -1; + + RetryWaitTimer _waitTimer; + + int _maxRetries; + int _totalTimeLimit = Timeout.Infinite; + + int _retryCount = 0; + int _totalWaitTime = 0; + int _currentWaitTime = 0; + + public RetryManager(RetryWaitTimer waitTimer) : this(waitTimer, InfiniteRetries) { } + public RetryManager(RetryWaitTimer waitTimer, int maxRetries) : this(waitTimer, maxRetries, Timeout.Infinite) { } + + public RetryManager(RetryWaitTimer waitTimer, int maxRetries, int totalTimeLimit) + { + if (waitTimer == null) + { + throw new ArgumentNullException("wait"); + } + _waitTimer = waitTimer; + + SetMaxRetries(maxRetries); + SetTotalTimeLimit(totalTimeLimit); + } + + + /// + /// Gets the number of retries attempted thus far + /// + public int RetryCount { get { return _retryCount; } } + + /// + /// Get the total spent waiting between retries + /// + public int ElaspsedWaitTime { get { return _totalWaitTime; } } + + /// + /// Gets or sets the maximum number of retries + /// + public int MaxRetryCount + { + get { return _maxRetries; } + set { SetMaxRetries(value); } + } + + /// + /// Gets or sets the total amount of time that may be spend waiting for retries. + /// + public int TotalTimeLimit + { + get { return _totalTimeLimit; } + set { SetTotalTimeLimit(value); } + } + + void SetMaxRetries(int n) + { + if (n <= 0 && n != RetryManager.InfiniteRetries) + { + throw new ArgumentException("The maximum number of retries must be greater than zero, or RetryOperator.InfiniteRetries"); + } + _maxRetries = n; + } + + void SetTotalTimeLimit(int t) + { + if (t <= 0 && t != Timeout.Infinite) + { + throw new ArgumentException("The specified time must be greater than zero, or Timeout.Infinite"); + } + _totalTimeLimit = t; + } + + + /// + /// Returns true if there are more retries left + /// + public bool HasAttemptsLeft + { + get + { + return ((_maxRetries == RetryManager.InfiniteRetries || _retryCount < _maxRetries) + && (_totalTimeLimit == Timeout.Infinite || _totalWaitTime < _totalTimeLimit)); + } + } + + /// + /// Get the next wait time + /// + public int NextWaitTime + { + get + { + int waitTime = _waitTimer.GetNextWaitTime(_retryCount, _currentWaitTime); + if (_totalTimeLimit != Timeout.Infinite && (_totalWaitTime + waitTime > _totalTimeLimit)) + { + waitTime = _totalTimeLimit - _totalWaitTime; + } + return waitTime; + } + } + + /// + /// Increment the retry count and advance the total wait time without actually waiting + /// + public void SimulateNextAttempt() + { + WaitForNextAttempt(false); + } + + + /// + /// Wait until the next retry by making the current thread sleep for the appropriate amount of time. + /// May return immediately if the wait is zero. + /// + public void WaitForNextAttempt() + { + WaitForNextAttempt(true); + } + + + void WaitForNextAttempt(bool doSleep) + { + if (!HasAttemptsLeft) + { + throw new InvalidOperationException("There are no more retry attempts remaining"); + } + + _currentWaitTime = NextWaitTime; + _retryCount++; + + Debug.Assert(_currentWaitTime >= 0); + if (_currentWaitTime > 0) + { + if (doSleep) + { + Thread.Sleep(_currentWaitTime); + } + _totalWaitTime += _currentWaitTime; + } + } + + /// + /// Resets the retry manager's retry count + /// + public void Reset() + { + _retryCount = 0; + _totalWaitTime = 0; + _currentWaitTime = 0; + } + } + + + /// + /// Defines how long a retry manager will wait between sub-sequent retries + /// + internal abstract class RetryWaitTimer + { + internal abstract int GetNextWaitTime(int retryCount, int currentWaitTime); + } + + + /// + /// Instantly returns without waiting + /// + internal class InstantRetryTimer : RetryWaitTimer + { + internal override int GetNextWaitTime(int retryCount, int currentWaitTime) + { + return 0; + } + + // This class should be a singleton + private InstantRetryTimer() { } + + static InstantRetryTimer _instance = new InstantRetryTimer(); + public static InstantRetryTimer Instance + { + get { return _instance; } + } + } + + /// + /// Waits a constant time between subsequent retries + /// + internal class PeriodicRetryTimer : RetryWaitTimer + { + int _period; + + public PeriodicRetryTimer(int period) + { + if (period < 0) + { + throw new ArgumentOutOfRangeException("period", "The period must be a non-negative integer (in milliseconds)"); + } + _period = period; + } + + internal override int GetNextWaitTime(int retryCount, int currentWaitTime) + { + return _period; + } + } + + /// + /// A retry timer where wait time at retry n depends on the wait at retry n-1. + /// + internal abstract class BoundedBackoffRetryTimer : RetryWaitTimer + { + int _initialWait; + int _waitUpperBound; + + protected BoundedBackoffRetryTimer(int initialWait, int waitUpperBound) + { + if (initialWait <= 0) + { + throw new ArgumentOutOfRangeException("initialWait", "Initial value must be a positive integer (in milliseconds)"); + } + if (waitUpperBound <= 0 && waitUpperBound != Timeout.Infinite) + { + throw new ArgumentOutOfRangeException("waitCap", "The wait cap must be greater than zero, or Timeout.Infinite"); + } + + _initialWait = initialWait; + _waitUpperBound = waitUpperBound; + } + + internal override int GetNextWaitTime(int retryCount, int currentWaitTime) + { + if (retryCount == 0) + { + return _initialWait; + } + + int nextWaitTime = GetBackOffValue(currentWaitTime); + if (nextWaitTime < 0) + { + return 0; + } + if (_waitUpperBound != Timeout.Infinite && nextWaitTime > _waitUpperBound) + { + return _waitUpperBound; + } + return nextWaitTime; + } + + protected abstract int GetBackOffValue(int currentValue); + } + + + /// + /// Wait times will increase exponentially + /// + internal class ExponentialBackoffRetryTimer : BoundedBackoffRetryTimer + { + double _growthFactor; + + public ExponentialBackoffRetryTimer(int initialWait) : this(initialWait, Timeout.Infinite, 2) { } + public ExponentialBackoffRetryTimer(int initialWait, int waitUpperBound) : this(initialWait, waitUpperBound, 2) { } + + public ExponentialBackoffRetryTimer(int initialWait, int waitUpperBound, double growthFactor) + : base(initialWait, waitUpperBound) + { + if (growthFactor <= 0) + { + throw new ArgumentOutOfRangeException("growthFactor", "The growth factor must be a positive value"); + } + _growthFactor = growthFactor; + } + + protected override int GetBackOffValue(int currentValue) + { + return (int)Math.Round(currentValue * _growthFactor); + } + } + + /// + /// Wait times will increase exponentially and also vary a bit randomly + /// + internal class ExponentialRandomBackoffRetryTimer : ExponentialBackoffRetryTimer + { + Random _rand = null; + public ExponentialRandomBackoffRetryTimer(int initialWait) : this(initialWait, Timeout.Infinite, 2) { } + public ExponentialRandomBackoffRetryTimer(int initialWait, int waitUpperBound) : this(initialWait, waitUpperBound, 2) { } + + public ExponentialRandomBackoffRetryTimer(int initialWait, int waitUpperBound, double growthFactor) + : base(initialWait, waitUpperBound,growthFactor) + { + _rand = new Random(); + } + + protected override int GetBackOffValue(int currentValue) + { + return ((int)base.GetBackOffValue(currentValue)) + _rand.Next(0, currentValue); + } + } + + + /// + /// Wait times will increase linearly + /// + internal class LinearBackoffRetryTimer : BoundedBackoffRetryTimer + { + int _increment; + + public LinearBackoffRetryTimer(int initialWait) : this(initialWait, Timeout.Infinite, initialWait) { } + public LinearBackoffRetryTimer(int initialWait, int waitUpperBound) : this(initialWait, waitUpperBound, initialWait) { } + + public LinearBackoffRetryTimer(int initialWait, int waitUpperBound, int increment) + : base(initialWait, waitUpperBound) + { + _increment = increment; + } + + protected override int GetBackOffValue(int currentValue) + { + return currentValue + _increment; + } + } +} \ No newline at end of file diff --git a/CommonCode/SchedulerHelper.cs b/CommonCode/SchedulerHelper.cs new file mode 100644 index 0000000..c17a87d --- /dev/null +++ b/CommonCode/SchedulerHelper.cs @@ -0,0 +1,753 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.Globalization; + using System.ServiceModel; + using System.ServiceModel.Channels; + using System.Threading; + using System.Collections.Generic; + using System.Collections.Specialized; + using System.Collections.Concurrent; + using System.Reflection; + using System.Net.Security; + using System.Security; + using System.Security.AccessControl; + using System.Security.Principal; + using System.Text; + using System.Xml; + using System.Xml.Serialization; + using Microsoft.Hpc; + using Microsoft.Win32; + using Microsoft.Research.Dryad.YarnBridge; + + [Flags] + public enum VertexTaskState + { + NA = 0x00000000, + Waiting = 0x00000010, + Running = 0x00000020, + Finished = 0x00000030, + Failed = 0x00000031, + Canceled = 0x00000032 + } + + public class VertexChangeEventArgs : EventArgs + { + public VertexChangeEventArgs(int id) + { + Id = id; + OldState = VertexTaskState.NA; + OldNode = String.Empty; + OldRequeueCount = 0; + } + + public int Id + { + get; + private set; + } + + public VertexTaskState OldState + { + get; + set; + } + + public VertexTaskState NewState + { + get; + set; + } + + public string OldNode + { + get; + set; + } + + public string NewNode + { + get; + set; + } + + public int OldRequeueCount + { + get; + set; + } + + public int NewRequeueCount + { + get; + set; + } + } + + public class VertexComputeNode + { + public string ComputeNode; + public int instanceId; + public VertexTaskState State; + } + + public delegate void VertexChangeEventHandler(object sender, VertexChangeEventArgs e); + + public interface ISchedulerHelper : IDisposable + { + event VertexChangeEventHandler OnVertexChange; + + void FinishJob(); + + string GetVertexServiceBaseAddress(string nodename, int instanceId); + + NetTcpBinding GetVertexServiceBinding(); + + void SetJobProgress(int n, string message); + + bool StartTaskMonitorThread(); + + void StopTaskMonitorThread(); + + bool WaitForTasksReady(); + + } + + public class SchedulerHelperFactory + { + private static ISchedulerHelper m_instance = null; + private static object m_lock = new object(); + + public static ISchedulerHelper GetInstance() + { + if (m_instance == null) + { + lock (m_lock) + { + if (m_instance == null) + { + string schedulerType = System.Environment.GetEnvironmentVariable(Constants.schedulerTypeEnvVar); + if (String.IsNullOrEmpty(schedulerType) || schedulerType == Constants.schedulerTypeYarn) + { + m_instance = new YarnSchedulerHelper(); + } + else if (schedulerType == Constants.schedulerTypeLocal) + { + m_instance = new LocalSchedulerHelper(); + } + else + { + throw new InvalidOperationException(String.Format("Scheduler type {0} is not supported", schedulerType)); + } + } + } + } + return m_instance; + } + } + + public class LocalSchedulerHelper : ISchedulerHelper + { + protected string m_EnvCcpLocalProcessComputeNodes = System.Environment.GetEnvironmentVariable(Constants.localProcessComputeNodesEnvVar); + protected string[] m_LocalProcessComputeNodes = new string[0]; + private bool m_disposed = false; + + + private VertexChangeEventHandler m_vertexChangeEvent; + + event VertexChangeEventHandler ISchedulerHelper.OnVertexChange + { + add + { + lock (m_vertexChangeEvent) + { + m_vertexChangeEvent += value; + } + } + remove + { + lock (m_vertexChangeEvent) + { + m_vertexChangeEvent -= value; + } + } + } + + public LocalSchedulerHelper() + { + if (!String.IsNullOrEmpty(m_EnvCcpLocalProcessComputeNodes)) + { + m_LocalProcessComputeNodes = m_EnvCcpLocalProcessComputeNodes.Split(','); + } + } + + void IDisposable.Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!m_disposed) + { + if (disposing) + { + } + m_disposed = true; + } + } + + void ISchedulerHelper.FinishJob() + { + + } + + string ISchedulerHelper.GetVertexServiceBaseAddress(string nodename, int instanceId) + { + return String.Format(Constants.vertexAddrFormat, "localhost", nodename); + } + + NetTcpBinding ISchedulerHelper.GetVertexServiceBinding() + { + NetTcpBinding binding = new NetTcpBinding(SecurityMode.Transport, false); + binding.PortSharingEnabled = true; + binding.Security.Transport.ClientCredentialType = TcpClientCredentialType.Windows; + binding.Security.Transport.ProtectionLevel = ProtectionLevel.None; + binding.SendTimeout = Constants.SendTimeout; + binding.ReceiveTimeout = Constants.ReceiveTimeout; + binding.MaxReceivedMessageSize = Constants.MaxReceivedMessageSize; + binding.MaxBufferPoolSize = Constants.MaxBufferPoolSize; + binding.MaxConnections = Constants.MaxConnections; + binding.ListenBacklog = Constants.ListenBacklog; + binding.ReaderQuotas = System.Xml.XmlDictionaryReaderQuotas.Max; + + return binding; + + } + + void ISchedulerHelper.SetJobProgress(int n, string message) + { + } + + bool ISchedulerHelper.StartTaskMonitorThread() + { + return true; + } + + void ISchedulerHelper.StopTaskMonitorThread() + { + return; + } + + bool ISchedulerHelper.WaitForTasksReady() + { + return true; + } + + } + + public class YarnSchedulerHelper : ISchedulerHelper + { + public enum YarnTaskState + { + NA, + Scheduling, + Running, + Completed, + Failed + } + + public class VertexTask + { + public VertexTask(int id, string node, YarnTaskState state, int requeueCount, DateTime changeTime) + { + Id = id; + Node = node; + State = state; + RequeueCount = requeueCount; + ChangeTime = changeTime; + } + + public int Id + { + get; + set; + } + + public string Node + { + get; + set; + } + + public YarnTaskState State + { + get; + set; + } + + public int RequeueCount + { + get; + set; + } + + public DateTime ChangeTime + { + get; + set; + } + } + + private VertexTask[] m_vertices = null; + + protected string m_EnvCcpClusterName = System.Environment.GetEnvironmentVariable(Constants.clusterNameEnvVar); + protected int m_EnvCcpJobId = Convert.ToInt32(System.Environment.GetEnvironmentVariable(Constants.jobIdEnvVar)); + + private object m_eventLock = new object(); + private VertexChangeEventHandler m_vertexChangeEvent; + + private int m_minNodes = -1; + private int m_maxNodes = -1; + private int m_startNodes = -1; + private int m_runningTasks = 0; // No longer start at 1 since the GM is not running under a task + private int m_finishedTasks = 0; + private object m_lock = new object(); + private bool m_disposed = false; + private AutoResetEvent m_taskChangeEvt = new AutoResetEvent(false); + + private bool m_taskMonitorThreadRunning = false; + private ManualResetEvent m_threadStopEvt = new ManualResetEvent(false); + private Thread m_taskMonitorThread = null; + private BlockingCollection m_taskUpdateQueue; + + public AMInstance m_appMaster; + + private const int GM_EXITCODE_CANNOT_ACCESS_SCHEDULER = 1000; + + #region Events + + event VertexChangeEventHandler ISchedulerHelper.OnVertexChange + { + add + { + lock (m_eventLock) + { + m_vertexChangeEvent += value; + } + } + remove + { + lock (m_eventLock) + { + m_vertexChangeEvent -= value; + } + } + } + + private AMInstance GetScheduler() + { + return m_appMaster; + } + + #endregion + + #region Properties + + private int JobMinNodes + { + get + { + return m_minNodes; + } + } + + private int JobMaxNodes + { + get + { + return m_maxNodes; + } + } + + private int JobStartNodes + { + get + { + return m_startNodes; + } + } + + #endregion + + public YarnSchedulerHelper() + { + // init the DryadLogger, just to make sure + DryadLogger.Start("xcompute.log"); + m_taskUpdateQueue = new BlockingCollection(); + + // if we are not running in a vertex, then init the GM + string jmString = Environment.GetEnvironmentVariable(Constants.jobManager); + if (String.IsNullOrEmpty(jmString)) + { + m_minNodes = int.Parse(Environment.GetEnvironmentVariable("MINIMUM_COMPUTE_NODES")); + m_maxNodes = int.Parse(Environment.GetEnvironmentVariable("MAXIMUM_COMPUTE_NODES")); + m_startNodes = m_minNodes; + + m_vertices = new VertexTask[JobMaxNodes + 2]; + DryadLogger.LogInformation("YarnSchedulerHelper()", "Initializing JAVA GM"); + DryadLogger.LogInformation("YarnSchedulerHelper()", "m_maxNodes: {0}", m_maxNodes); + AMInstance.RegisterGMCallback(new UpdateProcessState(QueueYarnUpdate)); + ((ISchedulerHelper)this).OnVertexChange += new VertexChangeEventHandler(OnVertexChangeHandler); + m_appMaster = new AMInstance(); + + } + else + { + m_vertices = new VertexTask[JobMaxNodes + 2]; + DryadLogger.LogInformation("YarnSchedulerHelper()", "Not initializing JAVA GM"); + } + + + } + + #region Methods + + void IDisposable.Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!m_disposed) + { + if (disposing) + { + if (m_appMaster != null) + { + m_appMaster.Close(); + } + this.m_taskChangeEvt.Close(); + } + + m_disposed = true; + } + } + + void ISchedulerHelper.FinishJob() + { + m_appMaster.Finish(); + } + + string ISchedulerHelper.GetVertexServiceBaseAddress(string nodename, int instanceId) + { + return String.Format(Constants.vertexAddrFormat, nodename, instanceId); + } + + NetTcpBinding ISchedulerHelper.GetVertexServiceBinding() + { + NetTcpBinding binding = new NetTcpBinding(SecurityMode.Transport, false); + binding.Security.Transport.ClientCredentialType = TcpClientCredentialType.Windows; + binding.Security.Transport.ProtectionLevel = ProtectionLevel.None; + binding.SendTimeout = Constants.VertexSendTimeout; + binding.ReceiveTimeout = Constants.ReceiveTimeout; + binding.MaxReceivedMessageSize = Constants.MaxReceivedMessageSize; + binding.MaxBufferPoolSize = Constants.MaxBufferPoolSize; + binding.MaxConnections = Constants.MaxConnections; + binding.ListenBacklog = Constants.ListenBacklog; + binding.ReaderQuotas = System.Xml.XmlDictionaryReaderQuotas.Max; + + return binding; + + } + + void ISchedulerHelper.SetJobProgress(int n, string message) + { + DryadLogger.LogWarning("SetJobProgress", "n: {0} message: {1}", n, message); + } + + + + bool ISchedulerHelper.StartTaskMonitorThread() + { + // We only want to have one of these threads running, in case we get called more than once + if (m_taskMonitorThreadRunning == false) + { + lock (m_lock) + { + if (m_taskMonitorThreadRunning == false) + { + ((ISchedulerHelper)this).OnVertexChange += new VertexChangeEventHandler(OnVertexChangeHandler); + try + { + m_taskMonitorThread = new Thread(new ThreadStart(TaskMonitorThread)); + m_taskMonitorThread.Start(); + m_taskMonitorThreadRunning = true; + return true; + } + catch (Exception e) + { + DryadLogger.LogCritical(0, e, "Failed to start task monitoring thread"); + return false; + } + } + } + } + return true; + } + + void ISchedulerHelper.StopTaskMonitorThread() + { + DryadLogger.LogMethodEntry(); + bool wait = false; + if (m_taskMonitorThreadRunning) + { + lock (m_lock) + { + if (m_taskMonitorThreadRunning) + { + m_threadStopEvt.Set(); + wait = true; + } + } + } + + if (wait) + { + try + { + m_taskMonitorThread.Join(); + } + catch (Exception e) + { + DryadLogger.LogError(0, e, "Failed to wait for task monitor thread to stop."); + } + } + DryadLogger.LogMethodExit(); + } + + private VertexTaskState YarnTaskStateToVertexTaskState(YarnTaskState ts) + { + VertexTaskState vts = VertexTaskState.NA; + if (ts == YarnTaskState.NA) + { + vts = VertexTaskState.NA; + } + else if (ts < YarnTaskState.Running) + { + vts = VertexTaskState.Waiting; + } + else if (ts == YarnTaskState.Running) + { + vts = VertexTaskState.Running; + } + else + { + switch (ts) + { + case YarnTaskState.Completed: + vts = VertexTaskState.Finished; + break; + case YarnTaskState.Failed: + vts = VertexTaskState.Failed; + break; + //case TaskState.Canceled: + //case TaskState.Canceling: + // vts = VertexTaskState.Canceled; + // break; + } + } + DryadLogger.LogDebug("Task State", "Mapped ts: {0} to vts: {1}", ts, vts); + return vts; + } + + bool ISchedulerHelper.WaitForTasksReady() + { + // The basic strategy is to wait for the maximum number of vertex tasks which is + // practical. Start by waiting for AllocatedNodes.Count. As tasks fail or are cancelled, + // decrement the number of tasks to wait for until we drop below Min at which time the + // scheduler will end the job. Also, if tasks are rerun, increment the number of tasks to wait for. + do + { + // Event set by the Task Monitor Thread when it finishes processes a batch of changes. + m_taskChangeEvt.WaitOne(); + + // Don't want OnVertexChangeHandler updating these counts while we're checking them + lock (this) + { + DryadLogger.LogInformation("Wait for vertex tasks", + "{0} tasks are running, waiting for at least {1} before starting", + m_runningTasks, m_startNodes); + if (m_runningTasks >= m_startNodes) + { + // We have enough running tasks to start + DryadLogger.LogDebug("Wait for vertex tasks", + "Sufficient number of tasks transitioned to running to begin: {0} running tasks", + m_runningTasks); + return true; + } + } + + } while (true); + } + + public void QueueYarnUpdate(int taskId, int taskState, string nodeName) + { + DryadLogger.LogInformation("QueueYarnUpdate", "Task {0} on node {2} is in state {3}", taskId, nodeName, + taskState); + // Set change event arguments + YarnTaskState yTaskState = (YarnTaskState)taskState; + VertexTask v = new VertexTask(taskId, nodeName, yTaskState, int.MaxValue, DateTime.UtcNow); + m_taskUpdateQueue.Add(v); + } + + public void ProcessYarnUpdate(VertexTask v) + { + DryadLogger.LogInformation("ProcessYarnUpdate", "Task {0} on node {1} is in state {2}", v.Id, v.Node, + v.State); + VertexChangeEventArgs e = new VertexChangeEventArgs(v.Id); + + e.NewNode = v.Node; + e.NewState = YarnTaskStateToVertexTaskState(v.State); + e.NewRequeueCount = v.RequeueCount; + + if (m_vertices[v.Id] != null) + { + e.OldNode = m_vertices[v.Id].Node; + e.OldState = YarnTaskStateToVertexTaskState(m_vertices[v.Id].State); + e.OldRequeueCount = m_vertices[v.Id].RequeueCount; + } + + if (e.NewRequeueCount != e.OldRequeueCount) + { + DryadLogger.LogInformation("ProcessYarnUpdate", "Task {0} requeue count changed from {1} to {2}", + v.Id, e.OldRequeueCount, e.NewRequeueCount); + } + + // Update current vertex state + m_vertices[v.Id] = v; + m_vertexChangeEvent(this, e); + //m_taskChangeEvt.Set(); + } + + private void TaskMonitorThread() + { + TimeSpan pollInterval = TimeSpan.FromSeconds(1); + TimeSpan maxPollInterval = TimeSpan.FromSeconds(16); + + // The main loop. Each iteration polls for task changes. + while (true) + { + bool foundUpdate = false; + DateTime loopStartTime = DateTime.Now; + // + // Process change results from blocking queue + // + do + { + VertexTask v = null; + if (m_taskUpdateQueue.TryTake(out v, pollInterval)) + { + foundUpdate = true; + ProcessYarnUpdate(v); + } + + } while ((DateTime.Now - loopStartTime) < pollInterval); + + if (foundUpdate) + { + // Notify WaitForTasksReady once for each polling cycle + // so that it gets all the changes in one batch + m_taskChangeEvt.Set(); + } + + // Check to see if we've been told to stop. + // Timeout after pollInterval. + // TODO: For better shutdown perf, we may want to check this at other places + // or just kill the thread - but this provides a more graceful exit. + if (m_threadStopEvt.WaitOne(pollInterval, true)) + { + m_taskMonitorThreadRunning = false; + DryadLogger.LogInformation("Task Monitoring Thread", "Received shutdown event"); + return; + } + + // Double the polling interval each iteration up to maxPollInterval + if (pollInterval < maxPollInterval) + { + double newSeconds = 2 * pollInterval.TotalSeconds; + if (newSeconds < maxPollInterval.TotalSeconds) + { + pollInterval = TimeSpan.FromSeconds(newSeconds); + } + else + { + pollInterval = maxPollInterval; + } + } + + } + } + + private void OnVertexChangeHandler(object sender, VertexChangeEventArgs ve) + { + if (ve.OldState != ve.NewState) + { + // Don't want to update counts while WaitForTasksReady is checking them + lock (this) + { + if (ve.OldState == VertexTaskState.Running) + { + m_runningTasks--; + } + else if (ve.OldState > VertexTaskState.Running) + { + m_finishedTasks--; + // Task transitioning from a completed state so we can increment + // the number of tasks to wait for at startup + m_startNodes++; + } + + if (ve.NewState == VertexTaskState.Running) + { + m_runningTasks++; + } + else if (ve.NewState > VertexTaskState.Running) + { + m_finishedTasks++; + // Task transitioning to a completed state so we need to + // decrement the number of tasks to wait for at startup. + m_startNodes--; + } + } + } + } + + #endregion + } + + +} diff --git a/CommonCode/SoftAffinity.cs b/CommonCode/SoftAffinity.cs new file mode 100644 index 0000000..f77f629 --- /dev/null +++ b/CommonCode/SoftAffinity.cs @@ -0,0 +1,79 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.ServiceModel.Channels; + using System.Runtime.Serialization; + + [DataContract] + public class SoftAffinity : IComparable + { + [DataMember] + private string node = null; + [DataMember] + private ulong weight = ulong.MinValue; + + public SoftAffinity(string n, ulong w) + { + node = n; + weight = w; + } + + // Compare SoftAffinities using their weights + int IComparable.CompareTo(object obj) + { + SoftAffinity other = obj as SoftAffinity; + + if (other == null) + { + return -1; + } + + if (this.Weight < other.Weight) + { + return -1; + } + else if (this.Weight > other.Weight) + { + return 1; + } + else + { + return 0; + } + } + + [DataMember] + public string Node + { + get { return node; } + set { node = value; } + } + + [DataMember] + public ulong Weight + { + get { return weight; } + set { weight = value; } + } + } +} diff --git a/Dryad.sln b/Dryad.sln new file mode 100644 index 0000000..aa2f233 --- /dev/null +++ b/Dryad.sln @@ -0,0 +1,359 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio 2012 +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "DryadVertex", "DryadVertex", "{99F5E7FE-ADD4-4E9D-BA90-5D3D3409BEB3}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "VertexHost", "VertexHost", "{37D9C01A-94F3-4B9E-9AA6-C4D3A2B5AD0D}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "System", "System", "{D8B4F38E-2BF7-44A0-BDBF-025B46501DDE}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Vertex", "Vertex", "{DC277807-506B-4F7C-BECD-345079B91044}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "CommonCode", "CommonCode", "{387FAB85-B968-48F6-A2F7-79A1F2CF2F54}" + ProjectSection(SolutionItems) = preProject + CommonCode\AzureUtils.cs = CommonCode\AzureUtils.cs + CommonCode\Constants.cs = CommonCode\Constants.cs + CommonCode\DiscLocalMonitor.cs = CommonCode\DiscLocalMonitor.cs + CommonCode\DryadTracing.cs = CommonCode\DryadTracing.cs + CommonCode\DryadVertexServiceAuthorizationManager.cs = CommonCode\DryadVertexServiceAuthorizationManager.cs + CommonCode\ExecutionHelper.cs = CommonCode\ExecutionHelper.cs + CommonCode\IDryadVertexCallback.cs = CommonCode\IDryadVertexCallback.cs + CommonCode\IDryadVertexService.cs = CommonCode\IDryadVertexService.cs + CommonCode\NativeMethods.cs = CommonCode\NativeMethods.cs + CommonCode\NetShareWrapper.cs = CommonCode\NetShareWrapper.cs + CommonCode\ProcessPathHelper.cs = CommonCode\ProcessPathHelper.cs + CommonCode\ProcessState.cs = CommonCode\ProcessState.cs + CommonCode\QueryUtility.cs = CommonCode\QueryUtility.cs + CommonCode\RetryFramework.cs = CommonCode\RetryFramework.cs + CommonCode\SchedulerHelper.cs = CommonCode\SchedulerHelper.cs + CommonCode\SoftAffinity.cs = CommonCode\SoftAffinity.cs + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "GraphManager", "GraphManager\GraphManager.vcxproj", "{8E30F4A4-603B-4799-A473-6EF5388661BA}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "DryadYarnBridge", "DryadYarnBridge\DryadYarnBridge.vcxproj", "{09FB27C7-D1A5-4A59-B010-67D5886DD9A2}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "linqtodryadjm_managed", "linqtodryadjm_managed_yarn\linqtodryadjm_managed.csproj", "{1311809B-306E-44A4-9D69-8A7BD15123C5}" + ProjectSection(ProjectDependencies) = postProject + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2} = {09FB27C7-D1A5-4A59-B010-67D5886DD9A2} + EndProjectSection +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.Research.Dryad.ClusterAdapter", "xcompute_managed\Microsoft.Research.Dryad.ClusterAdapter.csproj", "{F4B04940-67CF-4796-B6D3-3CFD38FB988A}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "YarnQueryNativeClusterAdapter", "xcompute_native\YarnQueryNativeClusterAdapter.vcxproj", "{E092E2B9-D3C9-4CE2-8201-BDA442574C97}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DryadVertexService", "DryadVertex\service\DryadVertexService.csproj", "{27D89037-8934-45BE-8A44-2561F9330EB7}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "channel", "DryadVertex\VertexHost\system\channel\channel.vcxproj", "{482E0741-E244-4974-97D4-3A7167581E91}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "classlib", "DryadVertex\VertexHost\system\classlib\classlib.vcxproj", "{016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "common", "DryadVertex\VertexHost\system\common\common.vcxproj", "{57663B94-E11B-431E-BE4B-E2C61112DEC5}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "dprocess", "DryadVertex\VertexHost\system\dprocess\dprocess.vcxproj", "{AA529122-F51C-48D7-A8C1-C0B24F570885}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ManagedWrapperVertex", "DryadVertex\VertexHost\vertex\managedwrappervertex\ManagedWrapperVertex.vcxproj", "{BDEDD3BB-C7E2-498F-A212-F99786C8E23C}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "VertexHost", "DryadVertex\VertexHost\vertex\vertexHost\VertexHost.vcxproj", "{0CF3D1D5-9BBE-4175-979B-EC6138EF4F37}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "WrapperNativeInfo", "DryadVertex\VertexHost\vertex\WrapperNativeInfo\WrapperNativeInfo.vcxproj", "{AB9EA66C-5811-49A7-B002-24203AEB9083}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "WrapperNativeInfoDll", "DryadVertex\VertexHost\vertex\WrapperNativeInfoDll\WrapperNativeInfoDll.vcxproj", "{3EE0920C-0607-4569-9EC3-5C12BB6EF244}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LinqToDryad", "LinqToDryad\LinqToDryad.csproj", "{D33C34CC-6DB2-417C-88B7-299830711774}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Hdfs", "Hdfs", "{6003A98A-82CB-4385-8B29-70BC28F19076}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "HdfsBridgeManaged", "Hdfs\HdfsBridgeManaged\HdfsBridgeManaged.vcxproj", "{C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "HdfsBridgeNative", "Hdfs\HdfsBridgeNative\HdfsBridgeNative.vcxproj", "{95FBF9B7-9407-4554-A74A-3527839BD1B6}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "include", "include", "{89C5654B-02E4-478D-A7E6-50D79F638B4F}" + ProjectSection(SolutionItems) = preProject + DryadVertex\VertexHost\vertex\include\ChannelTransform.h = DryadVertex\VertexHost\vertex\include\ChannelTransform.h + DryadVertex\VertexHost\vertex\include\CompressionVertex.h = DryadVertex\VertexHost\vertex\include\CompressionVertex.h + DryadVertex\VertexHost\vertex\include\DataBlockItem.h = DryadVertex\VertexHost\vertex\include\DataBlockItem.h + DryadVertex\VertexHost\vertex\include\FifoChannel.h = DryadVertex\VertexHost\vertex\include\FifoChannel.h + DryadVertex\VertexHost\vertex\include\FifoInputChannel.h = DryadVertex\VertexHost\vertex\include\FifoInputChannel.h + DryadVertex\VertexHost\vertex\include\FifoOutputChannel.h = DryadVertex\VertexHost\vertex\include\FifoOutputChannel.h + DryadVertex\VertexHost\vertex\include\GzipCompressionChannelTransform.h = DryadVertex\VertexHost\vertex\include\GzipCompressionChannelTransform.h + DryadVertex\VertexHost\vertex\include\GzipDecompressionChannelTransform.h = DryadVertex\VertexHost\vertex\include\GzipDecompressionChannelTransform.h + DryadVertex\VertexHost\vertex\include\InputChannel.h = DryadVertex\VertexHost\vertex\include\InputChannel.h + DryadVertex\VertexHost\vertex\include\ManagedWrapper.h = DryadVertex\VertexHost\vertex\include\ManagedWrapper.h + DryadVertex\VertexHost\vertex\include\NullChannelTransform.h = DryadVertex\VertexHost\vertex\include\NullChannelTransform.h + DryadVertex\VertexHost\vertex\include\OutputChannel.h = DryadVertex\VertexHost\vertex\include\OutputChannel.h + DryadVertex\VertexHost\vertex\include\wrappernativeinfo.h = DryadVertex\VertexHost\vertex\include\wrappernativeinfo.h + EndProjectSection +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Debug|Mixed Platforms = Debug|Mixed Platforms + Debug|Win32 = Debug|Win32 + Debug|x64 = Debug|x64 + Release|Any CPU = Release|Any CPU + Release|Mixed Platforms = Release|Mixed Platforms + Release|Win32 = Release|Win32 + Release|x64 = Release|x64 + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {8E30F4A4-603B-4799-A473-6EF5388661BA}.Debug|Any CPU.ActiveCfg = Debug|x64 + {8E30F4A4-603B-4799-A473-6EF5388661BA}.Debug|Mixed Platforms.ActiveCfg = Debug|x64 + {8E30F4A4-603B-4799-A473-6EF5388661BA}.Debug|Mixed Platforms.Build.0 = Debug|x64 + {8E30F4A4-603B-4799-A473-6EF5388661BA}.Debug|Win32.ActiveCfg = Debug|Win32 + {8E30F4A4-603B-4799-A473-6EF5388661BA}.Debug|Win32.Build.0 = Debug|Win32 + {8E30F4A4-603B-4799-A473-6EF5388661BA}.Debug|x64.ActiveCfg = Debug|x64 + {8E30F4A4-603B-4799-A473-6EF5388661BA}.Debug|x64.Build.0 = Debug|x64 + {8E30F4A4-603B-4799-A473-6EF5388661BA}.Release|Any CPU.ActiveCfg = Release|x64 + {8E30F4A4-603B-4799-A473-6EF5388661BA}.Release|Mixed Platforms.ActiveCfg = Release|x64 + {8E30F4A4-603B-4799-A473-6EF5388661BA}.Release|Mixed Platforms.Build.0 = Release|x64 + {8E30F4A4-603B-4799-A473-6EF5388661BA}.Release|Win32.ActiveCfg = Release|Win32 + {8E30F4A4-603B-4799-A473-6EF5388661BA}.Release|Win32.Build.0 = Release|Win32 + {8E30F4A4-603B-4799-A473-6EF5388661BA}.Release|x64.ActiveCfg = Release|x64 + {8E30F4A4-603B-4799-A473-6EF5388661BA}.Release|x64.Build.0 = Release|x64 + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2}.Debug|Any CPU.ActiveCfg = Debug|x64 + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2}.Debug|Mixed Platforms.ActiveCfg = Debug|x64 + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2}.Debug|Mixed Platforms.Build.0 = Debug|x64 + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2}.Debug|Win32.ActiveCfg = Debug|Win32 + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2}.Debug|Win32.Build.0 = Debug|Win32 + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2}.Debug|x64.ActiveCfg = Debug|x64 + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2}.Debug|x64.Build.0 = Debug|x64 + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2}.Release|Any CPU.ActiveCfg = Release|x64 + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2}.Release|Mixed Platforms.ActiveCfg = Release|x64 + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2}.Release|Mixed Platforms.Build.0 = Release|x64 + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2}.Release|Win32.ActiveCfg = Release|Win32 + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2}.Release|Win32.Build.0 = Release|Win32 + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2}.Release|x64.ActiveCfg = Release|x64 + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2}.Release|x64.Build.0 = Release|x64 + {1311809B-306E-44A4-9D69-8A7BD15123C5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1311809B-306E-44A4-9D69-8A7BD15123C5}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1311809B-306E-44A4-9D69-8A7BD15123C5}.Debug|Mixed Platforms.ActiveCfg = Debug|Any CPU + {1311809B-306E-44A4-9D69-8A7BD15123C5}.Debug|Mixed Platforms.Build.0 = Debug|Any CPU + {1311809B-306E-44A4-9D69-8A7BD15123C5}.Debug|Win32.ActiveCfg = Debug|Any CPU + {1311809B-306E-44A4-9D69-8A7BD15123C5}.Debug|x64.ActiveCfg = Debug|Any CPU + {1311809B-306E-44A4-9D69-8A7BD15123C5}.Debug|x64.Build.0 = Debug|Any CPU + {1311809B-306E-44A4-9D69-8A7BD15123C5}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1311809B-306E-44A4-9D69-8A7BD15123C5}.Release|Any CPU.Build.0 = Release|Any CPU + {1311809B-306E-44A4-9D69-8A7BD15123C5}.Release|Mixed Platforms.ActiveCfg = Release|Any CPU + {1311809B-306E-44A4-9D69-8A7BD15123C5}.Release|Mixed Platforms.Build.0 = Release|Any CPU + {1311809B-306E-44A4-9D69-8A7BD15123C5}.Release|Win32.ActiveCfg = Release|Any CPU + {1311809B-306E-44A4-9D69-8A7BD15123C5}.Release|x64.ActiveCfg = Release|Any CPU + {1311809B-306E-44A4-9D69-8A7BD15123C5}.Release|x64.Build.0 = Release|Any CPU + {F4B04940-67CF-4796-B6D3-3CFD38FB988A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F4B04940-67CF-4796-B6D3-3CFD38FB988A}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F4B04940-67CF-4796-B6D3-3CFD38FB988A}.Debug|Mixed Platforms.ActiveCfg = Debug|Any CPU + {F4B04940-67CF-4796-B6D3-3CFD38FB988A}.Debug|Mixed Platforms.Build.0 = Debug|Any CPU + {F4B04940-67CF-4796-B6D3-3CFD38FB988A}.Debug|Win32.ActiveCfg = Debug|Any CPU + {F4B04940-67CF-4796-B6D3-3CFD38FB988A}.Debug|x64.ActiveCfg = Debug|Any CPU + {F4B04940-67CF-4796-B6D3-3CFD38FB988A}.Debug|x64.Build.0 = Debug|Any CPU + {F4B04940-67CF-4796-B6D3-3CFD38FB988A}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F4B04940-67CF-4796-B6D3-3CFD38FB988A}.Release|Any CPU.Build.0 = Release|Any CPU + {F4B04940-67CF-4796-B6D3-3CFD38FB988A}.Release|Mixed Platforms.ActiveCfg = Release|Any CPU + {F4B04940-67CF-4796-B6D3-3CFD38FB988A}.Release|Mixed Platforms.Build.0 = Release|Any CPU + {F4B04940-67CF-4796-B6D3-3CFD38FB988A}.Release|Win32.ActiveCfg = Release|Any CPU + {F4B04940-67CF-4796-B6D3-3CFD38FB988A}.Release|x64.ActiveCfg = Release|Any CPU + {F4B04940-67CF-4796-B6D3-3CFD38FB988A}.Release|x64.Build.0 = Release|Any CPU + {E092E2B9-D3C9-4CE2-8201-BDA442574C97}.Debug|Any CPU.ActiveCfg = Debug|Win32 + {E092E2B9-D3C9-4CE2-8201-BDA442574C97}.Debug|Mixed Platforms.ActiveCfg = Debug|x64 + {E092E2B9-D3C9-4CE2-8201-BDA442574C97}.Debug|Mixed Platforms.Build.0 = Debug|x64 + {E092E2B9-D3C9-4CE2-8201-BDA442574C97}.Debug|Win32.ActiveCfg = Debug|Win32 + {E092E2B9-D3C9-4CE2-8201-BDA442574C97}.Debug|Win32.Build.0 = Debug|Win32 + {E092E2B9-D3C9-4CE2-8201-BDA442574C97}.Debug|x64.ActiveCfg = Debug|x64 + {E092E2B9-D3C9-4CE2-8201-BDA442574C97}.Debug|x64.Build.0 = Debug|x64 + {E092E2B9-D3C9-4CE2-8201-BDA442574C97}.Release|Any CPU.ActiveCfg = Release|Win32 + {E092E2B9-D3C9-4CE2-8201-BDA442574C97}.Release|Mixed Platforms.ActiveCfg = Release|Win32 + {E092E2B9-D3C9-4CE2-8201-BDA442574C97}.Release|Mixed Platforms.Build.0 = Release|Win32 + {E092E2B9-D3C9-4CE2-8201-BDA442574C97}.Release|Win32.ActiveCfg = Release|Win32 + {E092E2B9-D3C9-4CE2-8201-BDA442574C97}.Release|Win32.Build.0 = Release|Win32 + {E092E2B9-D3C9-4CE2-8201-BDA442574C97}.Release|x64.ActiveCfg = Release|x64 + {E092E2B9-D3C9-4CE2-8201-BDA442574C97}.Release|x64.Build.0 = Release|x64 + {27D89037-8934-45BE-8A44-2561F9330EB7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {27D89037-8934-45BE-8A44-2561F9330EB7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {27D89037-8934-45BE-8A44-2561F9330EB7}.Debug|Mixed Platforms.ActiveCfg = Debug|Any CPU + {27D89037-8934-45BE-8A44-2561F9330EB7}.Debug|Mixed Platforms.Build.0 = Debug|Any CPU + {27D89037-8934-45BE-8A44-2561F9330EB7}.Debug|Win32.ActiveCfg = Debug|Any CPU + {27D89037-8934-45BE-8A44-2561F9330EB7}.Debug|x64.ActiveCfg = Debug|Any CPU + {27D89037-8934-45BE-8A44-2561F9330EB7}.Debug|x64.Build.0 = Debug|Any CPU + {27D89037-8934-45BE-8A44-2561F9330EB7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {27D89037-8934-45BE-8A44-2561F9330EB7}.Release|Any CPU.Build.0 = Release|Any CPU + {27D89037-8934-45BE-8A44-2561F9330EB7}.Release|Mixed Platforms.ActiveCfg = Release|Any CPU + {27D89037-8934-45BE-8A44-2561F9330EB7}.Release|Mixed Platforms.Build.0 = Release|Any CPU + {27D89037-8934-45BE-8A44-2561F9330EB7}.Release|Win32.ActiveCfg = Release|Any CPU + {27D89037-8934-45BE-8A44-2561F9330EB7}.Release|x64.ActiveCfg = Release|Any CPU + {27D89037-8934-45BE-8A44-2561F9330EB7}.Release|x64.Build.0 = Release|Any CPU + {482E0741-E244-4974-97D4-3A7167581E91}.Debug|Any CPU.ActiveCfg = Debug|Win32 + {482E0741-E244-4974-97D4-3A7167581E91}.Debug|Mixed Platforms.ActiveCfg = Debug|x64 + {482E0741-E244-4974-97D4-3A7167581E91}.Debug|Mixed Platforms.Build.0 = Debug|x64 + {482E0741-E244-4974-97D4-3A7167581E91}.Debug|Win32.ActiveCfg = Debug|Win32 + {482E0741-E244-4974-97D4-3A7167581E91}.Debug|Win32.Build.0 = Debug|Win32 + {482E0741-E244-4974-97D4-3A7167581E91}.Debug|x64.ActiveCfg = Debug|x64 + {482E0741-E244-4974-97D4-3A7167581E91}.Debug|x64.Build.0 = Debug|x64 + {482E0741-E244-4974-97D4-3A7167581E91}.Release|Any CPU.ActiveCfg = Release|Win32 + {482E0741-E244-4974-97D4-3A7167581E91}.Release|Mixed Platforms.ActiveCfg = Release|Win32 + {482E0741-E244-4974-97D4-3A7167581E91}.Release|Mixed Platforms.Build.0 = Release|Win32 + {482E0741-E244-4974-97D4-3A7167581E91}.Release|Win32.ActiveCfg = Release|Win32 + {482E0741-E244-4974-97D4-3A7167581E91}.Release|Win32.Build.0 = Release|Win32 + {482E0741-E244-4974-97D4-3A7167581E91}.Release|x64.ActiveCfg = Release|x64 + {482E0741-E244-4974-97D4-3A7167581E91}.Release|x64.Build.0 = Release|x64 + {016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA}.Debug|Any CPU.ActiveCfg = Debug|Win32 + {016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA}.Debug|Mixed Platforms.ActiveCfg = Debug|x64 + {016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA}.Debug|Mixed Platforms.Build.0 = Debug|x64 + {016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA}.Debug|Win32.ActiveCfg = Debug|Win32 + {016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA}.Debug|Win32.Build.0 = Debug|Win32 + {016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA}.Debug|x64.ActiveCfg = Debug|x64 + {016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA}.Debug|x64.Build.0 = Debug|x64 + {016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA}.Release|Any CPU.ActiveCfg = Release|Win32 + {016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA}.Release|Mixed Platforms.ActiveCfg = Release|Win32 + {016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA}.Release|Mixed Platforms.Build.0 = Release|Win32 + {016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA}.Release|Win32.ActiveCfg = Release|Win32 + {016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA}.Release|Win32.Build.0 = Release|Win32 + {016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA}.Release|x64.ActiveCfg = Release|x64 + {016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA}.Release|x64.Build.0 = Release|x64 + {57663B94-E11B-431E-BE4B-E2C61112DEC5}.Debug|Any CPU.ActiveCfg = Debug|Win32 + {57663B94-E11B-431E-BE4B-E2C61112DEC5}.Debug|Mixed Platforms.ActiveCfg = Debug|x64 + {57663B94-E11B-431E-BE4B-E2C61112DEC5}.Debug|Mixed Platforms.Build.0 = Debug|x64 + {57663B94-E11B-431E-BE4B-E2C61112DEC5}.Debug|Win32.ActiveCfg = Debug|Win32 + {57663B94-E11B-431E-BE4B-E2C61112DEC5}.Debug|Win32.Build.0 = Debug|Win32 + {57663B94-E11B-431E-BE4B-E2C61112DEC5}.Debug|x64.ActiveCfg = Debug|x64 + {57663B94-E11B-431E-BE4B-E2C61112DEC5}.Debug|x64.Build.0 = Debug|x64 + {57663B94-E11B-431E-BE4B-E2C61112DEC5}.Release|Any CPU.ActiveCfg = Release|Win32 + {57663B94-E11B-431E-BE4B-E2C61112DEC5}.Release|Mixed Platforms.ActiveCfg = Release|Win32 + {57663B94-E11B-431E-BE4B-E2C61112DEC5}.Release|Mixed Platforms.Build.0 = Release|Win32 + {57663B94-E11B-431E-BE4B-E2C61112DEC5}.Release|Win32.ActiveCfg = Release|Win32 + {57663B94-E11B-431E-BE4B-E2C61112DEC5}.Release|Win32.Build.0 = Release|Win32 + {57663B94-E11B-431E-BE4B-E2C61112DEC5}.Release|x64.ActiveCfg = Release|x64 + {57663B94-E11B-431E-BE4B-E2C61112DEC5}.Release|x64.Build.0 = Release|x64 + {AA529122-F51C-48D7-A8C1-C0B24F570885}.Debug|Any CPU.ActiveCfg = Debug|Win32 + {AA529122-F51C-48D7-A8C1-C0B24F570885}.Debug|Mixed Platforms.ActiveCfg = Debug|x64 + {AA529122-F51C-48D7-A8C1-C0B24F570885}.Debug|Mixed Platforms.Build.0 = Debug|x64 + {AA529122-F51C-48D7-A8C1-C0B24F570885}.Debug|Win32.ActiveCfg = Debug|Win32 + {AA529122-F51C-48D7-A8C1-C0B24F570885}.Debug|Win32.Build.0 = Debug|Win32 + {AA529122-F51C-48D7-A8C1-C0B24F570885}.Debug|x64.ActiveCfg = Debug|x64 + {AA529122-F51C-48D7-A8C1-C0B24F570885}.Debug|x64.Build.0 = Debug|x64 + {AA529122-F51C-48D7-A8C1-C0B24F570885}.Release|Any CPU.ActiveCfg = Release|Win32 + {AA529122-F51C-48D7-A8C1-C0B24F570885}.Release|Mixed Platforms.ActiveCfg = Release|Win32 + {AA529122-F51C-48D7-A8C1-C0B24F570885}.Release|Mixed Platforms.Build.0 = Release|Win32 + {AA529122-F51C-48D7-A8C1-C0B24F570885}.Release|Win32.ActiveCfg = Release|Win32 + {AA529122-F51C-48D7-A8C1-C0B24F570885}.Release|Win32.Build.0 = Release|Win32 + {AA529122-F51C-48D7-A8C1-C0B24F570885}.Release|x64.ActiveCfg = Release|x64 + {AA529122-F51C-48D7-A8C1-C0B24F570885}.Release|x64.Build.0 = Release|x64 + {BDEDD3BB-C7E2-498F-A212-F99786C8E23C}.Debug|Any CPU.ActiveCfg = Debug|Win32 + {BDEDD3BB-C7E2-498F-A212-F99786C8E23C}.Debug|Mixed Platforms.ActiveCfg = Debug|x64 + {BDEDD3BB-C7E2-498F-A212-F99786C8E23C}.Debug|Mixed Platforms.Build.0 = Debug|x64 + {BDEDD3BB-C7E2-498F-A212-F99786C8E23C}.Debug|Win32.ActiveCfg = Debug|Win32 + {BDEDD3BB-C7E2-498F-A212-F99786C8E23C}.Debug|Win32.Build.0 = Debug|Win32 + {BDEDD3BB-C7E2-498F-A212-F99786C8E23C}.Debug|x64.ActiveCfg = Debug|x64 + {BDEDD3BB-C7E2-498F-A212-F99786C8E23C}.Debug|x64.Build.0 = Debug|x64 + {BDEDD3BB-C7E2-498F-A212-F99786C8E23C}.Release|Any CPU.ActiveCfg = Release|Win32 + {BDEDD3BB-C7E2-498F-A212-F99786C8E23C}.Release|Mixed Platforms.ActiveCfg = Release|Win32 + {BDEDD3BB-C7E2-498F-A212-F99786C8E23C}.Release|Mixed Platforms.Build.0 = Release|Win32 + {BDEDD3BB-C7E2-498F-A212-F99786C8E23C}.Release|Win32.ActiveCfg = Release|Win32 + {BDEDD3BB-C7E2-498F-A212-F99786C8E23C}.Release|Win32.Build.0 = Release|Win32 + {BDEDD3BB-C7E2-498F-A212-F99786C8E23C}.Release|x64.ActiveCfg = Release|x64 + {BDEDD3BB-C7E2-498F-A212-F99786C8E23C}.Release|x64.Build.0 = Release|x64 + {0CF3D1D5-9BBE-4175-979B-EC6138EF4F37}.Debug|Any CPU.ActiveCfg = Debug|Win32 + {0CF3D1D5-9BBE-4175-979B-EC6138EF4F37}.Debug|Mixed Platforms.ActiveCfg = Debug|x64 + {0CF3D1D5-9BBE-4175-979B-EC6138EF4F37}.Debug|Mixed Platforms.Build.0 = Debug|x64 + {0CF3D1D5-9BBE-4175-979B-EC6138EF4F37}.Debug|Win32.ActiveCfg = Debug|Win32 + {0CF3D1D5-9BBE-4175-979B-EC6138EF4F37}.Debug|Win32.Build.0 = Debug|Win32 + {0CF3D1D5-9BBE-4175-979B-EC6138EF4F37}.Debug|x64.ActiveCfg = Debug|x64 + {0CF3D1D5-9BBE-4175-979B-EC6138EF4F37}.Debug|x64.Build.0 = Debug|x64 + {0CF3D1D5-9BBE-4175-979B-EC6138EF4F37}.Release|Any CPU.ActiveCfg = Release|Win32 + {0CF3D1D5-9BBE-4175-979B-EC6138EF4F37}.Release|Mixed Platforms.ActiveCfg = Release|Win32 + {0CF3D1D5-9BBE-4175-979B-EC6138EF4F37}.Release|Mixed Platforms.Build.0 = Release|Win32 + {0CF3D1D5-9BBE-4175-979B-EC6138EF4F37}.Release|Win32.ActiveCfg = Release|Win32 + {0CF3D1D5-9BBE-4175-979B-EC6138EF4F37}.Release|Win32.Build.0 = Release|Win32 + {0CF3D1D5-9BBE-4175-979B-EC6138EF4F37}.Release|x64.ActiveCfg = Release|x64 + {0CF3D1D5-9BBE-4175-979B-EC6138EF4F37}.Release|x64.Build.0 = Release|x64 + {AB9EA66C-5811-49A7-B002-24203AEB9083}.Debug|Any CPU.ActiveCfg = Debug|Win32 + {AB9EA66C-5811-49A7-B002-24203AEB9083}.Debug|Mixed Platforms.ActiveCfg = Debug|x64 + {AB9EA66C-5811-49A7-B002-24203AEB9083}.Debug|Mixed Platforms.Build.0 = Debug|x64 + {AB9EA66C-5811-49A7-B002-24203AEB9083}.Debug|Win32.ActiveCfg = Debug|Win32 + {AB9EA66C-5811-49A7-B002-24203AEB9083}.Debug|Win32.Build.0 = Debug|Win32 + {AB9EA66C-5811-49A7-B002-24203AEB9083}.Debug|x64.ActiveCfg = Debug|x64 + {AB9EA66C-5811-49A7-B002-24203AEB9083}.Debug|x64.Build.0 = Debug|x64 + {AB9EA66C-5811-49A7-B002-24203AEB9083}.Release|Any CPU.ActiveCfg = Release|Win32 + {AB9EA66C-5811-49A7-B002-24203AEB9083}.Release|Mixed Platforms.ActiveCfg = Release|Win32 + {AB9EA66C-5811-49A7-B002-24203AEB9083}.Release|Mixed Platforms.Build.0 = Release|Win32 + {AB9EA66C-5811-49A7-B002-24203AEB9083}.Release|Win32.ActiveCfg = Release|Win32 + {AB9EA66C-5811-49A7-B002-24203AEB9083}.Release|Win32.Build.0 = Release|Win32 + {AB9EA66C-5811-49A7-B002-24203AEB9083}.Release|x64.ActiveCfg = Release|x64 + {AB9EA66C-5811-49A7-B002-24203AEB9083}.Release|x64.Build.0 = Release|x64 + {3EE0920C-0607-4569-9EC3-5C12BB6EF244}.Debug|Any CPU.ActiveCfg = Debug|Win32 + {3EE0920C-0607-4569-9EC3-5C12BB6EF244}.Debug|Mixed Platforms.ActiveCfg = Debug|x64 + {3EE0920C-0607-4569-9EC3-5C12BB6EF244}.Debug|Mixed Platforms.Build.0 = Debug|x64 + {3EE0920C-0607-4569-9EC3-5C12BB6EF244}.Debug|Win32.ActiveCfg = Debug|Win32 + {3EE0920C-0607-4569-9EC3-5C12BB6EF244}.Debug|Win32.Build.0 = Debug|Win32 + {3EE0920C-0607-4569-9EC3-5C12BB6EF244}.Debug|x64.ActiveCfg = Debug|x64 + {3EE0920C-0607-4569-9EC3-5C12BB6EF244}.Debug|x64.Build.0 = Debug|x64 + {3EE0920C-0607-4569-9EC3-5C12BB6EF244}.Release|Any CPU.ActiveCfg = Release|Win32 + {3EE0920C-0607-4569-9EC3-5C12BB6EF244}.Release|Mixed Platforms.ActiveCfg = Release|Win32 + {3EE0920C-0607-4569-9EC3-5C12BB6EF244}.Release|Mixed Platforms.Build.0 = Release|Win32 + {3EE0920C-0607-4569-9EC3-5C12BB6EF244}.Release|Win32.ActiveCfg = Release|Win32 + {3EE0920C-0607-4569-9EC3-5C12BB6EF244}.Release|Win32.Build.0 = Release|Win32 + {3EE0920C-0607-4569-9EC3-5C12BB6EF244}.Release|x64.ActiveCfg = Release|x64 + {3EE0920C-0607-4569-9EC3-5C12BB6EF244}.Release|x64.Build.0 = Release|x64 + {D33C34CC-6DB2-417C-88B7-299830711774}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D33C34CC-6DB2-417C-88B7-299830711774}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D33C34CC-6DB2-417C-88B7-299830711774}.Debug|Mixed Platforms.ActiveCfg = Debug|Any CPU + {D33C34CC-6DB2-417C-88B7-299830711774}.Debug|Mixed Platforms.Build.0 = Debug|Any CPU + {D33C34CC-6DB2-417C-88B7-299830711774}.Debug|Win32.ActiveCfg = Debug|Any CPU + {D33C34CC-6DB2-417C-88B7-299830711774}.Debug|x64.ActiveCfg = Debug|Any CPU + {D33C34CC-6DB2-417C-88B7-299830711774}.Debug|x64.Build.0 = Debug|Any CPU + {D33C34CC-6DB2-417C-88B7-299830711774}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D33C34CC-6DB2-417C-88B7-299830711774}.Release|Any CPU.Build.0 = Release|Any CPU + {D33C34CC-6DB2-417C-88B7-299830711774}.Release|Mixed Platforms.ActiveCfg = Release|Any CPU + {D33C34CC-6DB2-417C-88B7-299830711774}.Release|Mixed Platforms.Build.0 = Release|Any CPU + {D33C34CC-6DB2-417C-88B7-299830711774}.Release|Win32.ActiveCfg = Release|Any CPU + {D33C34CC-6DB2-417C-88B7-299830711774}.Release|x64.ActiveCfg = Release|Any CPU + {D33C34CC-6DB2-417C-88B7-299830711774}.Release|x64.Build.0 = Release|Any CPU + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5}.Debug|Any CPU.ActiveCfg = Debug|x64 + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5}.Debug|Mixed Platforms.ActiveCfg = Debug|x64 + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5}.Debug|Mixed Platforms.Build.0 = Debug|x64 + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5}.Debug|Win32.ActiveCfg = Debug|Win32 + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5}.Debug|Win32.Build.0 = Debug|Win32 + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5}.Debug|x64.ActiveCfg = Debug|x64 + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5}.Debug|x64.Build.0 = Debug|x64 + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5}.Release|Any CPU.ActiveCfg = Release|x64 + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5}.Release|Mixed Platforms.ActiveCfg = Release|x64 + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5}.Release|Mixed Platforms.Build.0 = Release|x64 + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5}.Release|Win32.ActiveCfg = Release|Win32 + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5}.Release|Win32.Build.0 = Release|Win32 + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5}.Release|x64.ActiveCfg = Release|x64 + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5}.Release|x64.Build.0 = Release|x64 + {95FBF9B7-9407-4554-A74A-3527839BD1B6}.Debug|Any CPU.ActiveCfg = Debug|x64 + {95FBF9B7-9407-4554-A74A-3527839BD1B6}.Debug|Mixed Platforms.ActiveCfg = Debug|x64 + {95FBF9B7-9407-4554-A74A-3527839BD1B6}.Debug|Mixed Platforms.Build.0 = Debug|x64 + {95FBF9B7-9407-4554-A74A-3527839BD1B6}.Debug|Win32.ActiveCfg = Debug|Win32 + {95FBF9B7-9407-4554-A74A-3527839BD1B6}.Debug|Win32.Build.0 = Debug|Win32 + {95FBF9B7-9407-4554-A74A-3527839BD1B6}.Debug|x64.ActiveCfg = Debug|x64 + {95FBF9B7-9407-4554-A74A-3527839BD1B6}.Debug|x64.Build.0 = Debug|x64 + {95FBF9B7-9407-4554-A74A-3527839BD1B6}.Release|Any CPU.ActiveCfg = Release|x64 + {95FBF9B7-9407-4554-A74A-3527839BD1B6}.Release|Mixed Platforms.ActiveCfg = Release|x64 + {95FBF9B7-9407-4554-A74A-3527839BD1B6}.Release|Mixed Platforms.Build.0 = Release|x64 + {95FBF9B7-9407-4554-A74A-3527839BD1B6}.Release|Win32.ActiveCfg = Release|Win32 + {95FBF9B7-9407-4554-A74A-3527839BD1B6}.Release|Win32.Build.0 = Release|Win32 + {95FBF9B7-9407-4554-A74A-3527839BD1B6}.Release|x64.ActiveCfg = Release|x64 + {95FBF9B7-9407-4554-A74A-3527839BD1B6}.Release|x64.Build.0 = Release|x64 + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(NestedProjects) = preSolution + {37D9C01A-94F3-4B9E-9AA6-C4D3A2B5AD0D} = {99F5E7FE-ADD4-4E9D-BA90-5D3D3409BEB3} + {27D89037-8934-45BE-8A44-2561F9330EB7} = {99F5E7FE-ADD4-4E9D-BA90-5D3D3409BEB3} + {D8B4F38E-2BF7-44A0-BDBF-025B46501DDE} = {37D9C01A-94F3-4B9E-9AA6-C4D3A2B5AD0D} + {DC277807-506B-4F7C-BECD-345079B91044} = {37D9C01A-94F3-4B9E-9AA6-C4D3A2B5AD0D} + {482E0741-E244-4974-97D4-3A7167581E91} = {D8B4F38E-2BF7-44A0-BDBF-025B46501DDE} + {016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA} = {D8B4F38E-2BF7-44A0-BDBF-025B46501DDE} + {57663B94-E11B-431E-BE4B-E2C61112DEC5} = {D8B4F38E-2BF7-44A0-BDBF-025B46501DDE} + {AA529122-F51C-48D7-A8C1-C0B24F570885} = {D8B4F38E-2BF7-44A0-BDBF-025B46501DDE} + {BDEDD3BB-C7E2-498F-A212-F99786C8E23C} = {DC277807-506B-4F7C-BECD-345079B91044} + {0CF3D1D5-9BBE-4175-979B-EC6138EF4F37} = {DC277807-506B-4F7C-BECD-345079B91044} + {AB9EA66C-5811-49A7-B002-24203AEB9083} = {DC277807-506B-4F7C-BECD-345079B91044} + {3EE0920C-0607-4569-9EC3-5C12BB6EF244} = {DC277807-506B-4F7C-BECD-345079B91044} + {89C5654B-02E4-478D-A7E6-50D79F638B4F} = {DC277807-506B-4F7C-BECD-345079B91044} + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5} = {6003A98A-82CB-4385-8B29-70BC28F19076} + {95FBF9B7-9407-4554-A74A-3527839BD1B6} = {6003A98A-82CB-4385-8B29-70BC28F19076} + EndGlobalSection +EndGlobal diff --git a/Dryad.v11.suo b/Dryad.v11.suo new file mode 100644 index 0000000000000000000000000000000000000000..d3f2b74952a131e5b05ab60796f6f49e53c83c98 GIT binary patch literal 115712 zcmeHQ33wdEm2R2C7>vW=080`_u}y$rDY_3IBy&i%K(-u9#!fN`$kIrbJ(6tA$XKx* zSO|o0#UT#o!V=a27Dxg)Kb(anWJxxA#36=|YgyvmO}@RpY!X;D7vBHxH|47C>FMc_ zS`H(t^*zl@cXidPSFc{Zdhb=$clV$A(x)Fg`=^%cy3jhtdhO6G>v-e-416cp*D02D z8NTrMYljXU;=N-4xNYI*^i?7hSYr+2FJqNkE!Ge|N31b6W#H=!>$t=1J+r`D-`#st z}> zeEz}AAp4ztnsxdU0B_CBex8g^KI289A>ZUb?_cV=$0sM^9(hAjP5)OEI2$w}@9?+T z27iuw8?*NZ@VN=F8IS@D0)_y?fGvQnfDu3%kO7PWwgD~&uqUxMe+aMxuoJKg@L|9e zfGYtYf zWz(p9r#}xJV*fwZgDyw1|NkkkvC=gryA-T3@O z_WLKZpP$0#t=aFlGav0_%arL?f5)nkt$ zXpDdt^sllqq@4#K#z9`AOlhP*z~!g^#G#RKZ1cOQ|D1^g3jjgOrwscqDnYoS7)ggQ zx(-6;8^Nd6{^WU~|8E|L{z+MTUN++{l$%Dm?Wk)SwC}{oKMYJVXbEbj9M}CW)ZebR z;ugQO>KS!0;=pUV|0~KJouAYF|LE@jZuzqo5+@1yqwWWgKV{Sf=L3Su-#?Q$EyqFY(Y&~W*Xl)1|(5umK_Id zv#PBs>jFva@&6qtpEfvUck@V+bBj8dy6HFm(iT&%66X0F&-il-@=}y7RDNBaa@B2) z{~XE}0)oacD;!MulyL!ZR9Ln6 zRgLks0{^SgKWeQ@56=&CXGe>{f2X1Sm0c*yR=yv7csfEOCaz%RcOl(M;MWbg>|{;o%FOQC+Bn(L3zZEZ`k8M%P!UNe+9-G>YNGuAG5!4{7IlUY(#Gm z?K@>Z#PtBnZf7D(#V$|+K4wd+J`k&6d>Y5AxzIo2e*7s(=?+u&DY^sISC+}I- zyJOa3Bm1~L@O2}rgl!TTf`fzVxp}`6PK!#U<5Xa~)$-_5w^}1BxBt(L$K3PPf9jxr zBTrF$VqfpV|GA6g`=`%d8i_7!}r1|NR$ul`SL&a_1u2W z18wg=^x&U8_c0}__J58c%KlGz{z(Aks-#-JmAMgO+LBNjy4*?zqJOcPEfIa8)0FI%L0lonEBH&Aa z{Q%XOCj0ZdDC=3k_W%a~&jFqXd>`-v;0J&o0$A4{=nz|Mdv(k)Zd78-U>^VRkwcIDcwf!! z2amh@zL}e=?nR|@YOO6imm_!D;u8?_tDu)s`%~5*{cIa_)=Jxslby<0>5-m9d}OS< zH_<*Yw6$wkzIU#v97(3>3%8D6Iy9O}UC6IxR&Q)^B0cgyF>0=u2lBm3k|UYq_Jn+_ z7*1yr>9KSsIhg3#*gZ6qOeHE;wRE+PbPpykA0FAf2(7n{D5? z<0Hq8MKYO@feoXXWV&*A13mXZK9ky2fd-w%xLj$8zFK*Ba8&ry zrr!@ub{KQkZhWddmq**%dhkp}3GwsvXQ^2Zy2kMaJq@(1alA3-oXQp$w}-lJf2iUn!I7KM9>Zt}(LqhA;G{zx_@8}`_Pn6{pUS3Z zPRjF4yVDk1ezE)!NIeyW$5i<*&cDya8SN%0DGrcF^(}2lp(UCb-=EzYg`L?|q|97fIvyzZ~Vi83`EE5VZUqNXtIM%U!;A z|C6*{|Ib7Dtjg(tpyhuEY0t}Mwia}z*1#z@{cNf~3fFsaT=vE62gX0UT$att`PBQD z$^R?BVZCbEBWI9FjibIwYgn{t=F|I+}JKa^>2 z1kfLTCV)CI^En%z9G@sB=_j8D;CNjQcpHHHvjDIVa1MaBWRB^SQRf5R0k{CL2v7lF zp6|qG6@d2f8bB?e4p0xE&z62uYVXZ}#egM%rGRAsP5_z5yYU$TL;*2C9MA%21uO@w z0JH%v0$dDe2do6F0(1aY1KtBz1Ly>hrp)&ee6GuWe=k1Y2Y5dq0r&ub^N&jbmjSu~ z8vs3kUO*Di2j~ZE1PlN+0X74uHw^-Y04#^FHT#{CG6NU|Yy(^l*bY$X)aNeT`!L`N zz?A^+@pp2^|7%cgh~t0J{7?Dw#~lCJF6=A3-128U<3IZ(Si!#RObh%YZ?JFinjZg; z?(zT1Nt*w2=GX&?HEL^e+E(bC)zd8g+waWV_oY+6xvp%^wts(Ll$>G*x?$~oWu{)7EnXMK71C*S!9aaMIupU!%;tZeP# z*qGnaxhU(eQFf7Kkngm7R@r0s`u!Z8$A3j*ovYv7_UOy|`)^(Tl>|IsHu`NbdS-2w-TdQkF8gyk+Kh|Dlt}aZcPgV{fQfK zM|-I%Q5rdh`Y&hX5#&Oe%P-n{(zca|>jk@Bh~M(ehreUTo~zoUE-~sRTBhzq9z&S1 zmLuoYkacQ)avW0ip`;Z4X4`MHKnqxlQGW?W#5H&;W@`eX<1B%e(P&3SsBbk$SA#oM zSqtHzLo<;JEm%SkjKox`jOzq^=WKcd(y;CvgXAJk)L3Y_-2|#>?8B#EuU+F~QQCHt zS!W~Qg&9+LU9=`sW2?kp0yJW68LdQX-Eg+2uR_~N96RfogeGM69BD=7LoLuRAI%qH z`|APqXa~?aP;d0lf!2#2dp_3#s;9t_#FZ8m+7NY4)D-=5VvBK(Z2C@}k)jdsQwAK( zoCjdp)U!5sTr2sQ6rc|5SE}iCVO*Hx-ix<2=G3h_?ViLwM~+Tr=kSc2qg|OjZY!Sf z7ki*8{VjOYtPP`U*dORW;EdQ=JLM08)wSY0#4ksguPfRzZiub20@O{R|5Iw|yf`N^OLe!CnfA<_ zw9HFIAyY2WJ3wo!&Y$xsKU^c=Li$jWhv{c>P9f?$YpU(By8J?tAh*8x>9js8O~i}n z=3!kX=bBs>#F2os7fam;c%E_mDT6xa6rfH%-f;50E?GvF+>&()&WD{;V-)=cV3joR z!rQu-w9`6zuse!k!Q8c`+krp#ju z6X`yN-z=LFnXyX5!o-+X45xf_@=?5{y^WHZGxU5OROK%y$DC?A^nnK0EE0CE5*MYS z`O96nu|JbOxmpsnWsaHj$vVBo{=V@1-eMW|siWlc5nom^H5>%hsCkl`*s7Hn_c(6( z)sEJbQf|L?)RducLF!6NQYqGz@{40tkUExXe-zy~NDXPrE5%k6&mFZjO0UYS7CB%` zcc)eygbK8JmLe6LIs|7Oy`ToQlftr!7U0sTa|Ci&X~LAmriNa)o*yikdO9Pp{A)&y zVPz~vSAjjpn@X)o4LEQ%$3G5(&E)2cE) zd)4^p8INffGTWavTgF8M8PmNeH&w^*xjA&oH|l#)rBr*6ia(Z7Ft-P-$9%Tk_Dq00 zhcAu=Q*91LYmc*^DB7smn!DYp8#tjeM*qt{)Ga;mg~6SX>%Mya?NwXP%tRN8|78(M z5KS)Z58=ds=}mu46zxf?rtB!in$n)9YeicsvB=oAV77(RBSpWC z##gnWvc8Uj`h$2|Q%~pY-FmXlL83az3Ch-@=;O^>!ZTjD`&paCqe?3s-zTo7&6RvY z+dFX)Zz-dkbNPv~LsNj|aSpv4Fl5k?eiNNa(Ww>rNsm(RnHb1C#?LP7ovRnuYPwS& z(EYej`;*?fk0=jk61|mszKE?rwapoxbDPp-`J*dWYiEn z{T$y++3cOOTA}1UHhb4nW~Sa&cvX+m^WHio9mjm7N0~7q&Qz+e(C2K%5;gZS7QtivSwuQNvtKUkqMXB*!j?qz~ z)C<>iv}9{4GZ(JikZ&?irAFzTDN{2mOo!6< zXGFbod{_M=NKbKl3rjhEM%OzZ&!`k$wVrxoI_J;%I78VFyeLunlF_WSR}H7uXfM7> zy^pG_Zq!pir7B92kcUkFAA1WhqID$#EJL3Um1*Fu}04_o_an;c!mwpVHm#`K)w}yMJ(Q-N&A6xjueQ_it+FeD(X- z*o#)WB`87cJ6={gZ7(q+J$a5BEh5UgY_1tJ(@`GlR5Alr7P-@}SetodeORjl^@_mB z+FQ}Ik%45mXpwsxKD|~2dehBKkJ7irEkymr-sWd@3%Xk)EqlTrs^4n66DeJ|Bu)2x z_C1akCXb3eO6@SB%X6o5x1F*}vHuilS6yHe=u!FFuB-BEMbs^q!A?eC^~NwQ9qA zwMj7QICJFc3tE!R==NYeE_dCMH&(%Nqji~a^W57`BImtxJ8wcQ=j;u--L&1@uPqos z#MM(?I}RP8hFSXRFlz-mGk?K}MXT!`Cu(b;&nE8au6RJ@%}^OaClGTJvHg{n4Xu z=0@2D*IB(NSE|y%J=W*?VpNK$RBN|P%&oOb;Z^I>my)K8@T%cd)aqmp6m!*qmk#Jf zjZ&40#p`^zQZaZQUX+`vYpv&E;57SCacd8zvOR;lL-|q0X$A4(nWL zCvCQ4`nhjU+A4>?iD%nYqkR7NyH3EvtH$A_HvK~CY_WvB6B;Q`?`s17j3KmqkTWFZ zxV@(*2@meJ#K?Pzj5mFDDi%J0===`>-z~sE_OxVNnBJW;2L9##i`=&~g`JU@A8o~q zqvr|V^kFc6%Ra^C*3jyahR1{!X-1(W>|Ug?V5ETTD$L%c@3Rl`#5TqUY($UfLpq*q zDQ89B9~y4k>=IO42ca={9wNm$Y;TvG6y~I?rsMIzH|2w?+Q%w!ctLJcm3kn7N#s-<}n_qQmP=gDe)=AW44^9439vr zSJ*PX9?vWAoU*G5-x`o6j;Brb>0EJqZ$|1GNPNFiTJUZ&Qndi0c8(miV)nN>eDZz; za^yLn)pmJttWh=6 zH{-v5U1G?;74OIF^tH&b+8!@k>{4RLiL~Qb*ns>ikeWDnjnjSJQoagz zqIM~bz?1F5Cyltvr!Byu9^a*98c<>s&;9E{YV%D}fh|JnL~y6oE`_v@*`-8L4(~PL zcRkWG4yhG?e6zJ+DQv+8{1JQV6Kt1i8$Y%SSICieQT(s7>&I4$AwAPn<2y@Zx&C#D z01s*-Y{44*RRFVkd}j^F8F64iim?vl4Yniso;4$eJb$*SU@2@@)}Pdh+Ssz!2>+7D znvfeysj;yj4s0oENDU~Fty+UOV+HF%O0^&#(v39`{v_Tkg?uZ$lH65c*MfbA9L+uB zo9)`v7AU0(Ik(~+_IUPbwm$15oXm52e8*N`tB}Xo{-j(jQu((_Ez+v~&l1@q z$jht&b;<_h&XSosIgmWZ79=LZ)zXLam0}$}a2fN26U$!v`73uk_u0Ew?c4o9BIzkL1Wr|x;{+V&?l&4?W6I_H&do%_34pLzDZ7d?7P|MS-k{_wuX58m?9 zw+|d_zpCtsz0q~;w{ZY=r%~nbnR{t6pUnWc%XjC&cijh8^Id{hS~p%Qtve6q>khO!u##evyyHCyB)7A38 zMfK-ieb=*JanmGq|2+?%y8eN-M`tW7U;LlDulAyiJB^9=>|0*CaOUv?E%$sQaoX30 zes!3KNjLYQ8HzS;K4seL|MHm2{{E_KVmF+><*&~BsZz9t+YDBKP>AVDO>Lhvi2WLr&MoK;WiDHoaE1yd6e^PamzZ-na(+BFx6@# zbv`sTdNpf{V}_7|&X^7GB_HXWGcU_J z!&@N-(F)h5EJ&g=yc@L4%JbkXz3SW~{Ro<9wpph#=g(6(ZKpXSjZYp8Xj$`g$!dpj zlK~)*6ki}~)4qhP&7X_a+d5j+Y+J^^@>F8ZD@UMN^xE?z`Bi33;N^0Ko{6)z0A&H47(so&39(6Rf_NB?jC(|cQ(W4rb9x#xWE^9vR%|H{Apbl15X25G#+-S6)E z<@)R9Ki=}l;^%vAe|0NEH!g# z{M7Y>*<7Vm>g}Ana~8z;J~eRB=c$#;>_cXQoP%(-RA-mTnN1Wp5XUM&aizvFbC9{D z%po{)p|x4&0<DDRIe@nV&IOzYI3Ms1zy$y{Sp|T3z7wBS0D32C0JQ*0 ziF!Z-pb@~j3Kk?*+UM@P0r7@Bsj$EG`9H2IvM*J?a5a@l68y0Q~@3 za0dXJ0Gk0UhcJ}=PJh8x0JX6+AOjc$sC4S{L%6pCuoJKg@L|9efGYvV0JJZS5>BjR zQG64dU^D31WJ?{|i6~boPeuMw+Ct++$G4DZw5R-A+l&u$oOG4Ig-0-22v^dRK@SsW1>)hL$BjM*&I>pj5q}{)dGt{< z;Z7qo9QsS>SrSi z=Y8FHgzyjMaqswj>&iLb{{DmgTW5WF_b1=U0joe=T2{7py$Pw*{+zA0v4iV&F2_uW z@i>dJ5d!M;!DTSb@NCK|+bFYs|B2p&33!+1PrF?q8~;T1pP;Ur-vGs3EhW~4v$tf< z-mnSi%3)fgY4AxOHb`~75WnS_qIZo|^Rwenvm?7}>s5^!e5xf=b^5+y+o zeyjbR^Eldw7&{=2KsRTSL%c4oqL?`$+HvbXweb zQXuzQAA(eJpKnMy_|-|{Vw%^TZm{C_2zhIPQcdqdzTH92Mjq}RC8UQc9|N_(t*{Ho zPR$}k*^Y0)n`UkN_Z3mD1pR)Y_M>I550sAJKka!!cXQ9=LOEy;&syZe9gp3+F2B71 z1&iN_WI=%(owW7eSA_FAlfv$HGWV@MGbgRNKNGElv<}KzUbVl?n+xC?K};&Wl}X4q zp`UqDW+Cc3Ybs*^$9c^Gg(N|4eZBW4A)VGkI+@ypvvghNX-Mm}Kg?>pO-r&)c}@Z9 zGs@f^UF$AYOg2cxye*-sR0)NIY&?p`}0(EsLM4ia4`{Uv0ovCjRC$ep`Zv%Sqd61cuZcKTrcI=y%x zxgMP!GkU@Mfrm2bZO_llQuC+hN*wwjoER{@ z>965#CGkfz7{uFphY9_r zck4-iPiT85E@JT?#`9c$qI}g9 zV0oNFF9!@6bfVuxr&4rkMSgN>sd~@EK;|)ic46;ay|`A>o%(?8$A#LT^wxc(3vJ&G zdL+?XllDFirYGYJ&pHwtzTtJQsCY@nKq+z^`?_=Wmd2W`Vfv^Pz1(yrf9IDcqmYsDOz+)?8A-DO6iXxWTJ|4Oy@PElwGYw* zNmC~sy=62x+(}dVg19Eo^cOj6t?47&$khXs>6%K;e&g2NRgIkeonDviXg_Z{dZ+cK zA;%rlv!+@;2;JB(C4JJyqz4=@f?ludM#`%7f4Sa8osKfltG9N7R@5+3_@qzTDZ}&k zGfI0(I~gx^Ipm96>C7rWC z$rFmk5+l%k_-&mh>8A12Jml5O6@7!W7&$}BJb(Q_xQL^I>CN`;8I&IGp26{k?La=& zHDgNMs=_yn)yo~By+>5{IK=ndtt{gi+o~J>U4CvfFz_BL+6?KQZ){UJQYYtmO&s4$ z9l$$hwd&Y=Z1%2Y{u#TbWa)Wtosy1YzS5)2C=+KY)mIo3)Niz+Q^u)O=9r-=MGwC2 zOwxbD z^h((g%u_ML)4yJNFx>3wuFhmwUzHxg$YJweM0pAAxSJTC%uUh@*+e8By=_KB)c?q^G#Og{2%nqw6ibKbL1z z3a?sEy)m8h=X{)@><34dYTRF<+EBFzzW!s}SffDp7$w@ZFQAKifMje@@@w<7&@k3_H=B|BQbLX{h zz$#e$yYMV`hY6m$B0EY{Vg)~UWUt2RaBd2k1Z`C$MP+tw%X4W4ahL0Pxwemfe_qsA z84tJ`{6OCd$2~LRh1#*-8bs|;QPB`;Cat`l99^~VQhY3FU|{;pyrW-a#UVMA>%en8 ze1&k+rKw2zn>pC_F1%YvJ}6PGp5B z-<6h=Q*KBjwk=m{I#*QYdZSCvb2tmN)onWM+mXIH?b~6L2m5Ndi|bWi$d#+!@-+Y2 zWi4xY%8ofxaF6jy8<>jrV7E=luZ~mC8sJykQneB)drlCpJ(_HSer;X))}3j)H!b7t z*%&q9qlw3P;0XIBoaN&lzgo6N8S~|eLOEH$YgJ${dX&B` zZXxI(_BKDOThQGaY1tEtt?IYhI`y_dC26|nv+vOwX7Z@mqc|U@1xlCaPUpqT)WO() z3R6JiWtQ#Cv-rIo=Lag=j7!e9!y|WzG%@qvFU~tIJ6H>8i&#D@=lOwZWv>76!(6v0 zqqZ3X&v@}MNGwL7Z?P@4vKq>?@ajDu{rK9+J8IR2_iB@1)N$s>)fcoRn=5yM^|;)1 zOWs%o%Z=7$#?5nYJBgh4&h5MjwVbmz=yua~bHBErwI|vMfED#cW)wOc0U z)>@_Ts&(m0NmE96)o?0mb+QMF8B64)1A0-TRHb6^I$y3-44#J<<)-Rd>$y4OtaGlf zS1Wi?Wh&bP!u*2R%VfI(YB->a`}J1-qT{bh3KveaI6lc{(Le zt*k&yc|G=A<4MxZh(Bi(I`1f$uf=Cb%@)cPoG?G3LEm)VLAp8t}g! z?^YvyGyVrXWqX-VsS816o}9GVj_K!llWD6Q{wD6uS`F_R|NC7h;Nex{@KT$8A$7JG zV`J%+mR-JzIYU%7jD32N@ZfGs-S9_AWW4FKQ?c*~MCX49{+BJlK=!m`T$tXSGlm_* zxMG6)mZsnZXMVI5_aXI2c9jJ4x9pQ$ZVjy-X(Pfo0WXn>1tSGUaL;|8D_3}88(U&y zcDJWWe6#F0xBElG&DqvLXw03@NU;vv+qK31@1!h!K|;slfp5wOSGA8-;_!}EQPh`v zu=S(nW*ZkyA=F_M!p%MXMG)E5MJH3CTDA&x8MuO8Zz$xm3lZwG| zu2^AY`WoaQJiHxwaFvw>diO*6s*^)Gn(eT8g1dgnuV%<6o-Ii!PU%f4#b^JgJzHt=^W*P58z)cuFx(tlWiDMNKK|EHk`wnEgtjd~UYobPaIgS;{;|GiHCQ z!zX1nM=(XTBEhWi0)D0-N7L-vZl9Gyi(rDwrQ>&S4EvVRpR_FL1 z#dH5sNNv8!5m8PyZN@td$cd#8)0kaK6u)?{3BT)+o^eR6_~Vr=QH$XztTKYfke+FH$~H@5x&C#D01x&6j^;J^s{m#^rJFS%XT*U8DaJaG z$JmZhq+`v9A$L%3Dp(3Rll3RHqBgcMyes_6(_?uaEK8}eu^kLTY?yGW~$$JqX)9Q8l{QfiS_^?#OF0o=&TN{-3d>a1-Hxv{2fAIb?zFP0X14;z2I=bnr2 zy#H?w-1p_ro@_q5`PX0i!K)YC8-L=jXS~04^;=#vpUvEU?oT$a-r4cwKh~`td;FBR z`Rtm9uV~u%`2ViB@71wS|7!m`yUb^MU#_}&$z^ve-BP` zb#40-n`T4~be;3cx6b|Dtj|3A-isbRrT_Ws27h?p;|Fhf>DvbmwqI5D#NOz-_S-lB zyVIz0_{_aDna^eb+~vFT;JfYv%lR(BE3F$ZmDZgH^K}Pes8zhzope70vv-?}&bbYASqiw}PD)YD#k z#(Xya6KfuN^D8f`c<}2_eW7FB3D@|Q@wEg0{lvjGb--x{N1>?@phN-)!+Q?Q#1E;wLEZ9{kd1)_3T&NG)dim z&%>v#f1vHr84Jr7|L5+jy=dc3W8yvgmX|J^dHg`jJ>N*2_O+p39p+)y$$e;sqK%tR znfCg>Jm#{$zv`OU4d-w9tMh)jdz&f)k8=20<>Rg!^Ksu(`S97r$!9+HuQP61{=m;x zzwqZhH*~0HEW>@dad*?rU5@*GXkuLZixUS{K8wD|899nO6F;H<^+0dqj{0)?vYdO) zc{A7Pa|R%Ll5pKVch+XSAWw1SoNEMI@lMn)Vd!co+R28dPmV4_R zDO)!Uva0|XQ>o{sQf?YGdr{@NGS_mRgKJsmIny~O7Ur}QNu3W(tzXUT;-F6of+=v) zvv}?`PhFOEq@1(Rc9FJb_u#=G{um5rS#zkV_t?~zww!C?rX~VN=bW`!)*0SPJBU`e zHf2E)ouyZro1`28(z8t*%bY(?VP$BdN#m2p0b15P%|2>ZakD`ogj72tYtyELtj(W` z)!RBk)od$29xk&c@N&6A&&1j+xy&k<Mq8%uXd(@h#KT#8*mPg>ci#iZ>D{58DG*6jpvt_QJZf5>b?~yvUr<4H`7IsZMz+M{m8Yinb*b+K4Q zLp)kjQCHmO=Z`#?j zIFTOdN%Y#!Q%IW`mhb72%8_Jx`_7#^nxeH$HTbJ&sHv){sIHDzqnyU7iYi>KO^x-9 z&GDLD&nv9znrd3A8tbYnT59U*D(adW>nfTX<8_Rs*E{%2R5LAb+`C&W+MAGKnXLj^ zIJZeGnBj3HV>d2GdogNKmD_=`*$v#KQg4rgr#TC%2FHtaf%-gW7&80dJ$`|)t6oJv^ZqTp>g z*@~wEQFYOrEw=f|#6@gpM;;H;eixh`T(h$S{6;HP7?vbzyJ|NkWxFgy3+eaaSs9Dr z$}mfvC|6D;&uWcCQT&wn8j?C17Xpf{0L=Az@_?`A%j_b4?)7l=46Qo7L zEY&Tailc&z;$h-5(H2n_W7^9}-|>5lg0+a~4q=w+ZjoS;Nm^nu@hDbOhlz{H6~X8c ze!G}B1gC|_i7+Eyg3Q%KPJ~&in=8j7C&I+XY>nb@kroNh=Ua@N2(wgoiv;TzA}1yj zkKlbn<{4q)Vsb?=dWf8uEF6NG_|Y1)E$s z9=$HiSw|2}JN$Wu;tCaMA54##p8t8-^uy!^lQWBzRbkG2f^g;Z{4c+^XL|m(1GdWW zR=>FARq(zsJ^#}_$m#jt4%pAbiYm|bG!)B~Vd7$r9Kl9&(JCejhv2lBp8xHD$1bc@ z2ovYZFmV|#S56iV#d2ksGqE5#mGsFlOLdQ&!DNGI`jd%A@RpDfJWO21%axObLvUJ% ztPF3RPcc@BFiV{XS577#!COMOGE7{?%axObLvUIsy)vv7Tg7l?cZcmO2rxoJ>4|w}fzIn7Ei65p3)beSflW2u=&(O2!N_dXQ@jV)$dE zAy+$a6*qSYkhKW(qsX2QT+hzvLPn-?l|EB&KZN2U4Z|$e&6UM@cVyNRCO#%d6o-qn z2-l*^%6Zb6D`99;je-7b5k`1&?FQR}ZBYU0bKd~2C}Hd;PxocZaG#A}azaWCZx@MR zr3&4{EY-~i!Fz)A$jQW|7|9SO9wsjar-w+%$-*NTErct>+jFKEt_-u(iE`y+;u5SS zq<4gg$3(btvhWB-3*k!cG{^m4xCT&mh~c_G?iND{%~gF|wa2v?El4A)`?%W+SNn3; z8MWpxsDwTd>m?#sPpB(IJcxp_qRgN{DT(q6B_l)9BWNEDN zoW7183Y7Z#<%-%!V|7DgeN%NybM>yB_Im0|k|UYq_Jn+_7*1yr>9KSsIhg3#*gZ6q zOsUVxZ9UuV#nv%dY%PnigJr?R+1czijP~^h8Mr7JgLS$4dtr8-g{Yp96U);Cww zwP4A0bF`tVqPDi7K6m|ibA4;HwHd3!o9n6?E9&a&>nb9Rb&-nN`kKb3+U97yI@+*n z=L%$=;J@g|KyQC?Rd;4!TQcW{cYOlB8yNL6>`wr&uYq^JK3BaTYsFumh`9yJSymZ0 zJu)utwX^^bKfJ|2>9L4W9s6r;EmewWMQ4|6q|nn>)7#rqTT{{0)7^{R40>xSnrpGI zK~+_KZCyS0G^ncD6)&`%L5Ye3MGo~T1e>FNT9r?oo&C=J+%?B+w|Dj6qf4M|a3@dh zp`ds6+=e^cc4j2IB})P+y8!xJ8@{Xa<1e`;a!*!>s&gXqpYKYX)x{kqTLBugVdyX1 z<#G^Oirz)1+&T}sOA~aY^X=_LylQRX%+_#nDonrNSW?SMp*DSJ5#2)Sx4T^?ny-|` z_FgkiS~&4_?^c9N?Tfpf@*3{TR9-TW9_>zbW=4AlhAX?0>CC!;q2A%kE8B)L$&tS9 zo@9DaIR@qQ!0^zr>gviW{&6J1aktraS3A<`J#$zCUX6Va3H4RowcW?@m8B~OdPauR z!+mNbvJ++hNvW~nq0Go|DwQ0mTr)bs+W7~)`p3NY1>p(KqJAWxCPezFJfditW z1F2qjlD97H>K^G&W>%-VGkwD&gYM^xRziPUlk7oB0!J(Jy>hvIg&YNZtG5sCNI6zI)VACYZjULRQ}`LlaD#6 zNK)7X$RCOSZBA@1Om2e{LE^C#+K%0r6}|x7IFd?{$I|LqXEK%S$s}`F&6Aao$AqrQ zV#VXjY!0%e(k#~WRQ}^E|Kt<7-=Dhnl*x$cB{{VeHFqA%w<*^)dj~+Z_j=*IauZh^+)^=upL zD3vz*va6Ll=5WS_T8Yjz_J&z3^7y4gqp8$|{90!9I^GC2>#JPV($zZBJ(#?Fcx3Y; zHI{9w!$=lMr;~#lQe%-!W@KQ)XeODi9Nxgav&>R4vbbtTeNGKM<};~X6=+~?gqM37 z{r|*fu?&0Pc&`(}cn)?Cp6Ao6#cxQd@>GVuuB62g^X;$m_{FnS1AP8JTqX(3!0X2ef$t_-u(iE!m) z;t{+hge$|uWxQNDSvUlzg>Yq<5kJAXGR#sZ!j+SWNAQ+VBlvXum)ch~%!r@hTp4Dm z6XD9q#G_cQ3=@~}%F4;YAvi5WR)!h*6oe}yn6K=xcd833>deiR#l>7r$A1YQgc1ssDcVj{6pUhi5~@c1uP#p(E|E|4Ohms6d=<`q5WjB`?c z^dlWVHQx9y@^6s%sm1Uld&T7`{#Sg&Pl+XT(j{N~m+ic!8((EK9Y2+gdNZO{eB#fh y{fBiA}|Md_*^+$<+GUC4|b~uKPH=2j?ZX(yjqJfH!S_(V<>;D1CbA>_x literal 0 HcmV?d00001 diff --git a/DryadVertex/VertexHost/system/channel/channel.vcxproj b/DryadVertex/VertexHost/system/channel/channel.vcxproj new file mode 100644 index 0000000..9521095 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/channel.vcxproj @@ -0,0 +1,186 @@ + + + + + Debug + Win32 + + + Debug + x64 + + + Release + Win32 + + + Release + x64 + + + + {482E0741-E244-4974-97D4-3A7167581E91} + channel + Win32Proj + + + + StaticLibrary + + + StaticLibrary + + + StaticLibrary + Unicode + true + + + StaticLibrary + Unicode + true + + + + + + + + + + + + + + + + + + + <_ProjectFileVersion>10.0.40219.1 + Debug\ + Debug\ + $(Platform)\$(Configuration)\ + $(Platform)\$(Configuration)\ + Release\ + Release\ + $(Platform)\$(Configuration)\ + $(Platform)\$(Configuration)\ + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + + + + Disabled + WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions) + true + EnableFastChecks + MultiThreadedDebugDLL + + + Level3 + EditAndContinue + + + + + X64 + + + Disabled + include;..\common\include;..\classlib\include;src;..\..\..\..\Hdfs\HdfsBridgeNative;%(AdditionalIncludeDirectories) + WIN32;_DEBUG;_LIB;WIN32_LEAN_AND_MEAN;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + false + Default + MultiThreadedDebugDLL + + + Level3 + ProgramDatabase + + + hdfsbridgenative.lib + ..\..\..\..\bin\$(Configuration) + + + + + WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + + + + + X64 + + + WIN32;NDEBUG;_LIB;WIN32_LEAN_AND_MEAN;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + include;..\common\include;..\classlib\include;src;..\..\..\..\Hdfs\HdfsBridge;%(AdditionalIncludeDirectories) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + {e092e2b9-d3c9-4ce2-8201-bda442574c97} + + + + + + \ No newline at end of file diff --git a/DryadVertex/VertexHost/system/channel/include/channelbuffer.h b/DryadVertex/VertexHost/system/channel/include/channelbuffer.h new file mode 100644 index 0000000..02af79c --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/include/channelbuffer.h @@ -0,0 +1,275 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include +#include +#include + +class RChannelBufferReader; + +#define RCHANNEL_BUFFER_OFFSET_UNDEFINED ((UInt64) -1) + +/* this is a dummy class which may be extended in future to allow item + parsers to communicate with the buffer reader to influence i/o + buffer sizes, prefetching strategies, etc. Nobody uses it yet and + NULL should be passed wherever there is a prefetchCookie argument + for now. */ +class RChannelBufferPrefetchInfo +{ +public: + virtual ~RChannelBufferPrefetchInfo(); +}; + + +/* buffers read from a channel are typed with one of the following: */ +enum RChannelBufferType { + /* RChannelBuffer_Data: this is a data buffer which can be cast to + an RChannelBufferData object to retrieve the payload. */ + RChannelBuffer_Data, + + /* the following are all marker buffers which can be cast to an + RChannelBufferMarker object to retrieve the item describing the + marker. */ + + /* RChannelBuffer_Hole: this describes a "hole" in the + underlying data stream. */ + RChannelBuffer_Hole, + + /* RChannelBuffer_EndOfStream: this is a marker indicating that + the end of the stream has been succesfully reached. */ + RChannelBuffer_EndOfStream, + + /* RChannelBuffer_Restart: this is a marker indicating that the + remote end has requested a restart. */ + RChannelBuffer_Restart, + + /* RChannelBuffer_Abort: this is a marker indicating that the + underlying stream has suffered an unrecoverable error, restart + is impossible, and no more data is forthcoming. */ + RChannelBuffer_Abort +}; + +/* base class for data buffers generated by the byte-oriented read + layer of a channel */ +class RChannelBuffer : public DrRefCounter +{ +public: + static bool IsTerminationBuffer(RChannelBufferType type); + + /* When a consumer has finished using the data in a buffer it + should call this completion handler returning the + buffer. + + prefetchCookie is an optional hint which may be used to + influence subsequent read buffer sizes or prefetching + behaviour. It is dependent on the implementation of the + underlying buffer-oriented i/o class and should be NULL for + now. + + The completion callback mechanism is used to implement flow + control, as the buffer-oriented i/o will only allow the + consumer to hold a bounded number of outstanding buffers before + blocking further reads. + */ + virtual void + ProcessingComplete(RChannelBufferPrefetchInfo* prefetchCookie) = 0; + + /* get the buffer's type. + */ + RChannelBufferType GetType(); + + /* this returns a description of the buffer, used for debugging + and monitoring purposes. + */ + DryadMetaData* GetMetaData(); + + /* this replaces the item's current metadata with a new object */ + void ReplaceMetaData(DryadMetaData* metaData); + +protected: + RChannelBuffer(RChannelBufferType type); + virtual ~RChannelBuffer(); + +private: + RChannelBufferType m_type; + DryadMetaDataRef m_metaData; + DrBListEntry m_listPtr; + friend class DryadBList; +}; + +typedef class DryadBList ChannelBufferList; + +/* this is a default completion handler used to return buffers of type + RChannelBufferDataDefault and RChannelBufferMarkerDefault. It is + called from the ProcessingComplete methods of each of those + classes. */ +class RChannelBufferDefaultHandler +{ +public: + /* the callee owns a reference to buffer after this call */ + virtual void ReturnBuffer(RChannelBuffer* buffer) = 0; +}; + +class ChannelDataBufferList; + +class RChannelBufferData : public RChannelBuffer +{ +public: + /* The payload of a data buffer is a DryadLockedMemoryBuffer. + This cannot be grown, nor can its available size be + modified. It is derived from a DryadFixedMemoryBuffer and hence + it is guaranteed to wrap a single contiguous memory region. It + has the same thread-safety properties as the underlying memory + region, e.g. it can be read concurrently from multiple threads, + and if the threads co-operate they can safely concurrently + write to non-overlapping regions. + + The caller must increment the reference count of the returned + DryadLockedMemoryBuffer if it wishes to use it after the + RChannelBufferData goes out of scope. + */ + virtual DryadLockedMemoryBuffer* GetData() = 0; + + /* this returns a metadata description of an offset in the buffer + (which may include a description of the buffer, along with + e.g. the offset's overall position in the stream), used for + debugging and monitoring purposes. The caller owns a reference + to the returned metadata. If isStart is true the returned + metadata writes the offset using the ItemStreamStartOffset and + ItemBufferStartOffset elements, otherwise it uses the + ItemStreamEndOffset and ItemBufferEndOffset elements. + */ + virtual void GetOffsetMetaData(bool isStart, + UInt64 offset, + DryadMetaDataRef* dstMetaData) = 0; + +protected: + RChannelBufferData(); + virtual ~RChannelBufferData(); + DrBListEntry m_dataListPtr; + friend class ChannelDataBufferList; +}; + +class ChannelDataBufferList : public DrBList +{ +public: + static RChannelBufferData* CastOut(DrBListEntry* item); + static DrBListEntry* CastIn(RChannelBufferData* item); +}; + +class RChannelBufferDataDefault : public RChannelBufferData +{ +public: + static RChannelBufferDataDefault* + Create(DryadLockedMemoryBuffer* dataBuffer, + UInt64 startOffset, + RChannelBufferDefaultHandler* parent); + + /* The payload of a data buffer is a DryadLockedMemoryBuffer. + This cannot be grown, nor can its available size be + modified. It is derived from a DryadFixedMemoryBuffer and hence + it is guaranteed to wrap a single contiguous memory region. It + has the same thread-safety properties as the underlying memory + region, e.g. it can be read concurrently from multiple threads, + and if the threads co-operate they can safely concurrently + write to non-overlapping regions. + + The caller must increment the reference count of the returned + DryadLockedMemoryBuffer if it wishes to use it after the + RChannelBufferData goes out of scope. + */ + DryadLockedMemoryBuffer* GetData(); + + /* this returns a metadata description of an offset in the buffer + (which may include a description of the buffer, along with + e.g. the offset's overall position in the stream), used for + debugging and monitoring purposes. The caller owns a reference + to the returned metadata. If isStart is true the returned + metadata writes the offset using the ItemStreamStartOffset and + ItemBufferStartOffset elements, otherwise it uses the + ItemStreamEndOffset and ItemBufferEndOffset elements. + */ + void GetOffsetMetaData(bool isStart, UInt64 offset, + DryadMetaDataRef* dstMetaData); + + void ProcessingComplete(RChannelBufferPrefetchInfo* prefetchCookie); + +protected: + RChannelBufferDataDefault(DryadLockedMemoryBuffer* dataBuffer, + UInt64 startOffset, + RChannelBufferDefaultHandler* parent); + virtual ~RChannelBufferDataDefault(); + +private: + DryadLockedMemoryBuffer* m_dataBuffer; + UInt64 m_startOffset; + RChannelBufferDefaultHandler* m_parent; +}; + +class RChannelBufferMarker : public RChannelBuffer +{ +public: + /* the caller's referenece to item is transferred to the newly + created buffer */ + static RChannelBufferMarker* + Create(RChannelBufferType type, + RChannelItem* item); + + /* Get an item describing the marker event signified by this + buffer, which can be passed to the application. The reference + count of the item is increased before it is returned, so the + caller must decrease it before letting the item go out of + scope. + */ + RChannelItem* GetItem(); + +protected: + RChannelBufferMarker(RChannelBufferType type, + RChannelItem* item); + virtual ~RChannelBufferMarker(); + +private: + RChannelItem* m_item; +}; + +class RChannelBufferMarkerDefault : public RChannelBufferMarker +{ +public: + /* the caller's referenece to item is transferred to the newly + created buffer */ + static RChannelBufferMarkerDefault* + Create(RChannelBufferType type, + RChannelItem* item, + RChannelBufferDefaultHandler* parent); + + void ProcessingComplete(RChannelBufferPrefetchInfo* prefetchCookie); + +protected: + RChannelBufferMarkerDefault(RChannelBufferType type, + RChannelItem* item, + RChannelBufferDefaultHandler* parent); + ~RChannelBufferMarkerDefault(); + +private: + RChannelBufferDefaultHandler* m_parent; +}; diff --git a/DryadVertex/VertexHost/system/channel/include/channelinterface.h b/DryadVertex/VertexHost/system/channel/include/channelinterface.h new file mode 100644 index 0000000..571fc72 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/include/channelinterface.h @@ -0,0 +1,731 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +class RChannelBufferPrefetchInfo; +enum TransformType; + +/* this is a stub object that is e.g. passed to parsers and marshalers + when they are initialized. For example, DryadVertexProgramBase + inherits from this interface and is set as the context for all + parsers and marshalers in that vertex so that they can access + configuration parameters, etc. */ +class RChannelContext : public IDrRefCounter +{ +public: + virtual ~RChannelContext(); +}; + +typedef DrRef RChannelContextRef; + +class RChannelInterruptHandler +{ +public: + /* When the byte-oriented layer from which a channel is reading + signals an error condition by delivering a Restart or Abort + buffer, the application may wish to learn this immediately + rather than waiting for the error item to be delivered (for + example, the application may be blocking reads from the channel + indefinitely, yet still want timely error notifications.) In + this case, the application can register an interrupt handler + which will be called back when an error occurs, or during Drain + if no error occurs. + + interruptItem is the terminatino item which will eventually be + delivered. The callee of ProcessInterrupt should call IncRef on + interruptItem if it wants to store a reference to the item. + */ + virtual void ProcessInterrupt(RChannelItem* interruptItem) = 0; +}; + + +class RChannelItemArrayReaderHandler +{ +public: + RChannelItemArrayReaderHandler(); + virtual ~RChannelItemArrayReaderHandler(); + + /* This sets the maximum number of items which will be delivered + via the ProcessItemArray callback. The default value is 1. */ + void SetMaximumArraySize(UInt32 maximumArraySize); + UInt32 GetMaximumArraySize(); + + /* When an asynchronous item is ready to be delivered on a channel + this method is called on the handler object passed to the + associated call to RChannelReader::SupplyHandler(). + + If deliveredArray is empty and the asynchronous read was not + cancelled, then the channel has completed, i.e. a termination + item has been delivered to some other synchronous or + asynchronous read method. If there are multiple outstanding + reads the order of delivery is undefined so an empty array may + appear on one handler before the termination item has actually + been delivered. + + If an asynchronous read is cancelled ProcessItem is called + before the Cancel call completes. If the array is empty in this + case it is impossible to determine from the ProcessItem + callback whether it occurred because of successful cancellation + or because a termination item has been delivered to the + application. + */ + virtual void ProcessItemArray(RChannelItemArray* deliveredArray) = 0; + + /* there are two types of asynchronous handler: queued and + immediate. If the handler is queued, then it is called on a + worker thread which is specifically assigned for processing + items, and the application should perform as much computation + as necessary on the calling thread before returning. + + If the handler is immediate then it may be called on a variety + of different threads, including I/O processing threads, and the + application should return as soon as possible. Immediate + handlers are suitable, for example, for waking up a synchronous + thread which has been blocking on a read, or for queueing a + work item to an application's private thread pool if the + standard thread pool is for some reason inappropriate. + + In general rather than implementing the ImmediateDispatch + method the application should just inherit from + RChannelItemReaderHandlerImmediate or + RChannelItemReaderHandlerQueued. + */ + virtual bool ImmediateDispatch() = 0; + +private: + UInt32 m_maximumArraySize; +}; + +/* the base class for an immediate item array handler, see + RChannelItemArrayReaderHandler::ImmediateDispatch above. */ +class RChannelItemArrayReaderHandlerQueued : + public RChannelItemArrayReaderHandler +{ +public: + bool ImmediateDispatch(); +}; + +/* the base class for a queued item array handler, see + RChannelItemArrayReaderHandler::ImmediateDispatch above. */ +class RChannelItemArrayReaderHandlerImmediate : + public RChannelItemArrayReaderHandler +{ +public: + bool ImmediateDispatch(); +}; + +class RChannelItemReaderHandler : + public RChannelItemArrayReaderHandler +{ +public: + virtual ~RChannelItemReaderHandler(); + + /* When an asynchronous item is ready to be delivered on a channel + this method is called on the handler object passed to the + associated call to RChannelReader::SupplyHandler(). + + If deliveredItem is NULL and the asynchronous read was not + cancelled, then the channel has completed, i.e. a termination + item has been delivered to some synchronous or asynchronous + read method. If there are multiple outstanding reads the order + of delivery is undefined so a NULL may appear on one handler + before the termination item has actually been delivered. + + If an asynchronous read is cancelled ProcessItem is called with + NULL before the Cancel call completes. In this case it is + impossible to determine from the ProcessItem callback whether + it occurred because of successful cancellation or because a + termination item has been delivered to the application. + + If deliveredItem is non-NULL then the callee now owns a + reference to deliveredItem. + */ + virtual void ProcessItem(RChannelItem* deliveredItem) = 0; + + /* there are two types of asynchronous handler: queued and + immediate. If the handler is queued, then it is called on a + worker thread which is specifically assigned for processing + items, and the application should perform as much computation + as necessary on the calling thread before returning. + + If the handler is immediate then it may be called on a variety + of different threads, including I/O processing threads, and the + application should return as soon as possible. Immediate + handlers are suitable, for example, for waking up a synchronous + thread which has been blocking on a read, or for queueing a + work item to an application's private thread pool if the + standard thread pool is for some reason inappropriate. + + In general rather than implementing the ImmediateDispatch + method the application should just inherit from + RChannelItemReaderHandlerImmediate or + RChannelItemReaderHandlerQueued. + */ + virtual bool ImmediateDispatch() = 0; + + /* implementation of RChannelItemArrayReaderHandler interface */ + void ProcessItemArray(RChannelItemArray* deliveredArray); +}; + +/* the base class for an immediate item handler, see + RChannelItemReaderHandler::ImmediateDispatch above. */ +class RChannelItemReaderHandlerQueued : public RChannelItemReaderHandler +{ +public: + bool ImmediateDispatch(); +}; + +/* the base class for a queued item handler, see + RChannelItemReaderHandler::ImmediateDispatch above. */ +class RChannelItemReaderHandlerImmediate : public RChannelItemReaderHandler +{ +public: + bool ImmediateDispatch(); +}; + +class SyncItemReaderBase +{ +public: + virtual DrError ReadItemSync(RChannelItemRef* item /* out */) = 0; +}; + +class RChannelReader : public SyncItemReaderBase +{ +public: + virtual ~RChannelReader(); + + /* a client must call Start to cause the channel to start + generating items for the first time, or after a Drain has + completed before the channel is restarted. + + prefetchCookie is passed to the byte-oriented buffer layer and + should be NULL for now. + */ + virtual void Start(RChannelBufferPrefetchInfo* prefetchCookie) = 0; + + /* SupplyHandler passes a handler which will be "returned" via a + matching call to handler->ProcessItemArray. Multiple handlers + may be outstanding at a given time. The channel blocks when + there are no handlers available, so this is the primary flow + control mechanism to allow the reader to exert back-pressure on + a channel. + + If the channel has not been started or has delivered a + termination item to the application since the last call to + Start, handler->ProcessItemArray will be called with an empty + array on the calling thread before SupplyHandler returns. + + If there is an item waiting to be delivered and handler is an + RChannelItemReaderHandlerImmediate then handler will be called + back with the items on the calling thread before SupplyHandler + returns. + + If handler is an RChannelItemReaderHandlerQueued it will never + be called back on the calling thread. If there are items + waiting to be delivered, a processing request will be queued + for the item with handler, otherwise handler will be queued + waiting for the next item to be ready. + + If thread B submits an RChannelItemReaderHandlerImmediate with + a call to SupplyHandler that is overlapped with thread A's call + to SupplyHandler, B's immediate handler may be called on A's + calling thread before A's call to SupplyHandler returns. + + After a handler has been submitted to SupplyHandler with a + given value of cancelCookie, that handler is guaranteed to be + returned before any matching call to Cancel with the same value + of cancelCookie returns. If there is no item available at the + time the handler is cancelled it will be called with an empty + array. cancelCookie may take any value, including NULL, however + it is safest to use NULL or the address of an object owned by + the caller. The cancellation mechanism is used internally by + the synchronous FetchNextItem call which uses an allocated heap + address as its cancelCookie. Using the heap address of an + object which has not been freed before the call to Cancel will + avoid any danger of a cookie collision. + + For any item A which has already been returned via an async + handler or call to FetchNextItem when SupplyHandler is called + it is guaranteed that A's sequence number is less than or equal + to the sequence number of the item eventually returned on + handler. Beyond this constraint, if multiple handlers are + outstanding at once, or while calls to FetchNextItem are in + progress, the order of delivered items is undefined. + */ + virtual void SupplyHandler(RChannelItemArrayReaderHandler* handler, + void* cancelCookie) = 0; + + /* any handler passed to SupplyHandler with value cancelCookie is + guaranteed to be returned before a call to Cancel completes. */ + virtual void Cancel(void* cancelCookie) = 0; + + /* FetchNextItemArray blocks waiting until an item is available, a + termination item is delivered on another thread, or the timeOut + interval has elapsed (timeOut can be DrTimeInterval_Infinite in + which case FetchNextItem will block indefinitely). + + If the timeout expires FetchNextItemArray returns false and + *pItemArray is an empty array. Otherwise FetchNextItem returns + true and the any returned items are stored in + *pItemArray. *pItemArray may be empty even if + FetchNextItemArray returns true: this will happen if Start has + not been called or a termination item has already been + delivered on the channel since the last call to Start. At most + maxArraySize items will be delivered in *pItemArray, however + fewer may be delivered. + */ + virtual bool FetchNextItemArray(UInt32 maxArraySize, + RChannelItemArrayRef* pItemArray, + DrTimeInterval timeOut) = 0; + + /* FetchNextItem blocks waiting until an item is available, a + termination item is delivered on another thread, or the timeOut + interval has elapsed (timeOut can be DrTimeInterval_Infinite in + which case FetchNextItem will block indefinitely). + + If the timeout expires FetchNextItem returns false and *outItem + is NULL. Otherwise FetchNextItem returns true and the returned + item is stored in *outItem. *outItem may be NULL even if + FetchNextItem returns true: this will happen if Start has not + been called or a termination item has already been delivered on + the channel since the last call to Start. + */ + bool FetchNextItem(RChannelItemRef* pOutItem, + DrTimeInterval timeOut); + DrError ReadItemSync(RChannelItemRef* pOutItem /* out */); + + /* Instruct the channel to stop reading items and prepare for + Drain. The interruptItem must be of type RChannelItem_Abort, + RChannelItem_Restart, RChannelItem_EndOfStream or + RChannelItem_ParseError. If the channel is an in-process FIFO + or a pipe which has not broken, the interruptItem will be + delivered to the remote end and the remote writer will + therefore be made aware of the channel close request. Interrupt + may be called multiple times after a channel has been Started + and before it has been Drained. The first call will have its + interruptItem delivered; subsequent interruptItems will be + discarded. If multiple calls to Interrupt are overlapped, the + call whose interruptItem gets delivered is undefined. The + caller's reference to interruptItem is not modified by this + call. + */ + virtual void Interrupt(RChannelItem* interruptItem) = 0; + + /* Instruct the channel to return all outstanding handlers via + their handler->ProcessItem callbacks in preparation for either + closing or restarting the channel. Once the Drain method + returns all outstanding handler callbacks will have completed + and all waiting calls to FetchNextItem will have been unblocked + (though of course they may not have returned to the calling + thread). For obvious reasons, Drain may not be called from a + handler's ProcessItem callback. + + If the channel is an in-process FIFO or a pipe which has not + broken, and Interrupt has not previously been called, Drain + will send an appropriate termination item to the write end. If + the writer has delivered a termination item to the reader + implementation, an item of the same type will be returned to + the writer. Otherwise an item of type RChannel_Abort will be + returned to the writer. If the RChannelReader client has not + read a ternmination item, then the item delivered by a call to + Drain which is not preceded by Interrupt is undefined, since + the implementation may have received a termination item which + has not yet been forwarded to the client. + */ + virtual void Drain() = 0; + + /* If the writer has sent a termination item down the channel, it + is returned in pWriterDrainItem. If the reader has sent a + termination item to the writer (via Interrupt or drain) it is + returned in pReaderDrainItem. These termination items are set + to NULL when a channel is restarted. The returned items are + undefined if this call overlaps with a call to Start or Drain, + and it is illegal to call GetTerminationItem after Close has + been called. */ + virtual void GetTerminationItems(RChannelItemRef* pWriterDrainItem, + RChannelItemRef* pReaderDrainItem) = 0; + + /* After a call to Drain, this returns a status code and + optionally error metadata corresponding to the final status of + the channel. If the writer sent an error item, + GetTerminationStatus returns the error associated with that + item. Otherwise if the reader sent an Interrupt with an + EndOfStream item or called Drain after receiving an EndOfStream + item from the writer, GetTerminationStatus will return + DrError_EndOfStream (signalling clean termination). Otherwise + GetTerminationStatus returns DryadError_ProcessingInterrupted + which signifies that there was no error reported on the + channel, but the reader requested early termination. */ + DrError GetTerminationStatus(DryadMetaDataRef* pErrorData); + + /* Close may only be called if Start has never been called, or if + Drain has completed since the last call to Start. Close must be + called before the RChannelReader is destroyed. After Close has + been called no other methods may be called on RChannelReader. + */ + virtual void Close() = 0; + + /* Get the total length in bytes of the channel. This is known + only after the channel has been opened. */ + virtual bool GetTotalLength(UInt64* pLen) = 0; + + /* Get the total expected length in bytes of the channel. This may + be known even before the channel is opened. If + GetExpectedLength returns false, *pLen is undefined. */ + virtual bool GetExpectedLength(UInt64* pLen) = 0; + virtual void SetExpectedLength(UInt64 expectedLength) = 0; + + /* Get the URI of the channel. */ + virtual const char* GetURI() = 0; + + void SetTransformType(TransformType t) { m_transform = t; } + TransformType GetTransformType() { return m_transform; } + +private: + TransformType m_transform; +}; + + +class RChannelItemArrayWriterHandler +{ +public: + virtual ~RChannelItemArrayWriterHandler(); + + /* this interface is called whenever an async write (started by + RChannelWriter::WriteItemList) completes. If WriteItemList was + called with the flush flag set, this handler will not be + returned until the underlying channel has reported that the + write has completed. Otherwise this handler may be called + before the item has been written to the channel. + + Normally status is RChannelItem_Data, however if the underlying + channel implementation has requested a restart or abort, status + will be RChannelItem_Restart or RChannelItem_Abort + respectively. If WriteItem was called before the channel was + started, or after a termination item had been written, then + status will be RChannelItem_EndOfStream. If status is not + RChannelItem_Data the item may or may not have actually been + written to the output, regardless of whether a flush was + requested (if status is RChannelItem_EndOfStream the item has + definitely not been written to the output). + + If the item marshaler encountered any errors the error + descriptions will be returned in failureArray. This is an array + of items: each element is non-NULL if and only if the + corresponding item in the original list had a marshal error. If + there were no marshaling errors, failureArray will be NULL. + */ + virtual void ProcessWriteArrayCompleted(RChannelItemType status, + RChannelItemArray* + failureArray) = 0; +}; + +class RChannelItemWriterHandler : + public RChannelItemArrayWriterHandler +{ +public: + virtual ~RChannelItemWriterHandler(); + + /* this interface is called whenever an async write (started by + RChannelWriter::WriteItem) completes. If WriteItem was called + with the flush flag set, this handler will not be returned + until the underlying channel has reported that the write has + completed. Otherwise this handler may be called before the item + has been written to the channel. + + Normally status is RChannelItem_Data, however if the underlying + channel implementation has requested a restart or abort, status + will be RChannelItem_Restart or RChannelItem_Abort + respectively. If WriteItem was called before the channel was + started, or after a termination item had been written, then + status will be RChannelItem_EndOfStream. If status is not + RChannelItem_Data the item may or may not have actually been + written to the output, regardless of whether a flush was + requested (if status is RChannelItem_EndOfStream the item has + definitely not been written to the output). + + If the item marshaler encountered an error the error + description will be returned in marshalFailureItem. The callee + does not own a reference to marshalFailureItem. + */ + virtual void ProcessWriteCompleted(RChannelItemType status, + RChannelItem* marshalFailureItem) = 0; + + /* implementation of RChannelItemArrayWriterHandler interface */ + void ProcessWriteArrayCompleted(RChannelItemType status, + RChannelItemArray* failureArray); +}; + +class SyncItemWriterBase +{ +public: + /* A client should call WriteMapOutputItem or + WriteMapOutputItemConsumingReference whenever it wants to send + an item to the output channel. The latter consumes a reference + to item, the former does not. */ + virtual void WriteItemSyncConsumingReference(RChannelItemRef& item) = 0; + void WriteItemSyncConsumingFreeReference(RChannelItem* item); + void WriteItemSync(RChannelItem* item); + virtual DrError GetWriterStatus() = 0; +}; + +/* + The RChannelWriter is the primary mechanism for writing + application-specific structured items to an underlying byte-oriented + channel. + + The application first calls Start, then sends a series of items, + primarily of type RChannelItem_Data, though markers may be + interspersed, followed by a termination item of type + RChannelItem_Restart, RChannelItem_Abort or + RChannelItem_EndOfStream. In the case of a pipe which has not + broken, the termination item will be sent to the remote end so that + the consuming process learns the reason for the pipe closure. + + After sending a termination item no further items will be sent on + the channel until a call to Drain has completed. After this point + the channel may call Start again in the case of a restart, and start + sending items again. + + Each item is marshaled into the byte-stream using an + application-specific marshaler object which must co-operate with the + application so that bare RChannelItem objects can be cast by it into + objects containing meaningful data. + */ +class RChannelWriter : public SyncItemWriterBase +{ +public: + virtual ~RChannelWriter(); + + /* a client must call Start before writing items to the channel + the first time, or after a Drain has completed before writing + to a restarted channel. + */ + virtual void Start() = 0; + + /* WriteItem queues item for async write to the channel. + handler->ProcessWriteCompleted will be called when the async + write "completes." + + If flushAfter is false the handler may be returned before the + item has actually been written on the underlying channel. If + flushAfter is true the handler will not be returned until the + underlying channel has reported that the item has been + written. Setting flushAfter potentially causes the underlying + channel to write a partial buffer and should be used sparingly + where performance is important. + + WriteItem transfers the caller's reference to item to the + channel and it is guaranteed that handler will be called before + the next call to Drain completes. If Start has not yet been + called or a termination item has been queued since the last + call to Start, handler->ProcessWriteCompleted will be called + with status RChannelItem_EndOfStream on the caller's stack + before WriteItem returns. + + Flow control is exercised by having the channel refrain from + calling handler->ProcessWriteCompleted if the underlying + stream has blocked. Therefore the application should limit the + total number of outstanding calls to WriteItem in flight whose + handlers have not yet been returned. Async items are serialized + to a queue and marshaled on a worker thread, and handler is + returned after this marshaling completes, so in the common case + where the underlying channel has not blocked, the application + should allow enough outstanding handlers to account for this + delay in order to avoid "bubbles" in the processing pipeline. + + Once a call to WriteItem passing item A has completed, then any + items which are subsequently submitted to the channel are + guaranteed to appear after A. Beyond this constraint, if calls + to WriteItem and/or WriteItemSync are overlapped the order of + serialization to the channel is undefined. + */ + virtual void WriteItemArray(RChannelItemArrayRef& itemArray, + bool flushAfter, + RChannelItemArrayWriterHandler* handler) = 0; + void WriteItem(RChannelItem* item, + bool flushAfter, + RChannelItemArrayWriterHandler* handler); + + /* WriteItemArraySync submits item to be written to the channel and + may block indefinitely if the underlying channel has blocked. + + If flush is false, WriteItemArraySync may return before the array + has actually been sent to the underlying channel. If flush is + true WriteItemListSync will not return until the underlying + channel reports that item has been written. Setting flush + potentially causes the underlying channel to write a partial + buffer and should be used sparingly where efficiency is + crucial. + + If the marshaler encounters an error when marshaling any item + in the list, and pFailureArray is non-NULL, then pFailureArray + will be filled in with an array of the same size as the + original itemList. This array will contain non-NULL elements + corresponding to any item in itemList which had a marshal + error. If pFailureArray is non-NULL and there are no marshaling + errors, *pFailureArray=NULL on return. + + WriteItemListSync always consunes the caller's reference to + each item in the list and leaves itemList empty. If Start has + not yet been called or a termination item has been queued since + the last call to Start WriteItemListSync returns + RChannelItem_EndOfStream, no item is written to the channel, + and *pFailureArray is guaranteed to be NULL if pFailureArray is + non-NULL. Otherwise, normally status is RChannelItem_Data, + however if the underlying channel implementation has requested + a restart or abort, status will be RChannelItem_Restart or + RChannelItem_Abort respectively. If status is + RChannelItem_Restart or RChannelItem_Abort the item in the list + may or may not have actually been written to the output, + regardless of whether a flush was requested. + + Once a call to WriteItemListSync passing item A has completed, + then any items which are subsequently submitted to the channel + are guaranteed to appear after A. Beyond this constraint, if + calls to WriteItem and/or WriteItemSync are overlapped the + order of serialization to the channel is undefined. + */ + virtual RChannelItemType + WriteItemArraySync(RChannelItemArrayRef& itemArray, + bool flush, + RChannelItemArrayRef* pFailureArray) = 0; + + /* WriteItemSync submits item to be written to the channel and may + block indefinitely if the underlying channel has blocked. + + If flush is false, WriteItemSync may return before item has + actually been sent to the underlying channel. If flush is true + WriteItemSync will not return until the underlying channel + reports that item has been written. Setting flush potentially + causes the underlying channel to write a partial buffer and + should be used sparingly where efficiency is crucial. + + If the marshaler encounters an error when marshaling item and + pMarshalFailureItem is non-NULL, *pMarshalFailureItem will + contain an item describing that error. In this case item has + not been marshaled to the channel, however *pMarshalFailureItem + has been. If there is no marshaling error and + pMarshalFailureItem is non-NULL, *pMarshalFailureItem is NULL. + + WriteItemSync does not consume the caller's reference to + item. If Start has not yet been called or a termination item + has been queued since the last call to Start WriteItemSync + returns RChannelItem_EndOfStream, item is not written to the + channel, and *pMarshalFailureItem is guaranteed to be NULL if + pMarshalFailureItem is non-NULL. Otherwise, normally status is + RChannelItem_Data, however if the underlying channel + implementation has requested a restart or abort, status will be + RChannelItem_Restart or RChannelItem_Abort respectively. If + status is RChannelItem_Restart or RChannelItem_Abort the item + may or may not have actually been written to the output, + regardless of whether a flush was requested. + + Once a call to WriteItemSync passing item A has completed, then + any items which are subsequently submitted to the channel are + guaranteed to appear after A. Beyond this constraint, if calls + to WriteItem and/or WriteItemSync are overlapped the order of + serialization to the channel is undefined. + */ + RChannelItemType + WriteItemSync(RChannelItem* item, + bool flush, + RChannelItemRef* pMarshalFailureItem); + void WriteItemSyncConsumingReference(RChannelItemRef& item); + + /* Drain may not be called until after a termination item has been + written to the channel using a call to WriteItemListSync, + WriteItemSync, WriteItemList or WriteItem which has + return. Drain will not return until the underlying channel has + drained and all outstanding async handlers submitted to + WriteItem or WriteItemList have been returned. At this point + any calls to WriteItemSync or WriteItemListSync will be + unblocked, though of course they may not have returned to the + caller. + + If the application does not wish to block waiting for Drain it + should wait until the handler submitted with the termination + item has returned (or use a blocking call to WriteItemSync or + WriteItemListSync). This handler will not be returned until the + underlying channel has drained, and it is always the last + handler to be returned, so Drain will not block after it has + been returned. + + The call to Drain will block waiting for the "remote end" of + the channel to signal that the channel has been shut down. In + the case of a file or stream this happens immediately. In the + case of a pipe or FIFO this happens only when the reader end + calls Interrupt or Drain, and the item sent to the reader's + Interrupt command is delivered by the writer's Drain and stored + in *pRemoteStatus if pRemoteStatus is non-NULL. The item will + be type RChannelItem_Abort, RChannelItem_Restart, + RChannelItem_ParseError or RChannelItem_EndOfStream. The + writer's call to Drain will block for at most a time interval + of timeOut before returning. If the timeout period expires + before the remote end's drain item is received and + pRemoteStatus is non-NULL, *pRemoteStatus will be set to NULL. + */ + virtual void Drain(DrTimeInterval timeOut, + RChannelItemRef* pRemoteStatus) = 0; + + /* (*pWriterDrainItem) is set to the termination item written to + the channel, or NULL if no termination item has been written + since the most recent call to Start. *pReaderDrainItem is set + to the remote status item returned by the most recent call to + Drain, or NULL if Drain has not been called since the last call + to Start. The return values are undefined if this call overlaps + with a call to Start or Drain, and it is illegal to call + GetTerminationItems after Close has been called. + */ + virtual void GetTerminationItems(RChannelItemRef* pWriterDrainItem, + RChannelItemRef* pReaderDrainItem) = 0; + + /* This returns an error corresponding to the most recent + termination item written to the channel. It is illegal to call + GetTerminationStatus before sending a termination item. */ + DrError GetTerminationStatus(DryadMetaDataRef* pErrorData); + + DrError GetWriterStatus(); + + /* Close may not be called unless Start has never been called or + Drain has completed since the last call to Start. After Close + is called the channel may not be restarted. */ + virtual void Close() = 0; + + /* Get/set a hint about the total length the channel is expected + to be. Some channel implementations can use this to improve + write performance and decrease disk fragmentation. A value of 0 + (the default) means that the size is unknown. */ + virtual UInt64 GetInitialSizeHint() = 0; + virtual void SetInitialSizeHint(UInt64 hint) = 0; + + /* Get the URI of the channel. */ + virtual const char* GetURI() = 0; + + void SetTransformType(TransformType t) { m_transform = t; } + TransformType GetTransformType() { return m_transform; } + +private: + TransformType m_transform; + +}; diff --git a/DryadVertex/VertexHost/system/channel/include/channelitem.h b/DryadVertex/VertexHost/system/channel/include/channelitem.h new file mode 100644 index 0000000..33ef114 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/include/channelitem.h @@ -0,0 +1,294 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include +#include +#include + +/* + Structured data elements passed on restartable channels are called + Items. + + Every item is derived from the RChannelItem base class and has a + type, a sequence number and a metadata description. The type is + assigned by the creator of the item, and is discussed below. The + sequence number is only used for items being read from a channel and + is assigned automatically by the channel reader mechanism. The + metadata is initially populated by the item creator and may be + initialized from the underlying data stream on a read channel, but + may be added to or modified subsequently during processing. + + Items fall into two fundamental classes: data items and marker + items. Data items have type RChannelItem_Data and carry + application-specific data. A data item which is read from a channel + is guaranteed to have been created by an application-specific parser + object which consumes byte-oriented buffers, and may be cast by the + application into a rich type. The sequence number of a data item is + automatically assigned by the channel reader machinery, and data + items read from a channel have unique, densely assigned, increasing + sequence numbers starting at zero. + + All non-data item types are known as marker items and are used to + represent and serialize information such as error conditions and + end-of-stream information. Errors can include information about + "holes" in the underlying data stream (caused e.g. when a subset of + the input data is unavailable) and parse failures. Marker items may + be produced and consumed by a variety of components, and information + is passed between these components using the metadata. There are + convenience methods below for creating standard error marker + items. The sequence number of a marker item is automatically + assigned by the channel reader machinery, and it is equal to the + sequence number of the next data item (or, equivalently, one greater + than the sequence number of the preceding data item). + + On write, all items are passed to an application-specific marshaling + object which knows how to serialize both marker items and data items + with a particular format into byte-oriented buffers. Not all + channels store or transmit marker items, so it is legal for the + marshaler to do nothing when presented with a marker item, but in + general it is expected that applications will make use of marker + items to convey rich monitoring and debugging information on + internal channels between the vertices of a distributed computation. + + A valid complete channel consists of a sequence of "body" items + followed by a single "termination" item. A channel reader will + always produce such a sequence unless interrupted by the application + calling Drain prematurely, and a channel writer should be fed such a + sequence. Termination items have type EndOfStream, Abort, Restart, + ParseError and MarshalError: all other items are body items and may + appear in any sequence in a valid channel. No item may appear in a + channel after a termination item. + + The item constructor is private to allow applications to use private + memory management. Each item supplies a Create method and is + reference-counted: a reference is released by calling DecRef on the + item. +*/ + +enum RChannelItemType { + /* RChannelItem_Data: a data item generated by application + specific code in the parser or the application body. The + application may cast this item to a rich derived type. */ + RChannelItem_Data, + + /* RChannelItem_BufferHole: a marker item supplied by the buffer + reader describing a hole in the underlying data stream. */ + RChannelItem_BufferHole, + + /* RChannelItem_ItemHole: a marker item supplied by the parser + describing a hole in the parsed data stream (usually occurring + because of malformed input data which the parser has skipped + over). */ + RChannelItem_ItemHole, + + /* RChannelItem_EndOfStream: a marker item appearing as the last + item in the stream and signaling clean completion. */ + RChannelItem_EndOfStream, + + /* RChannelItem_Restart: a marker item appearing as the last item + in the stream signaling an error condition and requesting that + the channel be restarted if possible. */ + RChannelItem_Restart, + + /* RChannelItem_Abort: a marker item appearing as the last item in + the stream signaling an unrecoverable error condition. */ + RChannelItem_Abort, + + /* RChannelItem_ParseError: a marker item supplied by the parser + signifying an unrecoverable error. */ + RChannelItem_ParseError, + + /* RChannelItem_MarshalError: a marker item supplied by the parser + signifying an unrecoverable error. */ + RChannelItem_MarshalError +}; + +class RChannelItem; +typedef DrRef RChannelItemRef; + +class DrResettableMemoryReader : public DrMemoryBufferReader +{ +public: + DrResettableMemoryReader(DrMemoryBuffer *pMemoryBuffer); + void ResetToBufferOffset(Size_t offset); +}; + +class RChannelItem : public DrRefCounter +{ +public: + /* returns true if the item can be used as a channel termination + marker, false otherwise */ + static bool IsTerminationItem(RChannelItemType type); + + /* return the type of the item, set at creation time. */ + RChannelItemType GetType(); + + /* The sequence number of the item in the Channel with respect to + all delivered Data items. This sequence number is generated + locally, and so if a Data item appears after any holes in the + stream it may not correspond to the sequence number for this + item when read a second time, or from another reader. + + If the item type is not RChannelItem_Data this number is the + sequence number of the *next* Data item, if any. Data items are + assigned dense, increasing sequence numbers. + */ + UInt64 GetDataSequenceNumber(); + + /* The sequence number of the item in the Channel with respect to + all delivered items. This sequence number is generated locally, + and so if there are any markers or holes in the stream it may + not correspond to the sequence number for this item when read a + second time, or from another reader. Items are assigned dense, + increasing sequence numbers. + */ + UInt64 GetDeliverySequenceNumber(); + + /* this returns a description of the item, used for debugging and + monitoring purposes. When the item is created the metadata is + NULL. + */ + DryadMetaData* GetMetaData(); + + /* these set the item's sequence numbers and are called by the + channel reader */ + void SetDataSequenceNumber(UInt64 dataSequenceNumber); + void SetDeliverySequenceNumber(UInt64 deliverySequenceNumber); + + /* this replaces the item's current metadata with a new object */ + void ReplaceMetaData(DryadMetaData* metaData); + + /* this does a shallow copy. By default this assert-fails, and + concrete Data items which need to be cloned must implement + an item-specific method. */ + virtual void Clone(RChannelItemRef* pClonedItem); + + virtual UInt64 GetNumberOfSubItems() const; + virtual void TruncateSubItems(UInt64 numberOfSubItems); + virtual UInt64 GetItemSize() const; + + virtual DrError DeSerialize(DrResettableMemoryReader* reader, + Size_t availableSize); + virtual DrError DeSerializePartial(DrResettableMemoryReader* reader, + Size_t availableSize); + virtual DrError Serialize(ChannelMemoryBufferWriter* writer); + + /* if the item has metadata which contains a Prop_Dryad_ErrorCode + then this is returned. Otherwise the error depends on the type: + RChannelItem_Data: DrError_OK + RChannelItem_BufferHole: DryadError_BufferHole + RChannelItem_ItemHole: DryadError_ItemHole + RChannelItem_EndOfStream: DrError_EndOfStream + RChannelItem_Restart: DryadError_ChannelRestart + RChannelItem_Abort: DryadError_ChannelAbort + RChannelItem_ParseError: DryadError_ItemParseError + RChannelItem_MarshalError: DryadError_ItemMarshalError + */ + DrError GetErrorFromItem(); + + static const UInt64 s_invalidSequenceNumber = ((UInt64) -1); + + static const UInt32 s_defaultItemBatchSize = 16; + static const UInt32 s_defaultRecordBatchSize = 256; + +protected: + RChannelItem(RChannelItemType type); + virtual ~RChannelItem(); + +private: + RChannelItemType m_type; + UInt64 m_dataSequenceNumber; + UInt64 m_deliverySequenceNumber; + DryadMetaDataRef m_metaData; + DrBListEntry m_listPtr; + friend class DryadBList; +}; + +typedef DryadBList RChannelItemList; + +class RChannelMarkerItem : public RChannelItem +{ +public: + /* used to create a standard marker item. If withMetaData is true, + an empty metadata object is created with the item, otherwise + the item's metadata is left NULL. */ + static RChannelMarkerItem* + Create(RChannelItemType type, bool withMetaData); + + /* this does a shallow copy, which shares the same metadata as the + original (with an increased refcount). */ + virtual void Clone(RChannelItemRef* pClonedItem); + + + /* these convenience methods create marker items with common + metadata fields already filled in */ + + /* fills in the metadata element Prop_Dryad_ErrorCode with the + supplied information */ + static RChannelItem* CreateErrorItem(RChannelItemType itemType, + DrError errorCode); + /* fills in the metadata elements Prop_Dryad_ErrorCode and + Prop_Dryad_ErrorString with the supplied information. */ + static RChannelItem* + CreateErrorItemWithDescription(RChannelItemType itemType, + DrError errorCode, + const char* errorDescription); + +protected: + RChannelMarkerItem(RChannelItemType type); + virtual ~RChannelMarkerItem(); +}; + +class RChannelDataItem : public RChannelItem +{ +public: + virtual UInt64 GetNumberOfSubItems() const; + virtual UInt64 GetItemSize() const; + +protected: + RChannelDataItem(); + virtual ~RChannelDataItem(); +}; + +class RChannelItemArray : public DrRefCounter +{ +public: + RChannelItemArray(); + ~RChannelItemArray(); + + void SetNumberOfItems(UInt32 numberOfItems); + void ExtendNumberOfItems(UInt32 numberOfItems); + UInt32 GetNumberOfItems(); + + void TruncateToSize(UInt32 prefix); + void DiscardPrefix(UInt32 prefix); + + RChannelItemRef* GetItemArray(); + +private: + UInt32 m_numberOfItems; + RChannelItemRef* m_baseItemArray; + RChannelItemRef* m_itemArray; +}; + +typedef DrRef RChannelItemArrayRef; diff --git a/DryadVertex/VertexHost/system/channel/include/channelmarshaler.h b/DryadVertex/VertexHost/system/channel/include/channelmarshaler.h new file mode 100644 index 0000000..5328191 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/include/channelmarshaler.h @@ -0,0 +1,161 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include +#include +#include + +/* The RChannelItemMarshaler is used to supply application-specifc + marshaling code to convert an RChannelItem object to a byte stream. + + It is expected that most marshalers will be stateless, but the + marshaler is called sequentially with each item in turn, so it is + possible to implement marshalers with state if desired. + + If an RChannelItemMarshaler object is passed to only one + RChannelWriter object then calls to its methods will never be + overlapped. Methods are called on worker threads and should compute + for as long as necessary before returning. The worker threads are + used for item parsing, item marshaling, item processing and other + actions. + */ +class RChannelItemMarshalerBase : public IDrRefCounter +{ +public: + RChannelItemMarshalerBase(); + virtual ~RChannelItemMarshalerBase(); + + void SetMaxMarshalBatchSize(UInt32 maxMarshalBatchSize); + UInt32 GetMaxMarshalBatchSize(); + + void SetMarshalerContext(RChannelContext* context); + RChannelContext* GetMarshalerContext(); + + /* this indicates, e.g., which channel on a vertex the marshaler + is attached to. It may not be unique in a given process if + there are e.g. multiple vertices. */ + void SetMarshalerIndex(UInt32 index); + UInt32 GetMarshalerIndex(); + + /* Reset is called before any items are marshaled to a stream. If + a channel is restarted, Reset will be called before starting to + write items again. */ + virtual void Reset(); + + /* MarshalItem is called whenever a new item should be written to + the underlying stream. The item should be marshaled using the + DrMemoryWriter writer in such a way that a matching + RChannelItemParser object can unmarshal it subsequently. + + If the marshaler succeeds in writing the item, it should return + DrError_OK and *pFailureItem will be ignored. It is legal to + write no data for some items or item types if the application + semantics permit this. + + If the marshaler encounters an error when parsing it should + attempt to recover and return a descriptive item of type + RChannelItem_MarshalError in *pFailureItem rather than, for + example, asserting. In this case, any data written to writer + will be discarded, and self will immediately be called back + with the RChannelItem_MarshalError item. Any failure item + returned by this second call will be ignored. + + If MarshalItem returns DryadError_ChannelAbort or + DryadError_ChannelRestart then the appropriate termination item + will be sent to the channel preventing subsequent items from + being marshaled. Any other error code will cause the channel to + continue marshaling items past the error. + + flush is set if the channel will be flushed after this item is + marshaled. This can be ignored by most marshalers. + */ + virtual DrError MarshalItem(ChannelMemoryBufferWriter* writer, + RChannelItem* item, + bool flush, + RChannelItemRef* pFailureItem) = 0; + +private: + UInt32 m_maxMarshalBatchSize; + UInt32 m_index; + RChannelContextRef m_context; +}; + +typedef DrRef RChannelItemMarshalerRef; + +class RChannelItemMarshaler : public RChannelItemMarshalerBase +{ +public: + ~RChannelItemMarshaler(); + DRREFCOUNTIMPL +}; + +class RChannelStdItemMarshalerBase : public RChannelItemMarshalerBase +{ +public: + virtual ~RChannelStdItemMarshalerBase(); + + DrError MarshalItem(ChannelMemoryBufferWriter* writer, + RChannelItem* item, + bool flush, + RChannelItemRef* pFailureItem); + + virtual DrError MarshalMarker(ChannelMemoryBufferWriter* writer, + RChannelItem* item, + bool flush, + RChannelItemRef* pFailureItem); +}; + +class RChannelStdItemMarshaler : public RChannelStdItemMarshalerBase +{ +public: + virtual ~RChannelStdItemMarshaler(); + DRREFCOUNTIMPL +}; + +class DryadMarshalerFactoryBase : public IDrRefCounter +{ +public: + virtual ~DryadMarshalerFactoryBase(); + virtual void MakeMarshaler(RChannelItemMarshalerRef* pMarshaler, + DVErrorReporter* errorReporter) = 0; +}; + +typedef DrRef DryadMarshalerFactoryRef; + +class DryadMarshalerFactory : public DryadMarshalerFactoryBase +{ +public: + virtual ~DryadMarshalerFactory(); + DRREFCOUNTIMPL +}; + +template class StdMarshalerFactory : public DryadMarshalerFactory +{ +public: + typedef _T Marshaler; + void MakeMarshaler(RChannelItemMarshalerRef* pMarshaler, + DVErrorReporter* errorReporter) + { + pMarshaler->Attach(new Marshaler()); + } +}; diff --git a/DryadVertex/VertexHost/system/channel/include/channelmemorybuffers.h b/DryadVertex/VertexHost/system/channel/include/channelmemorybuffers.h new file mode 100644 index 0000000..171c53d --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/include/channelmemorybuffers.h @@ -0,0 +1,141 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +class RChannelBufferData; +class RChannelBufferWriter; + +class ChannelDataBufferList; + +class RChannelReaderBuffer : public DrMemoryBuffer +{ +public: + RChannelReaderBuffer(DryadLockedBufferList* bufferList, + Size_t startOffset, Size_t endOffset); + RChannelReaderBuffer(ChannelDataBufferList* bufferList, + Size_t startOffset, Size_t endOffset); + ~RChannelReaderBuffer(); + + // + // Retrieve pointer to the data stored in memory block at uOffset + // and max size available in this block + // + // Returns NULL if no data at this offset, valid pointer otherwise + // + void *GetDataAddress(Size_t uOffset, + Size_t *puSize, + Size_t *puPriorSize); + + // + // Preallocate enough memory buffers to fix uMaxSize bytes of data. + // + void IncreaseAllocatedSize(Size_t uSize); + +private: + class Current + { + public: + Current(Size_t initialStartOffset, + Size_t finalEndOffset, + DryadLockedMemoryBuffer** bufferArray, + size_t nBuffers, + Size_t offset); + + void* GetDataAddress(DryadLockedMemoryBuffer** bufferArray, + Size_t offset, + Size_t *puSize, + Size_t *puPriorSize); + + private: + void SetTailBufferData(DryadLockedMemoryBuffer** bufferArray, + size_t nBuffers, + Size_t finalEndOffset); + + size_t m_currentBuffer; + Size_t m_currentBufferBase; + Size_t m_currentHeadCutLength; + Size_t m_currentTailOffset; + }; + + void Initialise(Size_t startOffset, Size_t endOffset); + + DryadLockedMemoryBuffer** m_bufferArray; + size_t m_nBuffers; + Size_t m_initialStartOffset; + Size_t m_finalEndOffset; +}; + +class RChannelWriterBuffer : public DrMemoryBuffer +{ +public: + RChannelWriterBuffer(RChannelBufferWriter* bufferProvider, + DryadFixedBufferList* bufferList); + ~RChannelWriterBuffer(); + + // + // Retrieve pointer to the data stored in memory block at uOffset + // and max size available in this block + // + // Returns NULL if no data at this offset, valid pointer otherwise + // + void *GetDataAddress(Size_t uOffset, + Size_t *puSize, + Size_t *puPriorSize); + + // + // Preallocate enough memory buffers to fix uMaxSize bytes of data. + // + void IncreaseAllocatedSize(Size_t uSize); + +protected: + void InternalSetAvailableSize(Size_t uSize); + +private: + DryadFixedBufferList* m_bufferList; + + Size_t m_currentBufferOffset; + Size_t m_currentBaseOffset; + Size_t m_baseBufferOffset; + + Size_t m_availableHighWaterMark; + Size_t m_availableStartOffset; + Size_t m_availableBufferOffset; + DrBListEntry* m_lastAvailableBuffer; + + RChannelBufferWriter* m_bufferProvider; +}; + +class ChannelMemoryBufferWriter : public DrMemoryBufferWriter +{ +public: + ChannelMemoryBufferWriter(DrMemoryBuffer* writer, + DryadFixedBufferList* bufferList); + + bool MarkRecordBoundary(); + Size_t GetLastRecordBoundary(); + +private: + Size_t m_initialBoundary; + Size_t m_lastRecordBoundary; + DryadFixedBufferList* m_bufferList; +}; diff --git a/DryadVertex/VertexHost/system/channel/include/channelparser.h b/DryadVertex/VertexHost/system/channel/include/channelparser.h new file mode 100644 index 0000000..7b518c7 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/include/channelparser.h @@ -0,0 +1,560 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "channelbuffer.h" +#include "dobjpool.h" +#include +#include "recordarray.h" + +class RChannelRawItemParser; + +/* The classes in this header file must be overridden to implement an + application-specific parser which can take application-specific + data from a byte-oriented channel and parse it into RChannelItem + objects. + + At the start of the header is the RChannelRawItemParser class which + may be overridden directly in the unlikely event that an + application needs full control over the parsing process. Almost all + applications will be able to derive from the RChannelItemParser or + RChannelLengthDelimitedItemParser classes which follow. + RChannelLengthDelimitedItemParser is somewhat simpler to use in the + case that the length of items in the channel can be determined by + reading a prefix of the item. + + Assuming that a parser object is only passed to a single + RChannelReader object, no method calls on the parser will ever be + overlapped. Methods are called on worker threads and should compute + for as long as necessary before returning. The worker threads are + used for item parsing, item marshaling, item processing and other + actions. +*/ + +/* RChannelRawItemParser is the raw interface for unmarshaling + structured items from a byte-stream, however the convenience + classes below will be more suitable for most applications. */ +class RChannelItemParserBase : public IDrRefCounter +{ +public: + RChannelItemParserBase(); + virtual ~RChannelItemParserBase(); + + void SetMaxParseBatchSize(UInt32 maxParseBatchSize); + UInt32 GetMaxParseBatchSize(); + + void SetParserContext(RChannelContext* context); + RChannelContext* GetParserContext(); + + /* this indicates, e.g., which channel on a vertex the parser is + attached to. It may not be unique in a given process if there + are e.g. multiple vertices. */ + void SetParserIndex(UInt32 index); + UInt32 GetParserIndex(); + + /* The Item Parser takes a sequence of data buffers from the + underlying Channel and returns a sequence of items parsed out + of the data. + + A single method call, RawParseItem, implements the API between + the Channel and the parser. In the common case, when + RawParseItem is called the parser returns the next item from + whatever buffer data it has cached: it must maintain enough + state to be able to keep track of which the next item is. + + If the parser does not have enough cached data to return the + next item (e.g. an item straddles two or more data buffers) it + signals this on return from the method, and the next call to + RawParseItem will supply the next buffer in sequence. + + The buffers are supplied as a sequence of RChannelBuffer_Data + and RChannelBuffer_Hole buffers followed by a single + RChannelBuffer_EndOfStream though there may be a restart at any + time. The returned items are of type RChannelItem_Data, + RChannelItem_ItemHole, RChannelItem_BufferHole, + RChannelItem_EndOfStream or RChannelItem_ParseError and between + restarts they obey the pattern sequence + + (Data | ItemHole | BufferHole)* (EndOfStream | ParseError) + + The detailed method description follows: + + The code calling the parser behaves as though it stores a + parserState variable which can take on the values NeedsData, + DataReady and Stopped; this variable is updated after every + call to RawParseItem. The rules for calling ItemParser are as + follows: + + 1) At any time RawParseItem may be called with + resetParser=true. This signifies that the entire channel has + been restarted. In this case inData contains a pointer to a + buffer and the parser should discard all saved state and + restart parsing from the beginning of inData. + + 2) Otherwise resetParser=false and the call to RawParseItem + depends on the value of parserState: + + 2a) parserState = NeedsData: *inData contains the next buffer + in sequence + 2b) parserState = DataReady: inData=NULL + 2c) parserState = Stopped: RawParseItem will never be called + + RawParseItem must set *outPrefetchCookie before returning. + + If RawParseItem returns NULL, then parserState is set to + NeedsData and *outPrefetchCookie is delivered to the buffer + reader when the current buffer (the most recent to have been + passed to RawParseItem) is returned. For now *outPrefetchCookie + should always be NULL. The return value may not be NULL if the + current buffer is of type RChannelBuffer_EndOfStream. + + Otherwise *outPrefetchCookie must be NULL and the returned item + must be of type RChannelItem_Data, RChannelItem_ItemHole, + RChannelItem_BufferHole, RChannelItem_ParseError or + RChannelItem_EndOfStream. + + The type of the returned item should be as follows: + + RChannelItem_Data: the item is the next item successfully + parsed from the stream. + + RChannelItem_ItemHole: the parser encountered malformed data + but believes it has successfully skipped over it and + resynchronized on the start of the next well-formed item. The + returned item contains metadata describing the error and + e.g. the amount of data skipped. + + RChannelItem_BufferHole: the parser has been given a buffer + of type RChannelBuffer_Hole. After it has returned any + partial data in the preceding buffers as RChannelItem_Data or + RChannelItem_ItemHole items, it will return the + RChannelItem_BufferHole which is stored in the + RChannelBuffer_Hole object. This contains the buffer + provider's explanation for the hole. + + RChannelItem_EndOfStream: all the data in the stream has been + consumed. The returned item may only have type + RChannelItem_EndOfStream if the current buffer has type + RChannelBuffer_EndOfStream. + + RChannelItem_ParseError: the parser encountered malformed + data, in unable to recover, and will not be able to return + any more items from the stream. The returned item contains + metadata describing the error. + + If the returned item has type RChannelItem_Data, + RChannelItem_BufferHole or RChannelItem_ItemHole then + parserState is set to DataReady before the next call to + RawParseItem, otherwise parserState is set to Stopped. + */ + virtual RChannelItem* RawParseItem(bool restartParser, + RChannelBuffer* inData, + RChannelBufferPrefetchInfo** + outPrefetchCookie /* out */) = 0; + +private: + UInt32 m_maxParseBatchSize; + UInt32 m_index; + RChannelContextRef m_context; +}; + +typedef DrRef RChannelItemParserRef; + +class RChannelRawItemParser : public RChannelItemParserBase +{ +public: + virtual ~RChannelRawItemParser(); + DRREFCOUNTIMPL +}; + +class RChannelItemTransformerBase : public IDrRefCounter +{ +public: + virtual ~RChannelItemTransformerBase(); + + /* this is called once before the first use of the other + transformer methods. The default implementation does + nothing. Initialize can retrieve the parser context, etc. from + the base class. */ + virtual void InitializeTransformer(RChannelItemParserBase* parent, + DVErrorReporter* errorReporter); + + /* this is called once for each item in the input. It should write + zero or more transformed items to writer. The default + implementation writes each input directly back to writer. */ + virtual void TransformItem(RChannelItemRef& inputItem, + SyncItemWriterBase* writer, + DVErrorReporter* errorReporter); + + /* this is called when the input is exhausted. In some cases, + input may be broken up into fragments (e.g. a cosmos stream + with missing extents). In this case, Flush will be called at + the end of each fragment, and will be followed by another + sequence of Map calls to process the next fragment, followed by + another Flush, until the end of the input is reached. The + default implementation does nothing. */ + virtual void FlushTransformer(SyncItemWriterBase* writer, + DVErrorReporter* errorReporter); + + /* this is called before Flush if the input encounters an error, + with the item describing the error (e.g. a corrupt extent in a + cosmos stream). The default implementation does nothing. */ + virtual void ReportTransformerErrorItem(RChannelItemRef& errorItem, + DVErrorReporter* errorReporter); +}; + +typedef DrRef RChannelItemTransformerRef; + +class RChannelItemTransformer : public RChannelItemTransformerBase +{ +public: + virtual ~RChannelItemTransformer(); + + DRREFCOUNTIMPL +}; + + +/* this record contains a single buffer. */ +class RChannelBufferRecord +{ +public: + void SetData(DrMemoryBuffer* data); + DrMemoryBuffer* GetData() const; + +private: + DrRef m_buffer; +}; + +/* this item contains a single buffer, however it is also a record + array that contains a single record holding the same buffer. That + way it can be used in both item-reading and record-reading + interfaces. */ +class RChannelBufferItem : public PackedRecordArray +{ +public: + static RChannelBufferItem* Create(DrMemoryBuffer* buffer); + DrMemoryBuffer* GetData() const; + virtual UInt64 GetItemSize() const; + +private: + RChannelBufferItem(DrMemoryBuffer* buffer); +}; + +class RChannelTransformerParserBase : + public RChannelItemParserBase, + public DVErrorReporter, + public SyncItemWriterBase +{ +public: + RChannelTransformerParserBase(); + virtual ~RChannelTransformerParserBase(); + + void SetTransformer(RChannelItemTransformerBase* transformer); + RChannelItemTransformerBase* GetTransformer(); + + /* the implementation of the base parser interface */ + RChannelItem* RawParseItem(bool restartParser, + RChannelBuffer* inData, + RChannelBufferPrefetchInfo** + outPrefetchCookie /* out */); + + /* the implementation of the item writer interface */ + void WriteItemSyncConsumingReference(RChannelItemRef& item); + DrError GetWriterStatus(); + +private: + RChannelItemTransformerRef m_transformer; + bool m_transformedAny; + RChannelItemList m_itemList; +}; + +class RChannelTransformerParser : public RChannelTransformerParserBase +{ + DRREFCOUNTIMPL +}; + +class RChannelItemParserNoRefImpl : public RChannelItemParserBase +{ +public: + RChannelItemParserNoRefImpl(); + virtual ~RChannelItemParserNoRefImpl(); + + /* ResetParser is called whenever the underlying channel restarts, + and before the first buffer is passed to the parser. On + receiving a call to ResetParser the parser should discard all + internal state and return to the start-of-stream ready + condition. */ + virtual void ResetParser(); + + /* ParseNextItem is called whenever another item is needed, and + the item should be parsed from the buffers in bufferList, + starting at startOffset in the first buffer in the list. The + list will never be empty. It may be convenient to wrap the + bufferList in a DrMemoryBuffer using the RChannelReaderBuffer + class in channelmemorybuffers.h, e.g. + wrapper = new RChannelReaderBuffer(bufferList, + startOffset, + bufferList->back()->GetData()-> + GetAvailableSize()); + + If there is not enough data to parse the next item, + ParseNextItem should return NULL. On the next call the parser + is guaranteed to receive a longer list of buffers starting from + the same point in the stream (with bufferList as a prefix) + unless ResetParser or ParsePartialItem has been called, so it + can cache information about partially parsed items if desired. + + If an item is successfully parsed from availableData it should + be returned with type RChannelItem_Data, and the number of + bytes consumed should be stored in *pOutLength which must be <= + the total available size of the buffers in bufferList - + startOffset. The parser is allowed to behave as if the + application now "owns" this prefix of the available data, + e.g. it can increase the refcounts of buffers in bufferList, + store them inside the returned item, and modify the data within + this prefix. + + If a parse error occurs but the parser can resynchronize at the + start of the next item, ParseNextItem should return an item of + type RChannelItem_ItemHole describing the error and set + *pOutLength to the number of bytes to be skipped before the + start of the next item. + + If a parse error occurs and the parser cannot determine the + start of the next item, ParseNextItem should return an item of + type RChannelItem_ParseError describing the error. In this case + ResetParser will be called before the next call to + ParseNextItem or ParsePartialItem. + */ + virtual RChannelItem* ParseNextItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + Size_t* pOutLength) = 0; + + /* ParsePartialItem is called whenever a buffer of type + RChannelBuffer_Hole or RChannelBuffer_EndOfStream arrives on + the channel and there is unparsed data remaining before this + marker buffer. The remaining data is passed as bufferList and + the buffer is passed as markerBuffer. bufferList may be empty. + + If there is no useful data remaining in bufferList, + ParsePartialItem should return NULL. If the data in bufferList + can be interpreted as a valid item, this item should be + returned as an item of type RChannelItem_Data. If the data is + malformed, ParsePartialItem should return an item of type + RChannelItem_ItemHole describing the malformed data. + + After a call to ParsePartialItem, the parser should probably + not be holding on to any state. If the marker was a buffer + hole, the next call to ParseNextItem will be called with + bufferList starting after the hole in the stream. If the + marker was an end of stream buffer, ResetParser will be called + before the next call to ParseNextItem or ParsePartialItem. + */ + virtual RChannelItem* ParsePartialItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + RChannelBufferMarker* + markerBuffer); + + + + /* this implements the RChannelRawItemParser interface */ + RChannelItem* RawParseItem(bool restartParser, + RChannelBuffer* inData, + RChannelBufferPrefetchInfo** outPrefetchCookie); + +protected: + void ResetParserInternal(); + void DiscardBufferPrefix(Size_t discardLength); + RChannelItem* DealWithPartialBuffer(RChannelBufferMarker* mBuffer); + +private: + bool m_needsReset; + bool m_needsData; + RChannelItemRef m_savedItem; + + ChannelDataBufferList m_bufferList; + + Size_t m_bufferListStartOffset; +}; + +class RChannelItemParser : public RChannelItemParserNoRefImpl +{ +public: + virtual ~RChannelItemParser(); + DRREFCOUNTIMPL +}; + + +class RChannelStdItemParserNoRefImpl : public RChannelItemParserNoRefImpl +{ +public: + RChannelStdItemParserNoRefImpl(DObjFactoryBase* factory); + virtual ~RChannelStdItemParserNoRefImpl(); + + void ResetParser(); + + RChannelItem* ParseNextItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + Size_t* pOutLength); + RChannelItem* ParsePartialItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + RChannelBufferMarker* + markerBuffer); + +private: + DObjFactoryBase* m_factory; + RChannelItemRef m_pendingErrorItem; +}; + +class RChannelStdItemParser : public RChannelStdItemParserNoRefImpl +{ +public: + RChannelStdItemParser(DObjFactoryBase* factory); + virtual ~RChannelStdItemParser(); + DRREFCOUNTIMPL +}; + + +/* many items are stored with an explicit length available near the + start of the item. The RChannelParserExplicitLength convenience + class is the simplest way to write a parser to deal with such + items. +*/ +class RChannelLengthDelimitedItemParserNoRefImpl : + public RChannelItemParserNoRefImpl +{ +public: + enum LengthStatus { + LS_ParseError, + LS_NeedsData, + LS_Ok + }; + + RChannelLengthDelimitedItemParserNoRefImpl(); + virtual ~RChannelLengthDelimitedItemParserNoRefImpl(); + + /* GetNextItemLength should attempt to read the length of the next + item from availableData, assuming the next item starts at the + beginning of availableData. availableData is not growable. + + If buffer does not contain enough data to read out the next + item length, GetNextItemLength should return LS_NeedsData and + *pErrorItem should be set to NULL. + + If the prefix of buffer is malformed (does not contain a valid + prefix of an item), GetNextItemLength should return + LS_ParseError and *pErrorItem should contain an item of type + RChannelItem_ParseError describing the error. + + Otherwise, GetNextItemLength should read the item length from + buffer, store it in *pOutLength, return LS_Ok and set + *pErrorItem to NULL. + */ + virtual LengthStatus GetNextItemLength(DrMemoryBuffer* availableData, + Size_t* pOutLength, + RChannelItem** pErrorItem) = 0; + + /* ParseItemWithLength is called after a successful call to + GetNextItemLength, and after cached data of the required length + has been accumulated from the channel. itemData is not growable + and has its AllocatedSize and AvailableSize set to itemLength, + which is the length returned by the previous call to + GetNextItemLength. + + ParseItemWithLength should parse itemData into an RChannelItem + object of type RChannelItem_Data and return it. The parser is + allowed to behave as if the application now "owns" itemData, + e.g. it can increase itemData's refcount and store it inside + the returned item, or modify the data within itemData. + + If a parse error occurs, ParseItemWithLength should return an + item of type RChannelItem_ItemHole describing the error. + + The RChannelLengthDelimitedItemParser class will automatically + add entries to the returned item's metadata indicating what + range of the underlying channel the item was generated from. + */ + virtual RChannelItem* ParseItemWithLength(DrMemoryBuffer* itemData, + Size_t itemLength) = 0; + + /* these implement the RChannelItemParser interface */ + void ResetParser(); + RChannelItem* ParseNextItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + Size_t* pOutLength); + RChannelItem* ParsePartialItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + RChannelBufferMarker* + markerBuffer); + +private: + void AddMetaData(RChannelItem* item, + ChannelDataBufferList* bufferList, + Size_t startOffset, + Size_t endOffset); + RChannelItem* FetchItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + Size_t tailBufferSize); + RChannelItem* MaybeFetchItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + Size_t* pOutLength); + + Size_t m_itemLength; + Size_t m_accumulatedLength; +}; + +class RChannelLengthDelimitedItemParser : + public RChannelLengthDelimitedItemParserNoRefImpl +{ +public: + virtual ~RChannelLengthDelimitedItemParser(); + DRREFCOUNTIMPL +}; + +class DryadParserFactoryBase : public IDrRefCounter +{ +public: + virtual ~DryadParserFactoryBase(); + virtual void MakeParser(RChannelItemParserRef* pParser, + DVErrorReporter* errorReporter) = 0; +}; + + +typedef DrRef DryadParserFactoryRef; + +class DryadParserFactory : public DryadParserFactoryBase +{ +public: + virtual ~DryadParserFactory(); + DRREFCOUNTIMPL +}; + +template class StdParserFactory : public DryadParserFactory +{ +public: + typedef _T Parser; + void MakeParser(RChannelItemParserRef* pParser, + DVErrorReporter* errorReporter) + { + pParser->Attach(new Parser()); + } +}; + diff --git a/DryadVertex/VertexHost/system/channel/include/concreterchannel.h b/DryadVertex/VertexHost/system/channel/include/concreterchannel.h new file mode 100644 index 0000000..09aba52 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/include/concreterchannel.h @@ -0,0 +1,114 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include +#include +#include + +class WorkQueue; + +class RChannelOpenThrottler; + +class ConcreteRChannel +{ +public: + static bool IsNTFSFile(const char* uri); +//JC static bool IsXComputeFile(const char* uri); +//JC static bool IsDryadStream(const char* uri); +//JC static bool IsDryadPipe(const char* uri); + static bool IsDscStream(const char* uri); + static bool IsDscPartition(const char* uri); + static bool IsHdfsFile(const char* uri); + static bool IsHdfsPartition(const char* uri); + static bool IsAzureBlob(const char* uri); + static bool IsUncPath(const char* uri); + static bool IsFifo(const char* uri); + static bool IsNull(const char* uri); + static bool IsNamedPipe(const char* uri); + static bool IsTidyFSStream(const char* uri); + static const UInt32 s_infiniteFifoBuffer = (UInt32) -1; +}; + +class RChannelReaderHolder : public DrRefCounter +{ +public: + virtual ~RChannelReaderHolder(); + + virtual RChannelReader* GetReader() = 0; + virtual void FillInStatus(DryadInputChannelDescription* status) = 0; + virtual void Close() = 0; +}; + +typedef DrRef RChannelReaderHolderRef; + +class RChannelWriterHolder : public DrRefCounter +{ +public: + virtual ~RChannelWriterHolder(); + + virtual RChannelWriter* GetWriter() = 0; + virtual void FillInStatus(DryadOutputChannelDescription* status) = 0; + virtual void Close() = 0; +}; + +typedef DrRef RChannelWriterHolderRef; + +class RChannelFactory +{ +public: + static DrError OpenReader(const char* channelURI, + DryadMetaData* metaData, + RChannelItemParserBase* parser, + UInt32 numberOfReaders, + RChannelOpenThrottler* openThrottler, + UInt32 maxParseBatchSize, + UInt32 maxParseUnitsInFlight, + WorkQueue* workQueue, + DVErrorReporter* errorReporter, + RChannelReaderHolderRef* pHolder /* out */, + LPDWORD localInputChannels); + + static DrError OpenWriter(const char* channelURI, + DryadMetaData* metaData, + RChannelItemMarshalerBase* marshaler, + UInt32 numberOfWriters, + RChannelOpenThrottler* openThrottler, + UInt32 maxMarshalBatchSize, + WorkQueue* workQueue, + DVErrorReporter* errorReporter, + RChannelWriterHolderRef* pHolder /* out */); + + /* this returns an ID that is guaranteed different to any other + return value from this call in this process, and can be used + when minting internal fifo names to ensure that two vertices + don't accidentally use the same name */ + static UInt32 GetUniqueFifoId(); + + /* this returns a throttler object that can be used to set the + maximum number of concurrently open readers and writers from a + set. When the throttler is no longer needed it should be passed + back to DiscardOpenThrottler. */ + static RChannelOpenThrottler* MakeOpenThrottler(UInt32 maxOpens, + WorkQueue* workQueue); + static void DiscardOpenThrottler(RChannelOpenThrottler* throttler); +}; diff --git a/DryadVertex/VertexHost/system/channel/include/recordarray.h b/DryadVertex/VertexHost/system/channel/include/recordarray.h new file mode 100644 index 0000000..f518f96 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/include/recordarray.h @@ -0,0 +1,485 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "channelparser.h" +#include "channelmarshaler.h" +#include "channelinterface.h" +#include + +class RecordArrayBase : public RChannelDataItem +{ +public: + RecordArrayBase(); + virtual ~RecordArrayBase(); + + /* from the RChannelItem interface */ + UInt64 GetNumberOfSubItems() const; + void TruncateSubItems(UInt64 numberOfSubItems); + virtual UInt64 GetItemSize() const; + + virtual size_t GetRecordSize() const = 0; + + UInt32 GetNumberOfRecords() const; + virtual void SetNumberOfRecords(UInt32 numberOfRecords); + + void* GetRecordArrayUntyped(); + void* GetRecordUntyped(UInt32 index); + void* NextRecordUntyped(); + + UInt32 GetRecordIndex() const; + void SetRecordIndex(UInt32 index); + + bool AtEnd(); + void Truncate(); + void ResetRecordPointer(); + void PopRecord(); + + virtual DrError DeSerialize(DrResettableMemoryReader* reader, + Size_t availableSize); + virtual DrError Serialize(ChannelMemoryBufferWriter* writer); + + void TransferRecord(void* dstRecord, void* srcRecord); + virtual void TransferRecord(RecordArrayBase* dstArray, void* dstRecord, + RecordArrayBase* srcArray, void* srcRecord) = 0; + + void StartSerializing(); + +protected: + void FreeStorage(); + void* TransferTruncatedArray(void* srcArray, UInt32 numberOfRecords); + virtual UInt32 ReadArray(DrMemoryBuffer* buffer, + Size_t startOffset, UInt32 nRecords); + virtual UInt32 AttachArray(DryadLockedMemoryBuffer* buffer, + Size_t startOffset); + virtual DrError ReadFinalArray(DrMemoryBuffer* buffer); + + virtual void* MakeTypedArray(UInt32 numberOfRecords) = 0; + virtual void FreeTypedArray(void* untypedArray) = 0; + + UInt32 m_numberOfRecords; + UInt32 m_nextRecord; + UInt32 m_recordArraySize; + void* m_recordArray; + DrRef m_buffer; + bool m_serializedAny; + + friend class PackedRecordArrayParserBase; + friend class AlternativeRecordMarshalerBase; +}; + +template< class _R > class PackedRecordArray : public RecordArrayBase +{ +public: + typedef _R RecordType; + + virtual ~PackedRecordArray() + { + /* call this from here rather than the base class destructor + since it needs to call FreeTypedArray() which is + inaccessible from the base destructor */ + FreeStorage(); + } + + size_t GetRecordSize() const + { + return sizeof(RecordType); + } + + RecordType* GetRecordArray() const + { + return (RecordType *) m_recordArray; + } + + RecordType* NextRecord() + { + return (RecordType *) NextRecordUntyped(); + } + + virtual void TransferRecord(RecordArrayBase* dstArray, void* dst, + RecordArrayBase* srcArray, void* src) + { + RecordType* dstRecord = (RecordType *) dst; + RecordType* srcRecord = (RecordType *) src; + *dstRecord = *srcRecord; + } + +private: + void* MakeTypedArray(UInt32 numberOfRecords) + { + return new RecordType[numberOfRecords]; + } + + void FreeTypedArray(void* untypedArray) + { + RecordType* array = (RecordType *) untypedArray; + delete [] array; + } +}; + +template< class _R > class RecordArray : public PackedRecordArray<_R> +{ +public: + virtual DrError DeSerialize(DrResettableMemoryReader* reader, + Size_t availableSize) + { + if (GetNumberOfRecords() == 0) + { + SetNumberOfRecords(RChannelItem::s_defaultRecordBatchSize); + } + + Size_t sizeUsed = 0; + Size_t remainingSize = availableSize; + DrError err = DrError_OK; + + ResetRecordPointer(); + RecordType* nextRecord; + while (remainingSize > 0 && (nextRecord = NextRecord()) != NULL) + { + err = nextRecord->DeSerialize(reader, remainingSize, false); + if (err != DrError_OK) + { + PopRecord(); + reader->ResetToBufferOffset(sizeUsed); + break; + } + + sizeUsed = reader->GetBufferOffset(); + LogAssert(sizeUsed <= availableSize); + remainingSize = availableSize - sizeUsed; + } + + Truncate(); + ResetRecordPointer(); + + if (err == DrError_EndOfStream && AtEnd() == false) + { + /* AtEnd() == false after Truncate+Reset means there's at + least one item */ + err = DrError_OK; + } + + return err; + } + + virtual DrError DeSerializePartial(DrResettableMemoryReader* reader, + Size_t availableSize) + { + SetNumberOfRecords(1); + + ResetRecordPointer(); + RecordType* nextRecord = NextRecord(); + LogAssert(nextRecord != NULL); + + DrError err = nextRecord->DeSerialize(reader, availableSize, true); + + ResetRecordPointer(); + + return err; + } + + virtual DrError Serialize(ChannelMemoryBufferWriter* writer) + { + StartSerializing(); + + DrError err = DrError_OK; + Size_t currentPosition = 0; + + RecordType* nextRecord; + bool filledBuffer = writer->MarkRecordBoundary(); + LogAssert(filledBuffer == false); + while (err == DrError_OK && + filledBuffer == false && + (nextRecord = NextRecord()) != NULL) + { + currentPosition = writer->GetBufferOffset(); + err = nextRecord->Serialize(writer); + if (err == DrError_OK) + { + filledBuffer = writer->MarkRecordBoundary(); + } + } + + if (err != DrError_OK) + { + LogAssert(writer->GetStatus() == DrError_OK); + writer->SetBufferOffset(currentPosition); + return DryadError_ItemMarshalError; + } + + if (filledBuffer) + { + return DrError_IncompleteOperation; + } + + return DrError_OK; + } + + virtual void TransferRecord(RecordArrayBase* dstArray, void* dst, + RecordArrayBase* srcArray, void* src) + { + RecordType* dstRecord = (RecordType *) dst; + RecordType* srcRecord = (RecordType *) src; + dstRecord->TransferFrom(*srcRecord); + } + +private: + virtual UInt32 ReadArray(DrMemoryBuffer* buffer, + Size_t startOffset, UInt32 nRecords) + { + LogAssert(false); + return 0; + } + + virtual UInt32 AttachArray(DryadLockedMemoryBuffer* buffer, + Size_t startOffset) + { + LogAssert(false); + return 0; + } +}; + +template< class _R > class RecordArrayEx : public RecordArray<_R> +{ +public: + virtual DrError DeSerialize(DrResettableMemoryReader* reader, + Size_t availableSize) + { + + if (GetNumberOfRecords() == 0) + { + SetNumberOfRecords(RChannelItem::s_defaultRecordBatchSize); + } + + Size_t sizeUsed = 0; + Size_t remainingSize = availableSize; + DrError err = DrError_OK; + + ResetRecordPointer(); + RecordType* nextRecord; + while (remainingSize > 0 && (nextRecord = NextRecord()) != NULL) + { + err = nextRecord->DeSerialize(reader, remainingSize, false); + if (err != DrError_OK) + { + PopRecord(); + reader->ResetToBufferOffset(sizeUsed); + break; + } + + sizeUsed = reader->GetBufferOffset(); + LogAssert(sizeUsed <= availableSize); + remainingSize = availableSize - sizeUsed; + } + + Truncate(); + ResetRecordPointer(); + + if (err == DrError_EndOfStream && AtEnd() == false) + { + /* AtEnd() == false after Truncate+Reset means there's at + least one item */ + err = DrError_OK; + } + + return err; + } + + virtual DrError DeSerializePartial(DrResettableMemoryReader* reader, + Size_t availableSize) + { + SetNumberOfRecords(1); + + ResetRecordPointer(); + RecordType* nextRecord = NextRecord(); + LogAssert(nextRecord != NULL); + + DrError err = nextRecord->DeSerialize(reader, availableSize, true); + + ResetRecordPointer(); + + return err; + } +}; + +template< class _A > class RecordArrayFactory : public DObjFactoryBase +{ +public: + typedef _A ArrayType; + + RecordArrayFactory(UInt32 arraySize) + { + m_arraySize = arraySize; + } + +private: + void* AllocateObjectUntyped() + { + ArrayType* newArray = new ArrayType(); + newArray->SetNumberOfRecords(m_arraySize); + return newArray; + } + + void FreeObjectUntyped(void* object) + { + DrRef typedObject; + typedObject.Attach((ArrayType *) object); + /* let typedObject go out of scope, freeing the reference */ + } + + UInt32 m_arraySize; + + DRREFCOUNTIMPL +}; + +class RecordArrayReaderBase +{ +public: + RecordArrayReaderBase(); + RecordArrayReaderBase(SyncItemReaderBase* reader); + virtual ~RecordArrayReaderBase(); + + void Initialize(SyncItemReaderBase* reader); + + bool Advance(); + UInt32 AdvanceBlock(UInt32 validEntriesRequested); + UInt32 GetValidCount() const; + void PushBack(); + + DrError GetStatus(); + RChannelItem* GetTerminationItem(); + +protected: + void PushBack(bool pushValid); + void DiscardCachedItems(); + virtual bool AdvanceInternal(UInt32 slotNumber); + + void** m_currentRecord; + SyncItemReaderBase* m_reader; + RChannelItemRef m_item; + RecordArrayBase* m_arrayItem; + UInt32 m_cacheSize; + RChannelItemRef* m_itemCache; + UInt32 m_valid; + UInt32 m_cachedItemCount; +}; + +template< class _R > class RecordArrayReader : + public RecordArrayReaderBase +{ +public: + typedef _R RecordType; + + RecordArrayReader() {} + RecordArrayReader(SyncItemReaderBase* reader) : + RecordArrayReaderBase(reader) + { + } + + RecordType* operator->() const + { + LogAssert(GetValidCount() > 0); + return (RecordType *) m_currentRecord[0]; + } + + RecordType& operator*() const + { + LogAssert(GetValidCount() > 0); + return *((RecordType *) m_currentRecord[0]); + } + + RecordType& operator[](UInt32 index) const + { + LogAssert(index < GetValidCount()); + return *((RecordType *) m_currentRecord[index]); + } +}; + +class RecordArrayWriterBase +{ +public: + RecordArrayWriterBase(); + RecordArrayWriterBase(SyncItemWriterBase* writer, + DObjFactoryBase* factory); + virtual ~RecordArrayWriterBase(); + + void Initialize(SyncItemWriterBase* writer, DObjFactoryBase* factory); + void SetWriter(SyncItemWriterBase* writer); + + void MakeValid(); + void MakeValidBlock(UInt32 validEntriesRequested); + UInt32 GetValidCount() const; + void PushBack(); + + void* ReadValidUntyped(UInt32 index); + + void Terminate(); + void Flush(); + + DrError GetWriterStatus(); + +protected: + void** m_currentRecord; + +private: + void Destroy(); + void SendCachedItems(); + void AdvanceInternal(UInt32 slotNumber); + + DrRef m_factory; + SyncItemWriterBase* m_writer; + DrRef m_item; + UInt32 m_cacheSize; + DrRef* m_itemCache; + UInt32 m_valid; + UInt32 m_pushBackIndex; + UInt32 m_cachedItemCount; +}; + +template< class _R > class RecordArrayWriter : public RecordArrayWriterBase +{ +public: + typedef _R RecordType; + + RecordArrayWriter() {} + RecordArrayWriter(SyncItemWriterBase* writer, DObjFactoryBase* factory) : + RecordArrayWriterBase(writer, factory) + { + } + + RecordType* operator->() const + { + LogAssert(GetValidCount() > 0); + return (RecordType *) m_currentRecord[0]; + } + + RecordType& operator*() const + { + LogAssert(GetValidCount() > 0); + return *((RecordType *) m_currentRecord[0]); + } + + RecordType& operator[](UInt32 index) const + { + LogAssert(index < GetValidCount()); + return *((RecordType *) m_currentRecord[index]); + } +}; diff --git a/DryadVertex/VertexHost/system/channel/include/recorditem.h b/DryadVertex/VertexHost/system/channel/include/recorditem.h new file mode 100644 index 0000000..e454cfc --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/include/recorditem.h @@ -0,0 +1,239 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "recordparser.h" + +class RecordBundleInterfaceBase +{ +public: + RecordBundleInterfaceBase() + { + m_marshalerFactory.Attach(new MarshalerFactory()); + } + + virtual ~RecordBundleInterfaceBase() {} + + DryadParserFactoryBase* GetParserFactory() + { + return m_parserFactory; + } + + DryadMarshalerFactoryBase* GetMarshalerFactory() + { + return m_marshalerFactory; + } + + DObjFactoryBase* GetRecordFactory() + { + return m_factory; + } + +protected: + class MarshalerFactory : public DryadMarshalerFactory + { + public: + void MakeMarshaler(RChannelItemMarshalerRef* pMarshaler, + DVErrorReporter* errorReporter) + { + pMarshaler->Attach(new RChannelStdItemMarshaler()); + } + }; + + DrRef m_factory; + DryadParserFactoryRef m_parserFactory; + DryadMarshalerFactoryRef m_marshalerFactory; +}; + +template< class _R > class PackedRecordBundleBase : + public RecordBundleInterfaceBase +{ +public: + typedef _R RecordType; + + typedef PackedRecordArray Array; + typedef RecordArrayReader Reader; + typedef RecordArrayWriter WriterBase; + + class Writer : public WriterBase + { + public: + Writer() {} + Writer(PackedRecordBundleBase<_R>* bundle, SyncItemWriterBase* writer) + { + Initialize(bundle, writer); + } + + void Initialize(PackedRecordBundleBase<_R>* bundle, + SyncItemWriterBase* writer) + { + bundle->InitializeWriter(this, writer); + } + }; + + PackedRecordBundleBase() {} + PackedRecordBundleBase(DObjFactoryBase* factory) + { + InitializeBase(factory); + } + virtual ~PackedRecordBundleBase() {} + + void InitializeBase(DObjFactoryBase* factory) + { + m_factory = factory; + m_parserFactory.Attach(new ParserFactory(m_factory)); + } + + void InitializeWriter(WriterBase* writer, + SyncItemWriterBase* channelWriter) + { + writer->Initialize(channelWriter, m_factory); + } + +private: + class ParserFactory : public DryadParserFactory + { + public: + ParserFactory(DObjFactoryBase* factory) + { + m_factory = factory; + } + + void MakeParser(RChannelItemParserRef* pParser, + DVErrorReporter* errorReporter) + { + pParser->Attach(new PackedRecordArrayParser(m_factory)); + } + + DObjFactoryRef m_factory; + }; +}; + +template< class _R > class PackedRecordBundle : + public PackedRecordBundleBase<_R> +{ +public: + typedef RecordArrayFactory Factory; + + PackedRecordBundle() + { + Initialize(RChannelItem::s_defaultRecordBatchSize); + } + + PackedRecordBundle(UInt32 maxArraySize) + { + Initialize(maxArraySize); + } + + void Initialize(UInt32 maxArraySize) + { + DrRef factory; + factory.Attach(new Factory(maxArraySize)); + InitializeBase(factory); + } +}; + +template< class _R > class RecordBundleBase : + public RecordBundleInterfaceBase +{ +public: + typedef _R RecordType; + + typedef RecordArray Array; + typedef RecordArrayReader Reader; + typedef RecordArrayWriter WriterBase; + + class Writer : public WriterBase + { + public: + Writer() {} + Writer(RecordBundleBase<_R>* bundle, SyncItemWriterBase* writer) + { + Initialize(bundle, writer); + } + + void Initialize(RecordBundleBase<_R>* bundle, + SyncItemWriterBase* writer) + { + bundle->InitializeWriter(this, writer); + } + }; + + RecordBundleBase() {} + RecordBundleBase(DObjFactoryBase* factory) + { + InitializeBase(factory); + } + virtual ~RecordBundleBase() {} + + void InitializeBase(DObjFactoryBase* factory) + { + m_factory = factory; + m_parserFactory.Attach(new ParserFactory(m_factory)); + } + + void InitializeWriter(WriterBase* writer, + SyncItemWriterBase* channelWriter) + { + writer->Initialize(channelWriter, m_factory); + } + +private: + class ParserFactory : public DryadParserFactory + { + public: + ParserFactory(DObjFactoryBase* factory) + { + m_factory = factory; + } + + void MakeParser(RChannelItemParserRef* pParser, + DVErrorReporter* errorReporter) + { + pParser->Attach(new RChannelStdItemParser(m_factory)); + } + + DObjFactoryRef m_factory; + }; +}; + +template< class _R > class RecordBundle : public RecordBundleBase<_R> +{ +public: + typedef RecordArrayFactory Factory; + + RecordBundle() + { + Initialize(RChannelItem::s_defaultRecordBatchSize); + } + + RecordBundle(UInt32 maxArraySize) + { + Initialize(maxArraySize); + } + + void Initialize(UInt32 maxArraySize) + { + DrRef factory; + factory.Attach(new Factory(maxArraySize)); + InitializeBase(factory); + } +}; diff --git a/DryadVertex/VertexHost/system/channel/include/recordparser.h b/DryadVertex/VertexHost/system/channel/include/recordparser.h new file mode 100644 index 0000000..e3b3e63 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/include/recordparser.h @@ -0,0 +1,201 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "channelparser.h" + +class PackedRecordArrayParserBase : public RChannelItemParserNoRefImpl +{ +public: + PackedRecordArrayParserBase(DObjFactoryBase* factory); + virtual ~PackedRecordArrayParserBase(); + + RChannelItem* ParseNextItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + Size_t* pOutLength); + RChannelItem* ParsePartialItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + RChannelBufferMarker* markerBuffer); + +private: + DrRef m_factory; +}; + +class PackedRecordArrayParser : public PackedRecordArrayParserBase +{ +public: + PackedRecordArrayParser(DObjFactoryBase* factory); + virtual ~PackedRecordArrayParser(); + DRREFCOUNTIMPL +}; + +typedef DrError RecordDeSerializerFunction(void* record, + DrMemoryBufferReader* reader, + Size_t availableSize, + bool lastRecordInStream); + +class AlternativeRecordParserBase : public RChannelItemParserNoRefImpl +{ +public: + AlternativeRecordParserBase(DObjFactoryBase* factory); + AlternativeRecordParserBase(DObjFactoryBase* factory, + RecordDeSerializerFunction* function); + virtual ~AlternativeRecordParserBase(); + + void ResetParser(); + + RChannelItem* ParseNextItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + Size_t* pOutLength); + RChannelItem* ParsePartialItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + RChannelBufferMarker* + markerBuffer); + +private: + DrError DeSerializeArray(RecordArrayBase* array, + DrResettableMemoryReader* reader, + Size_t availableSize); + virtual DrError DeSerializeUntyped(void* record, + DrMemoryBufferReader* reader, + Size_t availableSize, + bool lastRecordInStream); + + DObjFactoryBase* m_factory; + RChannelItemRef m_pendingErrorItem; + RecordDeSerializerFunction* m_function; +}; + +typedef DrRef AlternativeRecordParserRef; + +class UntypedAlternativeRecordParser : public AlternativeRecordParserBase +{ +public: + UntypedAlternativeRecordParser(DObjFactoryBase* factory, + RecordDeSerializerFunction* function); + + DRREFCOUNTIMPL +}; + +class StdAlternativeRecordParserFactory : public DryadParserFactoryBase +{ +public: + StdAlternativeRecordParserFactory(DObjFactoryBase* factory, + RecordDeSerializerFunction* function); + + void MakeParser(RChannelItemParserRef* pParser, + DVErrorReporter* errorReporter); + +private: + DObjFactoryBase* m_factory; + RecordDeSerializerFunction* m_function; + + DRREFCOUNTIMPL +}; + +template< class _R > class AlternativeRecordParser : + public AlternativeRecordParserBase +{ +public: + typedef _R RecordType; + + AlternativeRecordParser(DObjFactoryBase* factory); + + virtual DrError DeSerialize(RecordType* record, + DrMemoryBufferReader* reader, + Size_t availableSize, + bool lastRecordInStream) = 0; + +private: + DrError DeSerializeUntyped(void* record, + DrMemoryBufferReader* reader, + Size_t availableSize, + bool lastRecordInStream) + { + return DeSerialize((RecordType *) record, reader, + availableSize, lastRecordInStream); + } +}; + + +typedef DrError RecordSerializerFunction(void* record, + DrMemoryBufferWriter* writer); + +class AlternativeRecordMarshalerBase : public RChannelItemMarshalerBase +{ +public: + AlternativeRecordMarshalerBase(); + AlternativeRecordMarshalerBase(RecordSerializerFunction* function); + virtual ~AlternativeRecordMarshalerBase(); + + void SetFunction(RecordSerializerFunction* function); + + DrError MarshalItem(ChannelMemoryBufferWriter* writer, + RChannelItem* item, + bool flush, + RChannelItemRef* pFailureItem); + +private: + virtual DrError SerializeUntyped(void* record, + DrMemoryBufferWriter* writer); + + RecordSerializerFunction* m_function; +}; + +typedef DrRef AlternativeRecordMarshalerRef; + +class UntypedAlternativeRecordMarshaler : public AlternativeRecordMarshalerBase +{ +public: + UntypedAlternativeRecordMarshaler(RecordSerializerFunction* function); + + DRREFCOUNTIMPL +}; + +class StdAlternativeRecordMarshalerFactory : public DryadMarshalerFactoryBase +{ +public: + StdAlternativeRecordMarshalerFactory(RecordSerializerFunction* function); + + void MakeMarshaler(RChannelItemMarshalerRef* pMarshaler, + DVErrorReporter* errorReporter); + +private: + RecordSerializerFunction* m_function; + + DRREFCOUNTIMPL +}; + +template< class _R > class AlternativeRecordMarshaler : + public AlternativeRecordMarshalerBase +{ +public: + typedef _R RecordType; + + virtual DrError Serialize(RecordType* record, + DrMemoryBufferWriter* writer) = 0; + +private: + DrError SerializeUntyped(void* record, DrMemoryBufferWriter* writer) + { + return Serialize((RecordType *) record, writer); + } +}; diff --git a/DryadVertex/VertexHost/system/channel/src/channelbuffer.cpp b/DryadVertex/VertexHost/system/channel/src/channelbuffer.cpp new file mode 100644 index 0000000..ff4a185 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelbuffer.cpp @@ -0,0 +1,196 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include + +#pragma unmanaged + + +RChannelBuffer::RChannelBuffer(RChannelBufferType type) +{ + m_type = type; + DryadMetaData::Create(&m_metaData); +} + +RChannelBuffer::~RChannelBuffer() +{ +} + +DryadMetaData* RChannelBuffer::GetMetaData() +{ + return m_metaData.Ptr(); +} + +void RChannelBuffer::ReplaceMetaData(DryadMetaData* metaData) +{ + m_metaData.Set(metaData); +} + +bool RChannelBuffer::IsTerminationBuffer(RChannelBufferType type) +{ + return (type == RChannelBuffer_Restart || + type == RChannelBuffer_Abort || + type == RChannelBuffer_EndOfStream); +} + +RChannelBufferType RChannelBuffer::GetType() +{ + return m_type; +} + +RChannelBufferData::RChannelBufferData() : RChannelBuffer(RChannelBuffer_Data) +{ +} + +RChannelBufferData::~RChannelBufferData() +{ +} + +// +// Return structure that contains item +// +RChannelBufferData* ChannelDataBufferList::CastOut(DrBListEntry* item) +{ + return DR_GET_CONTAINER(RChannelBufferData, item, m_dataListPtr); +} + +// +// Return pointer to data +// +DrBListEntry* ChannelDataBufferList::CastIn(RChannelBufferData* item) +{ + return &(item->m_dataListPtr); +} + +RChannelBufferDataDefault:: + RChannelBufferDataDefault(DryadLockedMemoryBuffer* dataBuffer, + UInt64 startOffset, + RChannelBufferDefaultHandler* parent) +{ + m_dataBuffer = dataBuffer; + m_startOffset = startOffset; + m_parent = parent; +} + +RChannelBufferDataDefault::~RChannelBufferDataDefault() +{ + m_dataBuffer->DecRef(); +} + +RChannelBufferDataDefault* + RChannelBufferDataDefault::Create(DryadLockedMemoryBuffer* dataBuffer, + UInt64 startOffset, + RChannelBufferDefaultHandler* parent) +{ + return new RChannelBufferDataDefault(dataBuffer, startOffset, + parent); +} + +DryadLockedMemoryBuffer* RChannelBufferDataDefault::GetData() +{ + return m_dataBuffer; +} + +void RChannelBufferDataDefault:: + GetOffsetMetaData(bool isStart, + UInt64 offset, + DryadMetaDataRef* dstMetaData) +{ + DryadMetaDataRef m; + GetMetaData()->Clone(&m); + + DryadMTagRef tag; + tag.Attach(DryadMTagUInt64::Create((isStart) ? + Prop_Dryad_ItemBufferStartOffset : + Prop_Dryad_ItemBufferEndOffset, + offset)); + bool brc = m->Append(tag, false); + LogAssert(brc == true); + + UInt64 streamOffset = + (m_startOffset == RCHANNEL_BUFFER_OFFSET_UNDEFINED) ? + RCHANNEL_BUFFER_OFFSET_UNDEFINED : offset + m_startOffset; + + tag.Attach(DryadMTagUInt64::Create((isStart) ? + Prop_Dryad_ItemStreamStartOffset : + Prop_Dryad_ItemStreamEndOffset, + streamOffset)); + brc = m->Append(tag, false); + LogAssert(brc == true); + + *dstMetaData = m; +} + +// +// Mark a buffer complete +// +void RChannelBufferDataDefault::ProcessingComplete(RChannelBufferPrefetchInfo* + /* unused prefetchCookie*/) +{ + m_parent->ReturnBuffer(this); +} + +RChannelBufferMarker::RChannelBufferMarker(RChannelBufferType type, + RChannelItem* item) : + RChannelBuffer(type) +{ + LogAssert(item != NULL); + m_item = item; +} + +RChannelBufferMarker::~RChannelBufferMarker() +{ + m_item->DecRef(); +} + +RChannelItem* RChannelBufferMarker::GetItem() +{ + LogAssert(m_item != NULL); + return m_item; +} + +RChannelBufferMarkerDefault:: + RChannelBufferMarkerDefault(RChannelBufferType type, + RChannelItem* item, + RChannelBufferDefaultHandler* parent) : + RChannelBufferMarker(type, item) +{ + m_parent = parent; +} + +RChannelBufferMarkerDefault::~RChannelBufferMarkerDefault() +{ +} + +RChannelBufferMarkerDefault* + RChannelBufferMarkerDefault::Create(RChannelBufferType type, + RChannelItem* item, + RChannelBufferDefaultHandler* parent) +{ + return new RChannelBufferMarkerDefault(type, item, parent); +} + +void RChannelBufferMarkerDefault:: + ProcessingComplete(RChannelBufferPrefetchInfo* + /* unused prefetchCookie*/) +{ + m_parent->ReturnBuffer(this); +} diff --git a/DryadVertex/VertexHost/system/channel/src/channelbufferhdfs.cpp b/DryadVertex/VertexHost/system/channel/src/channelbufferhdfs.cpp new file mode 100644 index 0000000..d2d326e --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelbufferhdfs.cpp @@ -0,0 +1,1325 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "channelbufferhdfs.h" + +#include + +#pragma unmanaged + +const char* RChannelBufferHdfsReader::s_hdfsPartitionPrefix = "hpchdfspt://"; +const char* RChannelBufferHdfsWriter::s_hdfsFilePrefix = "hpchdfs://"; + +static long s_readBufferSize = 2 * 1024 * 1024; +static long s_writeBufferSize = 256 * 1024; +static LONG s_maxBuffersOut = 4; +static UInt32 s_maxBuffersToBlockWriter = 4; + +static bool +ExtractHdfsReadUri(DrStr64 uri, + DrStr64& headNode, Int32* pHdfsPort, + DrStr64& filePath, + Int64* pOffsetStart, Int32* pLength) +{ + if (!uri.StartsWith(RChannelBufferHdfsReader::s_hdfsPartitionPrefix)) + { + return false; + } + + char* cHeadNode = + uri.GetWritableBuffer(uri.GetLength(), + strlen(RChannelBufferHdfsReader:: + s_hdfsPartitionPrefix)); + + char* colon = strchr(cHeadNode, ':'); + if (colon == NULL) + { + return false; + } + *colon = '\0'; + + char* cPort = colon+1; + + char* slash = strchr(cPort, '/'); + if (slash == NULL) + { + return false; + } + *slash = '\0'; + + char* cPath = slash+1; + + char* quest = strchr(cPath, '?'); + if (quest == NULL) + { + return false; + } + *quest = '\0'; + + char* cOffset = quest+1; + + quest = strchr(cOffset, '?'); + if (quest == NULL) + { + return false; + } + *quest = '\0'; + + char* cSize = quest+1; + + headNode.Set(cHeadNode); + + DrError dre = DrStringToInt32(cPort, pHdfsPort); + if (dre != DrError_OK) + { + return false; + } + + filePath.Set(cPath); + + dre = DrStringToInt64(cOffset, pOffsetStart); + if (dre != DrError_OK) + { + return false; + } + + dre = DrStringToInt32(cSize, pLength); + if (dre != DrError_OK) + { + return false; + } + + return true; +} + +RChannelBufferHdfsReader::RChannelBufferHdfsReader(const char* uri) +{ + m_uri.Set(uri); + m_handler = NULL; + m_readThread = INVALID_HANDLE_VALUE; + m_abortHandle = INVALID_HANDLE_VALUE; + m_blockSemaphore = CreateSemaphore(NULL, + s_maxBuffersOut, + s_maxBuffersOut, + NULL); + + /* it's important to initialize here since these are called + sequentially so there's no race on the crappy hdfs + initialization code. Then we need to connect to the server, + since that has to happen first on the thread that creates the + jvm, in order for login to hdfs to work, for unexplained + reasons. */ + bool initialized = HdfsBridgeNative::Initialize(); + if (initialized) + { + DrStr64 headNode; + Int32 hdfsPort; + DrStr64 filePath; + Int64 offsetStart; + Int32 length; + + bool parsed = ExtractHdfsReadUri(m_uri, + headNode, &hdfsPort, + filePath, &offsetStart, &length); + if (parsed) + { + HdfsBridgeNative::Instance* bridge; + bool openedInstance = + HdfsBridgeNative::OpenInstance(headNode.GetString(), + hdfsPort, + &bridge); + if (openedInstance) + { + HdfsBridgeNative::InstanceAccessor ia(bridge); + ia.Dispose(); + } + } + } + +} + +RChannelBufferHdfsReader::~RChannelBufferHdfsReader() +{ + if (m_readThread != INVALID_HANDLE_VALUE) + { + CloseHandle(m_readThread); + } + if (m_abortHandle != INVALID_HANDLE_VALUE) + { + CloseHandle(m_abortHandle); + } +} + +void RChannelBufferHdfsReader:: +Start(RChannelBufferPrefetchInfo* /*unused prefetchCookie*/, + RChannelBufferReaderHandler* handler) +{ + LogAssert(m_handler == NULL); + m_handler = handler; + + LogAssert(m_readThread == INVALID_HANDLE_VALUE); + LogAssert(m_abortHandle == INVALID_HANDLE_VALUE); + + { + AutoCriticalSection acs(&m_cs); + + m_totalLength = 0; + m_processedLength = 0; + } + + m_abortHandle = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(m_abortHandle != NULL); + + m_readThread = + (HANDLE) ::_beginthreadex(NULL, + 0, + RChannelBufferHdfsReader::ThreadFunc, + this, + 0, + NULL); + LogAssert(m_readThread != 0); +} + +void RChannelBufferHdfsReader::Interrupt() +{ + /* tell the read thread to stop reading */ + BOOL bRet = ::SetEvent(m_abortHandle); + LogAssert(bRet != 0); + + /* then wait for it to exit */ + DWORD dRet = ::WaitForSingleObject(m_readThread, INFINITE); + LogAssert(dRet == WAIT_OBJECT_0); +} + +void RChannelBufferHdfsReader::Drain(RChannelItem* /* unused drainItem */) +{ + Interrupt(); + + for (LONG i=0; iSetChannelTotalLength(m_totalLength); + s->SetChannelProcessedLength(m_processedLength); +} + +bool RChannelBufferHdfsReader::GetTotalLength(UInt64* pLen) +{ + AutoCriticalSection acs(&m_cs); + + *pLen = m_totalLength; + + return true; +} + +void RChannelBufferHdfsReader::ReturnBuffer(RChannelBuffer* buffer) +{ + /* discard buffer */ + buffer->DecRef(); + + BOOL bRet = ReleaseSemaphore(m_blockSemaphore, 1, NULL); + LogAssert(bRet != 0); +} + +unsigned __stdcall RChannelBufferHdfsReader::ThreadFunc(void* arg) +{ + RChannelBufferHdfsReader* self = (RChannelBufferHdfsReader *) arg; + self->ReadThread(); + return 0; +} + + +static RChannelItem* +MakeErrorItem(DrError errorCode, const char* description) +{ + RChannelItem* item = + RChannelMarkerItem::Create(RChannelItem_Abort, true); + DryadMetaData* metaData = item->GetMetaData(); + metaData->AddErrorWithDescription(errorCode, description); + + return item; +} + +static RChannelBuffer* +MakeErrorBuffer(DrError errorCode, const char* description, + RChannelBufferDefaultHandler* handler) +{ + RChannelItem* item = MakeErrorItem(errorCode, description); + + RChannelBuffer* errorBuffer = + RChannelBufferMarkerDefault::Create(RChannelBuffer_Abort, + item, + handler); + DryadMetaData* metaData = errorBuffer->GetMetaData(); + metaData->AddErrorWithDescription(errorCode, description); + + return errorBuffer; +} + + +// +// Make an end-of-stream buffer +// +static RChannelItem* MakeEndOfStreamItem() +{ + return RChannelMarkerItem::Create(RChannelItem_EndOfStream, false); +} + +static RChannelBuffer* + MakeEndOfStreamBuffer(RChannelBufferDefaultHandler* handler) +{ + RChannelItem* item = MakeEndOfStreamItem(); + + RChannelBuffer* buffer = + RChannelBufferMarkerDefault::Create(RChannelBuffer_EndOfStream, + item, + handler); + + return buffer; +} + +// +// Create a data buffer for the reader to use +// +static RChannelBufferData* +MakeDataBuffer(UInt64 streamOffset, size_t blockSize, + RChannelBufferDefaultHandler* handler) +{ + DryadAlignedReadBlock* block = + new DryadAlignedReadBlock(blockSize, 0); + RChannelBufferData* dataBuffer = + RChannelBufferDataDefault::Create(block, + streamOffset, + handler); + + DryadMetaData* metaData = dataBuffer->GetMetaData(); + DryadMTagRef tag; + tag.Attach(DryadMTagUInt64::Create(Prop_Dryad_BufferLength, + block->GetAvailableSize())); + metaData->Append(tag, false); + + return dataBuffer; +} + +void RChannelBufferHdfsReader::SendBuffer(RChannelBuffer* buffer, + bool getSemaphore) +{ + if (getSemaphore) + { + HANDLE h[2]; + h[0] = m_abortHandle; + h[1] = m_blockSemaphore; + + DWORD dRet = WaitForMultipleObjects(2, h, FALSE, INFINITE); + if (dRet == WAIT_OBJECT_0) + { + /* we should discard the buffer and exit since we're + shutting down */ + buffer->DecRef(); + return; + } + else + { + LogAssert(dRet == WAIT_OBJECT_0+1); + } + } + + m_handler->ProcessBuffer(buffer); +} + +Int64 RChannelBufferHdfsReader:: +AdjustStartOffset(HdfsBridgeNative::Reader* reader, + const char* fileName, + Int64 startOffset, + Int64 endOffset) +{ + if (startOffset > 0) + { + /* this isn't the first block in the file, so scan to + the start of the next record. If the next record + starts at offsetEnd+1 or later then it will be + picked up by the next block reader, so don't keep + looking past there. */ + RChannelBuffer* errorBuffer = NULL; + Int64 offset = ScanForSync(reader, fileName, + startOffset, endOffset+1, &errorBuffer); + + if (offset < -1) + { + /* there was a read error */ + LogAssert(errorBuffer != NULL); + SendBuffer(errorBuffer, true); + } + else if (offset == -1) + { + /* there was no record starting in the selected range */ + LogAssert(errorBuffer == NULL); + DrLogI("Hdfs skipped block %I64d::%I64d because no record sync was found", + startOffset, endOffset); + errorBuffer = MakeEndOfStreamBuffer(this); + SendBuffer(errorBuffer, true); + } + else + { + LogAssert(errorBuffer == NULL); + LogAssert(offset <= endOffset); + DrLogI("Hdfs skipped from %I64d to start at new record at %I64d", + startOffset, offset); + } + + return offset; + } + else + { + LogAssert(startOffset == 0); + DrLogI("Hdfs starting first block at offset 0"); + return startOffset; + } +} + +Int64 RChannelBufferHdfsReader:: +AdjustEndOffset(HdfsBridgeNative::Reader* reader, + const char* fileName, + Int64 endOffset) +{ + RChannelBuffer* errorBuffer = NULL; + Int64 newOffset = ScanForSync(reader, fileName, + endOffset, -1, + &errorBuffer); + + if (newOffset < -1) + { + /* there was a read error */ + LogAssert(errorBuffer != NULL); + SendBuffer(errorBuffer, true); + } + else + { + LogAssert(newOffset >= endOffset); + DrLogI("HDFS file %s scanned past end of block from %I64d to %I64d", + fileName, + endOffset, newOffset); + } + + return newOffset; +} + +Int64 RChannelBufferHdfsReader:: +ReadDataBuffer(HdfsBridgeNative::ReaderAccessor& ra, + const char* fileName, + Int64 offset, + Int64 endOffset) +{ + Int32 sizeToRead = s_readBufferSize; + Int64 sizeLeft = endOffset - offset; + if (sizeLeft < sizeToRead) + { + sizeToRead = (Int32) sizeLeft; + } + + LogAssert(sizeToRead > 0); + + RChannelBuffer* buffer; + + RChannelBufferData* dataBuffer = + MakeDataBuffer((UInt64) offset, (size_t) sizeToRead, this); + + DrMemoryBuffer* block = dataBuffer->GetData(); + Size_t available; + void* dst = block->GetDataAddress(0, &available, NULL); + LogAssert(available >= sizeToRead); + + DrLogI("Reading HDFS file %s range %I64d:%d", + fileName, offset, sizeToRead); + long bytesRead = + ra.ReadBlock(offset, (char *) dst, sizeToRead); + + if (bytesRead < -1) + { + char* errorMsg = ra.GetExceptionMessage(); + DrStr64 description; + description.SetF("Can't read HDFS file '%s' at offset %I64d:%d: %s", + fileName, + offset, sizeToRead, errorMsg); + HdfsBridgeNative::DisposeString(errorMsg); + + buffer = + MakeErrorBuffer(DryadError_ChannelReadError, + description.GetString(), + this); + + dataBuffer->DecRef(); + offset = -1; + } + else if (bytesRead == -1) + { + DrStr64 description; + description.SetF("HDFS file '%s' got EOF at offset %I64d:%d", + fileName, + offset, sizeToRead); + + buffer = + MakeErrorBuffer(DryadError_ChannelReadError, + description.GetString(), + this); + + dataBuffer->DecRef(); + offset = -1; + } + else if (bytesRead != sizeToRead) + { + DrStr64 description; + description.SetF("HDFS file '%s' got too few bytes %d at offset %I64d:%d", + fileName, + bytesRead, + offset, sizeToRead); + + buffer = + MakeErrorBuffer(DryadError_ChannelReadError, + description.GetString(), + this); + + dataBuffer->DecRef(); + offset = -1; + } + else + { + buffer = dataBuffer; + offset += bytesRead; + } + + SendBuffer(buffer, false); + + return offset; +} + +void RChannelBufferHdfsReader::ReadThread() +{ + bool initialized = HdfsBridgeNative::Initialize(); + if (!initialized) + { + DrStr64 description; + description.Set("Can't initialize HDFS bridge"); + RChannelBuffer* error = + MakeErrorBuffer(DryadError_ChannelOpenError, + description.GetString(), + this); + SendBuffer(error, true); + return; + } + + DrStr64 headNode; + Int32 hdfsPort; + DrStr64 filePath; + Int64 offsetStart; + Int32 length; + + bool parsed = ExtractHdfsReadUri(m_uri, + headNode, &hdfsPort, + filePath, &offsetStart, &length); + if (!parsed) + { + DrStr64 description; + description.SetF("Can't parse HDFS URI '%s'", m_uri.GetString()); + RChannelBuffer* error = + MakeErrorBuffer(DryadError_InvalidChannelURI, + description.GetString(), + this); + SendBuffer(error, true); + return; + } + + { + AutoCriticalSection acs(&m_cs); + m_totalLength = length; + } + + HdfsBridgeNative::Instance* bridge; + bool openedInstance = + HdfsBridgeNative::OpenInstance(headNode.GetString(), + hdfsPort, + &bridge); + if (!openedInstance) + { + DrStr64 description; + description.SetF("Can't open HDFS Bridge '%s:%d'", + headNode.GetString(), hdfsPort); + RChannelBuffer* error = + MakeErrorBuffer(DryadError_ChannelOpenError, + description.GetString(), + this); + SendBuffer(error, true); + return; + } + + HdfsBridgeNative::InstanceAccessor ia(bridge); + + HdfsBridgeNative::Reader* reader; + bool openedReader = ia.OpenReader(filePath.GetString(), &reader); + if (!openedReader) + { + char* errorMsg = ia.GetExceptionMessage(); + DrStr64 description; + description.SetF("Can't open HDFS file '%s': %s", + filePath.GetString(), errorMsg); + HdfsBridgeNative::DisposeString(errorMsg); + + RChannelBuffer* error = + MakeErrorBuffer(DryadError_ChannelOpenError, + description.GetString(), + this); + SendBuffer(error, true); + + ia.Dispose(); + return; + } + + Int64 offsetEnd = offsetStart + length; + Int64 offset = AdjustStartOffset(reader, filePath.GetString(), + offsetStart, offsetEnd); + bool scannedFinal = false; + + if (offset < 0) + { + /* nothing to read here: AdjustStartOffset already sent + the termination item so we can exit */ + ia.Dispose(); + return; + } + + if (offset == offsetEnd) + { + offsetEnd = AdjustEndOffset(reader, filePath.GetString(), + offsetEnd); + if (offsetEnd < 0) + { + /* there was a read error: AdjustEndOffset already sent + the termination item so we can exit */ + ia.Dispose(); + return; + } + + scannedFinal = true; + } + + { + AutoCriticalSection acs(&m_cs); + + offsetStart = offset; + LogAssert(offsetEnd >= offsetStart); + m_totalLength = offsetEnd - offsetStart; + } + + HdfsBridgeNative::ReaderAccessor ra(reader); + + while (offset >=0 && offset < offsetEnd) + { + HANDLE h[2]; + h[0] = m_abortHandle; + h[1] = m_blockSemaphore; + + DWORD dRet = WaitForMultipleObjects(2, h, FALSE, INFINITE); + if (dRet == WAIT_OBJECT_0) + { + /* we should exit */ + offset = -1; + break; + } + else + { + LogAssert(dRet == WAIT_OBJECT_0+1); + } + + /* just check we aren't aborted anyway */ + dRet = WaitForSingleObject(m_abortHandle, 0); + if (dRet == WAIT_OBJECT_0) + { + /* give back the semaphore we just took */ + BOOL bRet = ReleaseSemaphore(m_blockSemaphore, 1, NULL); + LogAssert(bRet != 0); + + /* we should exit */ + offset = -1; + break; + } + else + { + LogAssert(dRet == WAIT_TIMEOUT); + } + + offset = ReadDataBuffer(ra, filePath.GetString(), + offset, offsetEnd); + if (offset >= 0) + { + AutoCriticalSection acs(&m_cs); + + m_processedLength = offset - offsetStart; + } + + if (offset == offsetEnd && !scannedFinal) + { + offsetEnd = AdjustEndOffset(reader, filePath.GetString(), + offsetEnd); + if (offsetEnd < 0) + { + /* there was a read error: AdjustEndOffset already + sent the termination item so we can exit */ + ia.Dispose(); + return; + } + + scannedFinal = true; + } + } /* while (offset >=0 && offset < offsetEnd) */ + + ra.Dispose(); + + if (offset >= 0) + { + RChannelBuffer* buffer = MakeEndOfStreamBuffer(this); + SendBuffer(buffer, true); + } + + ia.Dispose(); +} + +static long s_lineRecordScanSize = 4*1024; + +RChannelBufferHdfsReaderLineRecord:: +RChannelBufferHdfsReaderLineRecord(const char* uri) : + RChannelBufferHdfsReader(uri) +{ +} + +Int64 RChannelBufferHdfsReaderLineRecord:: +ScanForSync(HdfsBridgeNative::Reader* reader, + const char* fileName, + Int64 startOffset, Int64 endOffset, + RChannelBuffer** pErrorBuffer) +{ + *pErrorBuffer = NULL; + + char* scanBuffer = new char[s_lineRecordScanSize]; + + /* endOffset is -1 if we're scanning indefinitely, otherwise it + must designate a range that ends after startOffset */ + LogAssert(endOffset != 0); + + Int64 foundOffset = -1; + bool foundReturn = false; + + { + HdfsBridgeNative::ReaderAccessor ra(reader); + + do + { + long bytesToRead = s_lineRecordScanSize; + if (endOffset > 0) + { + /* there's a known endStop: let's not go past it */ + Int64 bytesLeft = endOffset - startOffset; + if (bytesLeft < bytesToRead) + { + bytesToRead = (long) bytesLeft; + } + } + + long bytesRead = ra.ReadBlock(startOffset, scanBuffer, + bytesToRead); + + if (bytesRead < -1) + { + char* errorMsg = ra.GetExceptionMessage(); + DrStr64 description; + description.SetF("Can't read HDFS file '%s' at offset %I64d:%d: %s", + fileName, + startOffset, s_lineRecordScanSize, errorMsg); + HdfsBridgeNative::DisposeString(errorMsg); + + *pErrorBuffer = + MakeErrorBuffer(DryadError_ChannelReadError, + description.GetString(), + this); + + /* break from while loop */ + foundOffset = -2; + } + else if (bytesRead == -1) + { + if (endOffset > 0) + { + /* we were supposed to be able to read as far as + endOffset, but hit EOF early */ + char* errorMsg = ra.GetExceptionMessage(); + DrStr64 description; + description.SetF("Got HDFS EOF early for '%s' at offset %I64d, expecting data up to %I64d: %s", + fileName, + startOffset, endOffset, errorMsg); + HdfsBridgeNative::DisposeString(errorMsg); + + *pErrorBuffer = + MakeErrorBuffer(DryadError_ChannelReadError, + description.GetString(), + this); + + /* break from while loop */ + foundOffset = -2; + } + else + { + /* we were scanning indefinitely and hit EOF, + which just means we found the end of the last + record. */ + /* break from while loop */ + foundOffset = startOffset; + } + } + else + { + LogAssert(bytesRead > 0); + + for (long i=0; i 0 && foundOffset >= endOffset) + { + /* we got to the end of the range we were + scanning without finding a new + record */ + LogAssert(foundOffset == endOffset); + LogAssert(startOffset + bytesRead == endOffset); + foundOffset = -1; + } + break; + } + else if (foundReturn) + { + /* we saw a return character the previous char, so + this is the first character in a new line */ + foundOffset = startOffset + i; + break; + } + else if (scanBuffer[i] == '\r') + { + foundReturn = true; + } + } + + startOffset += bytesRead; + } + } while (foundOffset == -1 && + (endOffset < 0 || startOffset < endOffset)); + + if (endOffset > 0) + { + LogAssert(startOffset <= endOffset); + } + } + + delete [] scanBuffer; + + return foundOffset; +} + + + +static bool +ExtractHdfsWriteUri(DrStr64 uri, + DrStr64& headNode, Int32* pHdfsPort, + DrStr64& filePath) +{ + if (!uri.StartsWith(RChannelBufferHdfsWriter::s_hdfsFilePrefix)) + { + return false; + } + + char* cHeadNode = + uri.GetWritableBuffer(uri.GetLength(), + strlen(RChannelBufferHdfsWriter:: + s_hdfsFilePrefix)); + + char* colon = strchr(cHeadNode, ':'); + if (colon == NULL) + { + return false; + } + *colon = '\0'; + + char* cPort = colon+1; + + char* slash = strchr(cPort, '/'); + if (slash == NULL) + { + return false; + } + *slash = '\0'; + + char* cPath = slash+1; + + headNode.Set(cHeadNode); + + DrError dre = DrStringToInt32(cPort, pHdfsPort); + if (dre != DrError_OK) + { + return false; + } + + filePath.Set(cPath); + + return true; +} + +RChannelBufferHdfsWriter::RChannelBufferHdfsWriter(const char* uri) +{ + m_uri.Set(uri); + m_queueHandle = INVALID_HANDLE_VALUE; + m_writeThread = INVALID_HANDLE_VALUE; + m_queueLength = 0; + + /* it's important to initialize here since these are called + sequentially so there's no race on the crappy hdfs + initialization code. Then we need to connect to the server, + since that has to happen first on the thread that creates the + jvm, in order for login to hdfs to work, for unexplained + reasons. */ + bool initialized = HdfsBridgeNative::Initialize(); + if (initialized) + { + DrStr64 headNode; + Int32 hdfsPort; + DrStr64 filePath; + + bool parsed = ExtractHdfsWriteUri(m_uri, + headNode, &hdfsPort, + filePath); + if (parsed) + { + HdfsBridgeNative::Instance* bridge; + bool openedInstance = + HdfsBridgeNative::OpenInstance(headNode.GetString(), + hdfsPort, + &bridge); + if (openedInstance) + { + HdfsBridgeNative::InstanceAccessor ia(bridge); + ia.Dispose(); + } + } + } +} + +DryadFixedMemoryBuffer* RChannelBufferHdfsWriter::GetNextWriteBuffer() +{ + return GetCustomWriteBuffer(s_writeBufferSize); +} + +DryadFixedMemoryBuffer* RChannelBufferHdfsWriter:: +GetCustomWriteBuffer(Size_t bufferSize) +{ + return new DryadAlignedWriteBlock(bufferSize, 0); +} + +void RChannelBufferHdfsWriter::Start() +{ + LogAssert(m_writeThread == INVALID_HANDLE_VALUE); + LogAssert(m_queueHandle == INVALID_HANDLE_VALUE); + LogAssert(m_queue.IsEmpty()); + LogAssert(m_queueLength == 0); + + { + AutoCriticalSection acs(&m_cs); + + m_processedLength = 0; + } + + m_queueHandle = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(m_queueHandle != NULL); + + m_writeThread = + (HANDLE) ::_beginthreadex(NULL, + 0, + RChannelBufferHdfsWriter::ThreadFunc, + this, + 0, + NULL); + LogAssert(m_writeThread != 0); +} + +unsigned __stdcall RChannelBufferHdfsWriter::ThreadFunc(void* arg) +{ + RChannelBufferHdfsWriter* self = (RChannelBufferHdfsWriter *) arg; + self->WriteThread(); + return 0; +} + +bool RChannelBufferHdfsWriter::Open(HdfsBridgeNative::Instance** pInstance, + HdfsBridgeNative::Writer** pWriter) +{ + bool initialized = HdfsBridgeNative::Initialize(); + if (!initialized) + { + DrStr64 description; + description.Set("Can't initialize HDFS bridge"); + m_completionItem = + MakeErrorItem(DryadError_ChannelOpenError, + description.GetString()); + return false; + } + + DrStr64 headNode; + Int32 hdfsPort; + DrStr64 filePath; + + bool parsed = ExtractHdfsWriteUri(m_uri, + headNode, &hdfsPort, + filePath); + if (!parsed) + { + DrStr64 description; + description.SetF("Can't parse HDFS URI '%s'", m_uri.GetString()); + m_completionItem = + MakeErrorItem(DryadError_InvalidChannelURI, + description.GetString()); + return false; + } + + HdfsBridgeNative::Instance* instance; + bool openedInstance = + HdfsBridgeNative::OpenInstance(headNode.GetString(), + hdfsPort, + &instance); + if (!openedInstance) + { + DrStr64 description; + description.SetF("Can't open HDFS Bridge '%s:%d'", + headNode.GetString(), hdfsPort); + m_completionItem = + MakeErrorItem(DryadError_ChannelOpenError, + description.GetString()); + return false; + } + + HdfsBridgeNative::InstanceAccessor ia(instance); + + HdfsBridgeNative::Writer* writer; + bool openedWriter = ia.OpenWriter(filePath.GetString(), &writer); + if (!openedWriter) + { + char* errorMsg = ia.GetExceptionMessage(); + DrStr64 description; + description.SetF("Can't open HDFS file '%s': %s", + filePath.GetString(), errorMsg); + HdfsBridgeNative::DisposeString(errorMsg); + + m_completionItem = + MakeErrorItem(DryadError_ChannelOpenError, + description.GetString()); + + ia.Dispose(); + return false; + } + + *pInstance = instance; + *pWriter = writer; + + return true; +} + +void RChannelBufferHdfsWriter::WriteThread() +{ + HdfsBridgeNative::Instance* instance = NULL; + HdfsBridgeNative::Writer* writer = NULL; + + bool opened = Open(&instance, &writer); + + do + { + DWORD dRet = WaitForSingleObject(m_queueHandle, INFINITE); + LogAssert(dRet == WAIT_OBJECT_0); + + WriteEntry* entry = NULL; + { + AutoCriticalSection acs(&m_cs); + + entry = m_queue.CastOut(m_queue.RemoveHead()); + /* the event shouldn't have been signaled unless the queue + is non-empty */ + LogAssert(entry != NULL); + LogAssert(m_queueLength > 0); + --m_queueLength; + } + + do + { + if (m_completionItem != NULL) + { + /* we've had a write error: we'll just reply below + with another error */ + LogAssert(m_completionItem->GetType() != + RChannelItem_EndOfStream); + } + else + { + if (entry->m_type == RChannelItem_Data) + { + /* we got a data item */ + LogAssert(entry->m_buffer != NULL); + + LogAssert(writer != NULL); + HdfsBridgeNative::WriterAccessor wa(writer); + + size_t dataSize; + void *dataAddr = entry->m_buffer-> + GetDataAddress(0, &dataSize, NULL); + Size_t dataToWrite = entry->m_buffer->GetAvailableSize(); + LogAssert(dataToWrite <= dataSize); + bool ret = + wa.WriteBlock((char *)dataAddr, dataToWrite, + entry->m_flush); + if (ret) + { + AutoCriticalSection acs(&m_cs); + + m_processedLength += dataSize; + } + else + { + char* errorMsg = wa.GetExceptionMessage(); + DrStr64 description; + description.SetF("Got HDFS error on write: %s", + errorMsg); + HdfsBridgeNative::DisposeString(errorMsg); + + DrLogE(description.GetString()); + + m_completionItem = + MakeErrorItem(DryadError_ChannelWriteError, + description.GetString()); + } + } + else + { + /* we got a termination item */ + LogAssert(entry->m_buffer == NULL); + + LogAssert(writer != NULL); + HdfsBridgeNative::WriterAccessor wa(writer); + + bool ret = wa.Close(); + if (!ret) + { + char* errorMsg = wa.GetExceptionMessage(); + DrStr64 description; + description.SetF("Got HDFS error on close: %s", + errorMsg); + HdfsBridgeNative::DisposeString(errorMsg); + + DrLogE(description.GetString()); + m_completionItem = + MakeErrorItem(DryadError_ChannelWriteError, + description.GetString()); + } + else + { + DrLogI("Closed HDFS writer"); + m_completionItem = + RChannelMarkerItem::Create(entry->m_type, false); + } + } + } + + RChannelItemType status; + if (m_completionItem == NULL) + { + status = RChannelItem_Data; + } + else + { + status = m_completionItem->GetType(); + LogAssert(status != RChannelItem_Data); + } + + entry->m_handler->ProcessWriteCompleted(status); + delete entry; + entry = NULL; + + if (status == RChannelItem_Data) + { + /* we haven't had an error or termination, so see if + there's another entry in the queue */ + AutoCriticalSection acs(&m_cs); + + entry = m_queue.CastOut(m_queue.RemoveHead()); + if (entry == NULL) + { + /* go to sleep until someone puts another buffer + in the queue */ + LogAssert(m_queueLength == 0); + BOOL bRet = ResetEvent(m_queueHandle); + LogAssert(bRet != 0); + } + else + { + LogAssert(m_queueLength > 0); + --m_queueLength; + } + } + } while (entry != NULL); + } while (m_completionItem == NULL); + + if (opened) + { + /* discard the java objects we're holding onto */ + HdfsBridgeNative::WriterAccessor wa(writer); + wa.Dispose(); + + HdfsBridgeNative::InstanceAccessor ia(instance); + ia.Dispose(); + } +} + +bool RChannelBufferHdfsWriter::AddToQueue(WriteEntry* entry) +{ + { + AutoCriticalSection acs(&m_cs); + + BOOL wasEmpty = m_queue.IsEmpty(); + + m_queue.InsertAsTail(m_queue.CastIn(entry)); + ++m_queueLength; + + if (wasEmpty) + { + LogAssert(m_queueLength == 1); + BOOL bRet = SetEvent(m_queueHandle); + LogAssert(bRet != 0); + } + + /* should block if the queue gets too deep */ + return (m_queueLength > s_maxBuffersToBlockWriter); + } +} + +bool RChannelBufferHdfsWriter:: +WriteBuffer(DryadFixedMemoryBuffer* buffer, + bool flushAfter, + RChannelBufferWriterHandler* handler) +{ + WriteEntry* entry = new WriteEntry; + entry->m_buffer.Attach(buffer); + entry->m_flush = flushAfter; + entry->m_type = RChannelItem_Data; + entry->m_handler = handler; + + return AddToQueue(entry); +} + +void RChannelBufferHdfsWriter:: +ReturnUnusedBuffer(DryadFixedMemoryBuffer* buffer) +{ + buffer->DecRef(); +} + +void RChannelBufferHdfsWriter:: +WriteTermination(RChannelItemType reasonCode, + RChannelBufferWriterHandler* handler) +{ + WriteEntry* entry = new WriteEntry; + /* NULL entry->m_buffer */ + entry->m_flush = false; + entry->m_type = reasonCode; + entry->m_handler = handler; + + AddToQueue(entry); +} + +void RChannelBufferHdfsWriter::FillInStatus(DryadChannelDescription* status) +{ + AutoCriticalSection acs(&m_cs); + + status->SetChannelTotalLength(0); + status->SetChannelProcessedLength(m_processedLength); +} + +void RChannelBufferHdfsWriter::Drain(RChannelItemRef* pReturnItem) +{ + /* Drain shouldn't have been called unless a termination item has + been sent, so eventually the writer thread will exit... */ + DWORD dRet = WaitForSingleObject(m_writeThread, INFINITE); + LogAssert(dRet == WAIT_OBJECT_0); + + LogAssert(m_queue.IsEmpty()); + LogAssert(m_queueLength == 0); + + /* and that's it, nothing more to do */ + LogAssert(m_completionItem != NULL); + *pReturnItem = m_completionItem; + m_completionItem = NULL; +} + +void RChannelBufferHdfsWriter::Close() +{ + LogAssert(m_queueHandle != INVALID_HANDLE_VALUE); + LogAssert(m_writeThread != INVALID_HANDLE_VALUE); + + BOOL bRetval = CloseHandle(m_queueHandle); + LogAssert(bRetval != 0); + m_queueHandle = INVALID_HANDLE_VALUE; + + bRetval = CloseHandle(m_writeThread); + LogAssert(bRetval != 0); + m_writeThread = INVALID_HANDLE_VALUE; + + LogAssert(m_completionItem == NULL); +} + +UInt64 RChannelBufferHdfsWriter::GetInitialSizeHint() +{ + return 0; +} + +void RChannelBufferHdfsWriter::SetInitialSizeHint(UInt64 /*hint*/) +{ +} diff --git a/DryadVertex/VertexHost/system/channel/src/channelbufferhdfs.h b/DryadVertex/VertexHost/system/channel/src/channelbufferhdfs.h new file mode 100644 index 0000000..b2326b1 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelbufferhdfs.h @@ -0,0 +1,157 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "channelreader.h" +#include "channelwriter.h" + +#include + +class RChannelBufferHdfsReader + : public RChannelBufferReader, public RChannelBufferDefaultHandler +{ +public: + static const char* s_hdfsPartitionPrefix; + + RChannelBufferHdfsReader(const char* uri); + virtual ~RChannelBufferHdfsReader(); + + void Start(RChannelBufferPrefetchInfo* prefetchCookie, + RChannelBufferReaderHandler* handler); + + void Interrupt(); + + void FillInStatus(DryadChannelDescription* status); + + void Drain(RChannelItem* drainItem); + + void Close(); + + bool GetTotalLength(UInt64* pLen); + + /* the RChannelBufferDefaultHandler interface */ + void ReturnBuffer(RChannelBuffer* buffer); + +private: + virtual Int64 ScanForSync(HdfsBridgeNative::Reader* reader, + const char* fileName, + Int64 startOffset, Int64 endOffset, + RChannelBuffer** pErrorBuffer) = 0; + + static unsigned __stdcall ThreadFunc(void* a); + void SendBuffer(RChannelBuffer* buffer, bool getSemaphore); + Int64 AdjustStartOffset(HdfsBridgeNative::Reader* reader, + const char* fileName, + Int64 startOffset, + Int64 endOffset); + Int64 AdjustEndOffset(HdfsBridgeNative::Reader* reader, + const char* fileName, + Int64 endOffset); + Int64 ReadDataBuffer(HdfsBridgeNative::ReaderAccessor& ra, + const char* fileName, + Int64 offset, + Int64 endOffset); + void ReadThread(); + + DrStr64 m_uri; + RChannelBufferReaderHandler* m_handler; + HANDLE m_readThread; + HANDLE m_blockSemaphore; + HANDLE m_abortHandle; + + UInt64 m_totalLength; + UInt64 m_processedLength; + UInt32 m_buffersOut; + CRITSEC m_cs; +}; + +class RChannelBufferHdfsReaderLineRecord : public RChannelBufferHdfsReader +{ +public: + RChannelBufferHdfsReaderLineRecord(const char* uri); + +private: + Int64 ScanForSync(HdfsBridgeNative::Reader* reader, + const char* fileName, + Int64 startOffset, Int64 endOffset, + RChannelBuffer** pErrorBuffer); +}; + + +class RChannelBufferHdfsWriter : public RChannelBufferWriter +{ +public: + static const char* s_hdfsFilePrefix; + + RChannelBufferHdfsWriter(const char* uri); + + DryadFixedMemoryBuffer* GetNextWriteBuffer(); + DryadFixedMemoryBuffer* GetCustomWriteBuffer(Size_t bufferSize); + + void Start(); + + bool WriteBuffer(DryadFixedMemoryBuffer* buffer, + bool flushAfter, + RChannelBufferWriterHandler* handler); + + void ReturnUnusedBuffer(DryadFixedMemoryBuffer* buffer); + + void WriteTermination(RChannelItemType reasonCode, + RChannelBufferWriterHandler* handler); + + void FillInStatus(DryadChannelDescription* status); + + void Drain(RChannelItemRef* pReturnItem); + + void Close(); + + /* Get/set a hint about the total length the channel is expected + to be. Some channel implementations can use this to improve + write performance and decrease disk fragmentation. A value of 0 + (the default) means that the size is unknown. */ + UInt64 GetInitialSizeHint(); + void SetInitialSizeHint(UInt64 hint); + +private: + struct WriteEntry + { + DrRef m_buffer; + bool m_flush; + RChannelItemType m_type; + RChannelBufferWriterHandler* m_handler; + DrBListEntry m_listPtr; + }; + + static unsigned __stdcall ThreadFunc(void* arg); + void WriteThread(); + bool Open(HdfsBridgeNative::Instance** pInstance, + HdfsBridgeNative::Writer** pWriter); + bool AddToQueue(WriteEntry* entry); + + DrStr64 m_uri; + DryadBList m_queue; + UInt32 m_queueLength; + HANDLE m_queueHandle; + HANDLE m_writeThread; + RChannelItemRef m_completionItem; + UInt64 m_processedLength; + CRITSEC m_cs; +}; diff --git a/DryadVertex/VertexHost/system/channel/src/channelbuffernativereader.cpp b/DryadVertex/VertexHost/system/channel/src/channelbuffernativereader.cpp new file mode 100644 index 0000000..a69137b --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelbuffernativereader.cpp @@ -0,0 +1,1755 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#define _CRT_RAND_S +#include + +#include +#include +#include +#include +#include +#include + +#pragma unmanaged + + +// +// Create a new generic read handle +// +RChannelBufferReaderNative::ReadHandler::ReadHandler(UInt64 requestOffset, + UInt64 streamOffset, + UInt32 dataSize, + size_t dataAlignment) + +{ + // + // Initialize the handler with the offset and data size + // + this->DryadNativePort::Handler::InitializeInternal(dataSize, + requestOffset); + + m_streamOffset = streamOffset; + m_buffer = NULL; + + // + // Create a fixed size buffer if there is any data to read + // + if (dataSize > 0) + { + m_block = new DryadAlignedReadBlock(dataSize, dataAlignment); + } + else + { + m_block = NULL; + } + + m_isLastDataBuffer = false; +} + +RChannelBufferReaderNative::ReadHandler::~ReadHandler() +{ + LogAssert(m_buffer == NULL); + if (m_block != NULL) + { + m_block->DecRef(); + } +} + +UInt64 RChannelBufferReaderNative::ReadHandler::GetStreamOffset() +{ + return m_streamOffset; +} + +void* RChannelBufferReaderNative::ReadHandler::GetData() +{ + LogAssert(m_block != NULL); + return m_block->GetData(); +} + +DryadAlignedReadBlock* RChannelBufferReaderNative::ReadHandler::GetBlock() +{ + LogAssert(m_block != NULL); + return m_block; +} + +void RChannelBufferReaderNative::ReadHandler:: + SetChannelBuffer(RChannelBuffer* buffer) +{ + LogAssert(m_buffer == NULL); + m_buffer = buffer; +} + +// +// Return pointer to current buffer and forget about it +// +RChannelBuffer* RChannelBufferReaderNative::ReadHandler:: + TransferChannelBuffer() +{ + LogAssert(m_buffer != NULL); + RChannelBuffer* retval = m_buffer; + m_buffer = NULL; + return retval; +} + +// +// Mark last data buffer +// +void RChannelBufferReaderNative::ReadHandler::SignalLastDataBuffer() +{ + LogAssert(m_isLastDataBuffer == false); + LogAssert(m_buffer != NULL); + LogAssert(m_buffer->GetType() == RChannelBuffer_Data); + m_isLastDataBuffer = true; +} + +bool RChannelBufferReaderNative::ReadHandler::IsLastDataBuffer() +{ + return m_isLastDataBuffer; +} + +RChannelBufferReaderNative:: + RChannelBufferReaderNative(UInt32 prefetchBuffers, + DryadNativePort* port, + WorkQueue* workQueue, + RChannelOpenThrottler* openThrottler, + bool supportsLazyOpen) +{ + m_prefetchBuffers = prefetchBuffers; + m_port = port; + m_workQueue = workQueue; + m_openThrottler = openThrottler; + m_supportsLazyOpen = supportsLazyOpen; + + m_state = S_Closed; + m_openState = OS_Closed; + m_drainingOpenQueue = false; + m_sentLastBuffer = false; + m_outstandingHandlers = 0; + m_outstandingBuffers = 0; + + m_totalLength = 0; + m_nextStreamOffsetToProcess = 0; + + m_handlerReturnEvent = INVALID_HANDLE_VALUE; + m_handler = NULL; + m_bufferReturnEvent = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(m_bufferReturnEvent != NULL); +} + +RChannelBufferReaderNative::~RChannelBufferReaderNative() +{ + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == S_Closed); + LogAssert(m_openState == OS_Closed); + LogAssert(m_drainingOpenQueue == false); + BOOL bRet = ::CloseHandle(m_bufferReturnEvent); + LogAssert(bRet != 0); + LogAssert(m_errorBuffer == NULL); + } +} + +CRITSEC* RChannelBufferReaderNative::GetBaseDR() +{ + return &m_baseDR; +} + +void RChannelBufferReaderNative::SetErrorBuffer(RChannelBuffer* buffer) +{ + { + AutoCriticalSection acs(&m_baseDR); + + if (m_errorBuffer == NULL) + { + m_errorBuffer = buffer; + } + } +} + +void RChannelBufferReaderNative:: + FillInStatus(DryadChannelDescription* status) +{ + { + AutoCriticalSection acs(&m_baseDR); + + status->SetChannelTotalLength(m_totalLength); + status->SetChannelProcessedLength(m_nextStreamOffsetToProcess); + } +} + +void RChannelBufferReaderNative::AssociateHandleWithPort(HANDLE h) +{ + m_port->AssociateHandle(h); +} + +void RChannelBufferReaderNative::SetTotalLength(UInt64 totalLength) +{ + m_totalLength = totalLength; +} + +bool RChannelBufferReaderNative::GetTotalLength(UInt64* pLen) +{ + *pLen = m_totalLength; + return true; +} + +/* called with baseDR held */ +void RChannelBufferReaderNative::OpenNativeReader() +{ + LogAssert(m_state == S_Closed); + LogAssert(m_openState == OS_Closed); + LogAssert(m_drainingOpenQueue == false); + m_state = S_Stopped; + m_openState = OS_Stopped; +} + +/* called with baseDR held */ +RChannelBuffer* RChannelBufferReaderNative::CloseNativeReader() +{ + LogAssert(m_state == S_Stopped); + LogAssert(m_openState == OS_Stopped); + LogAssert(m_drainingOpenQueue == false); + m_state = S_Closed; + m_openState = OS_Closed; + + RChannelBuffer* errorBuffer = m_errorBuffer.Detach(); + + return errorBuffer; +} + +void RChannelBufferReaderNative:: + StartNativeReader(RChannelBufferReaderHandler* handler) +{ + HandlerList requestList; + ChannelBufferList sendErrorBufferList; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == S_Stopped); + LogAssert(m_openState == OS_Stopped); + LogAssert(m_drainingOpenQueue == false); + LogAssert(m_sentLastBuffer == false); + LogAssert(m_outstandingHandlers == 0); + LogAssert(m_outstandingBuffers == 0); + LogAssert(m_reorderMap.empty()); + LogAssert(m_handlerReturnEvent == INVALID_HANDLE_VALUE); + m_handlerReturnEvent = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(m_handlerReturnEvent != NULL); + LogAssert(m_handler == NULL); + + m_handler = handler; + + m_latch.Start(); + m_nextStreamOffsetToProcess = 0; + + m_fetching = true; + + bool lazyOpenDone = false; + + if (m_supportsLazyOpen) + { + m_openState = OS_NotOpened; + } + else + { + m_openState = OS_Opened; + lazyOpenDone = true; + } + + if (m_errorBuffer == NULL) + { + LogAssert(m_prefetchBuffers > 0); + UInt32 i; + for (i=0; iIncRef(); + sendErrorBufferList.InsertAsTail(sendErrorBufferList. + CastIn(m_errorBuffer)); + m_latch.AcceptList(&sendErrorBufferList); + + m_outstandingBuffers = 1; + m_sentLastBuffer = true; + } + + m_state = S_Running; + } + + DrBListEntry* listEntry = requestList.GetHead(); + while (listEntry != NULL) + { + ReadHandler* requestHandler = requestList.CastOut(listEntry); + listEntry = requestList.GetNext(listEntry); + requestList.Remove(requestList.CastIn(requestHandler)); + requestHandler->QueueRead(m_port); + } + + while (sendErrorBufferList.IsEmpty() == false) + { + listEntry = sendErrorBufferList.GetHead(); + while (listEntry != NULL) + { + RChannelBuffer* buffer = sendErrorBufferList.CastOut(listEntry); + listEntry = sendErrorBufferList.GetNext(listEntry); + sendErrorBufferList.Remove(sendErrorBufferList.CastIn(buffer)); + m_handler->ProcessBuffer(buffer); + } + + { + AutoCriticalSection acs(&m_baseDR); + + m_latch.TransferList(&sendErrorBufferList); + } + } +} + +// +// Return true if currently open for reading and caller should wait to use +// +bool RChannelBufferReaderNative::EnsureOpenForRead(ReadHandler* handler) +{ + bool consumedHandler = false; + bool queueOpen = false; + + // + // Enter a critical section and + // + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_supportsLazyOpen); + + if (m_openState == OS_NotOpened) + { + // + // Either this is the first read ever queued on this channel + // or the throttler will call us back when the file is ready + // to be opened + // + if (m_openThrottler == NULL || + m_openThrottler->QueueOpen(this)) + { + queueOpen = true; + } + + m_openState = OS_Waiting; + } + + if (m_openState == OS_Waiting || m_drainingOpenQueue) + { + // + // the open has been queued but hasn't yet completed, so + // add this handler to a list of waiters that will get + // sent once the file is open + // + m_openWaitingList.InsertAsTail(m_openWaitingList. + CastIn(handler)); + + // + // true return means the caller shouldn't do anything with + // the handler right now since it has been queued. + // + consumedHandler = true; + } + else + { + LogAssert(m_openState == OS_Opened || m_openState == OS_OpenError); + + // + // the file is successfully opened or had an open error: + // either way, it's as open as it's going to get: fill in + // the appropriate handles + // + FillInOpenedDetails(handler); + } + } + + if (queueOpen) + { + // + // if queue is open, open the file now; queue it up so we make + // sure it's on one of our threads. This causes + // OpenAfterThrottle to be called on a worker thread + // + RChannelOpenThrottler::Dispatch* dispatch = + new RChannelOpenThrottler::Dispatch(this); + m_workQueue->EnQueue(dispatch); + } + + return consumedHandler; +} + +// +// Called when all handlers are done to close up reader +// we return true if the file was open when we entered, and was +// closed during the progress of the function. This value is +// needed so that the caller can notify the throttler outside the +// lock when a file gets closed +// +bool RChannelBufferReaderNative::FinishUsingFile() +{ + // + // we shouldn't get to this routine if we are waiting for a + // blocked open from the throttling mechanism + // + LogAssert(m_openState != OS_Waiting); + if (m_drainingOpenQueue) + { + LogAssert(m_openWaitingList.IsEmpty()); + } + + if (m_openState == OS_NotOpened) + { + // + // if we never opened the file; transition straight to closed + // + m_openState = OS_Stopped; + return false; + } + else if (m_openState == OS_Opened) + { + // + // If opened, close file and stop + // + EagerCloseFile(); + m_openState = OS_Stopped; + return true; + } + else + { + // + // If state isn't opened or not opened, it should be in error + // ie should not have been stopped already + // + LogAssert(m_openState == OS_OpenError); + return false; + } +} + +void RChannelBufferReaderNative::OpenAfterThrottle() +{ + bool openSuccess = LazyOpenFile(); + if (!openSuccess && m_openThrottler != NULL) + { + /* this means we tried and failed to open the file. Let the + throttler know (outside the lock) so it can queue up + the next one */ + m_openThrottler->NotifyFileCompleted(); + } + + bool firstTime = true; + + while (true) + { + HandlerList sendList; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_supportsLazyOpen); + + if (m_openState == OS_Waiting) + { + LogAssert(firstTime); + LogAssert(m_drainingOpenQueue == false); + m_drainingOpenQueue = true; + } + + if (m_openWaitingList.IsEmpty()) + { + /* there's supposed to be something waiting when we + get opened */ + LogAssert(firstTime == false); + LogAssert(m_drainingOpenQueue == true); + m_drainingOpenQueue = false; + + /* this drain loop is the last thing we were waiting + for */ + if (m_outstandingBuffers == 0 && m_state == S_Stopping) + { + BOOL bRet = ::SetEvent(m_bufferReturnEvent); + LogAssert(bRet != 0); + } + + return; + } + else + { + if (firstTime == true) + { + LogAssert(m_openState == OS_Waiting); + + if (openSuccess) + { + m_openState = OS_Opened; + } + else + { + m_openState = OS_OpenError; + } + + firstTime = false; + } + + LogAssert(m_openState == OS_Opened || + m_openState == OS_OpenError); + + /* copy out everything that's currently waiting */ + + while (m_openWaitingList.IsEmpty() == false) + { + ReadHandler* nextRequest = + m_openWaitingList. + CastOut(m_openWaitingList.RemoveHead()); + FillInOpenedDetails(nextRequest); + sendList.InsertAsTail(sendList.CastIn(nextRequest)); + } + } + + /* the openWaitingList is now empty but we're going to + leave the lock while remaining in the + m_drainingOpenQueue state and send off the + handlers. Then we'll go around the loop again in case + somebody added more handlers to the openWaitingList + while we were doing the send. */ + } + + while (sendList.IsEmpty() == false) + { + ReadHandler* nextRequest = + sendList.CastOut(sendList.RemoveHead()); + nextRequest->QueueRead(m_port); + } + } +} + +// +// Add all buffers waiting to be sent to sending queue and return any finished buffers +// to pool. Close the file if everything is done. +// +void RChannelBufferReaderNative::DispatchBuffer(ReadHandler* handler, + bool makeNewHandler, + ReadHandler* fillInReadHandler) +{ + ChannelBufferList sendBufferList; + ChannelBufferList returnBufferList; + ReadHandler* newHandler = NULL; + bool performedClose = false; + + { + AutoCriticalSection acs(&m_baseDR); + + if (m_state == S_Running) + { + if (m_sentLastBuffer) + { + // + // If running and already send last buffer, add read handler buffer to return buffer list + // + LogAssert(m_reorderMap.empty()); + RChannelBuffer* buffer = handler->TransferChannelBuffer(); + returnBufferList.InsertAsTail(returnBufferList. + CastIn(buffer)); + delete handler; + handler = NULL; + } + else + { + // + // If running and have not already sent last buffer, + // + std::pair retval; + UInt64 streamOffset = handler->GetStreamOffset(); + retval = m_reorderMap.insert(std::make_pair(streamOffset, + handler)); + LogAssert(retval.second == true); + handler = NULL; + + // + // Foreach entry in the reorder map + // + OffsetHandlerMap::iterator iter; + for (iter = m_reorderMap.begin(); + m_sentLastBuffer == false && + iter != m_reorderMap.end() && + iter->first == m_nextStreamOffsetToProcess; + iter = m_reorderMap.erase(iter)) + { + // + // Get the read handler from the reorder map, and add it's + // buffer to the send buffer list + // + ReadHandler* nextHandler = iter->second; + RChannelBuffer* nextBuffer = + nextHandler->TransferChannelBuffer(); + + sendBufferList.InsertAsTail(sendBufferList. + CastIn(nextBuffer)); + + RChannelBufferType t = nextBuffer->GetType(); + if (RChannelBuffer::IsTerminationBuffer(t)) + { + // + // If buffer is to be used for termination, it's the last one + // + m_sentLastBuffer = true; + } + else + { + // + // If the buffer is not used for termination, update the stream offset by + // the size of the buffer for next processing step + // + m_nextStreamOffsetToProcess += + (UInt64) + nextHandler->GetBlock()->GetAvailableSize(); + + if (nextHandler->IsLastDataBuffer()) + { + // + // If this is the last data buffer, then add a new end-of-stream buffer + // and mark last buffer as sent + // + UInt64 endOfStreamOffset = + m_nextStreamOffsetToProcess; + RChannelBuffer* endBuffer = + MakeEndOfStreamBuffer(endOfStreamOffset); + sendBufferList.InsertAsTail(sendBufferList. + CastIn(endBuffer)); + + m_sentLastBuffer = true; + } + } + + delete nextHandler; + } + + // + // If the last buffer has been put in the send buffer list + // (either termination or end-of-stream), take the read handler + // and add it's buffer to the return buffer list + // + if (m_sentLastBuffer) + { + for (iter = m_reorderMap.begin(); + iter != m_reorderMap.end(); + iter = m_reorderMap.erase(iter)) + { + ReadHandler* nextHandler = iter->second; + RChannelBuffer* nextBuffer = + nextHandler->TransferChannelBuffer(); + returnBufferList.InsertAsTail(returnBufferList. + CastIn(nextBuffer)); + delete nextHandler; + } + } + } + + // + // Update count of buffers currently waiting to be processed + // + m_outstandingBuffers += sendBufferList.CountLinks(); + m_outstandingBuffers += returnBufferList.CountLinks(); + + // + // Add send buffer list to send latch for processing + // + m_latch.AcceptList(&sendBufferList); + + LogAssert(m_outstandingHandlers > 0); + --m_outstandingHandlers; + + // + // If the last buffer has been send and no more handlers are using it, + // mark it as finished + // + if (m_sentLastBuffer == true && m_outstandingHandlers == 0) + { + performedClose = FinishUsingFile(); + } + + UInt32 buffersInFlight = + m_outstandingBuffers + m_outstandingHandlers; + if (makeNewHandler) + { + // + // If the number of buffers used is less than prefetch count + // and currently fetching, create another one if requested + // + if (m_fetching == true && + buffersInFlight < m_prefetchBuffers) + { + LogAssert(m_state == S_Running); + + bool lazyOpenDone = ((m_openState == OS_Opened || + m_openState == OS_OpenError) && + !m_drainingOpenQueue); + + newHandler = GetNextReadHandler(lazyOpenDone); + ++m_outstandingHandlers; + } + } + else if (fillInReadHandler == NULL) + { + // + // we've got all the buffers out of the stream now + // + m_fetching = false; + } + + if (fillInReadHandler != NULL) + { + // + // If a new read handler was provided, increment the cound + // + ++m_outstandingHandlers; + } + } + else + { + // + // If reader is not it OK state, make sure it's stopping + // + LogAssert(m_state == S_Stopping); + + // + // Transfer the read handler buffer into the return buffer list and update the count of + // outstanding buffers + // + RChannelBuffer* buffer = handler->TransferChannelBuffer(); + returnBufferList.InsertAsTail(returnBufferList.CastIn(buffer)); + delete handler; + handler = NULL; + + m_outstandingBuffers += returnBufferList.CountLinks(); + + // + // In this case, clear up the replacement read handler + // + if (fillInReadHandler != NULL) + { + delete fillInReadHandler; + fillInReadHandler = NULL; + } + + LogAssert(m_outstandingHandlers > 0); + --m_outstandingHandlers; + + // + // If no more handlers remaining, close the file and set handler return event + // + if (m_outstandingHandlers == 0) + { + performedClose = FinishUsingFile(); + BOOL bRet = ::SetEvent(m_handlerReturnEvent); + LogAssert(bRet != 0); + } + + LogAssert(newHandler == NULL); + } + + // + // Leave critical section + // + } + + // + // For each entry is the return buffer list, return it to the buffer pool + // + DrBListEntry* listEntry; + listEntry = returnBufferList.GetHead(); + while (listEntry != NULL) + { + RChannelBuffer* buffer = returnBufferList.CastOut(listEntry); + listEntry = returnBufferList.GetNext(listEntry); + returnBufferList.Remove(returnBufferList.CastIn(buffer)); + ReturnBuffer(buffer); + } + + // + // todo: why do this more than once? + // + while (sendBufferList.IsEmpty() == false) + { + // + // If there are any buffers in the sendBuffer list, queue them up for processing + // + listEntry = sendBufferList.GetHead(); + while (listEntry != NULL) + { + RChannelBuffer* buffer = sendBufferList.CastOut(listEntry); + listEntry = sendBufferList.GetNext(listEntry); + sendBufferList.Remove(sendBufferList.CastIn(buffer)); + m_handler->ProcessBuffer(buffer); + } + + { + // + // Take a lock and transfer all send buffer entries to the latch + // + + AutoCriticalSection acs(&m_baseDR); + m_latch.TransferList(&sendBufferList); + } + } + + // + // If a new read handler has been created, have it start reading + // + if (newHandler != NULL) + { + newHandler->QueueRead(m_port); + } + + // + // If a new read handler was supplied, have it start reading + // + if (fillInReadHandler != NULL) + { + fillInReadHandler->QueueRead(m_port); + } + + // + // The file was open previously, but is now closed. If we're + // being throttled, let the throttler know so it can open the + // next file in the queue + // + if (performedClose && m_openThrottler != NULL) + { + m_openThrottler->NotifyFileCompleted(); + } +} + +// +// Stop reader. Blocks for all outstanding buffers to be processed. +// +void RChannelBufferReaderNative::Interrupt() +{ + bool mustWaitForLatch = false; + HANDLE handlerEvent = INVALID_HANDLE_VALUE; + ChannelBufferList returnBufferList; + + // + // Enter a critical section and + // + { + AutoCriticalSection acs(&m_baseDR); + + // + // Interrupt latch if sending and see if we need to wait for it to block + // + mustWaitForLatch = m_latch.Interrupt(); + + // + // If any read handlers are still outstanding, get the handler return event handle + // + if (m_outstandingHandlers > 0) + { + LogAssert(m_handlerReturnEvent != INVALID_HANDLE_VALUE); + BOOL bRet = ::ResetEvent(m_handlerReturnEvent); + LogAssert(bRet != 0); + bRet = ::DuplicateHandle(GetCurrentProcess(), + m_handlerReturnEvent, + GetCurrentProcess(), + &handlerEvent, + 0, + FALSE, + DUPLICATE_SAME_ACCESS); + LogAssert(bRet != 0); + LogAssert(handlerEvent != INVALID_HANDLE_VALUE); + } + + if (m_state == S_Stopped) + { + // + // If this is stopped, the latch and outstanding handler handle must both + // be "stopped" as well + // + LogAssert(mustWaitForLatch == false); + LogAssert(handlerEvent == INVALID_HANDLE_VALUE); + } + else + { + // + // If not stopped, must be stopping + // + LogAssert(m_state == S_Running || m_state == S_Stopping); + m_state = S_Stopping; + m_fetching = false; + + // + // For each read handler, transfer a pointer to the read buffer into a list and clean up handler + // + OffsetHandlerMap::iterator iter; + for (iter = m_reorderMap.begin(); + iter != m_reorderMap.end(); + iter = m_reorderMap.erase(iter)) + { + ReadHandler* nextHandler = iter->second; + RChannelBuffer* nextBuffer = + nextHandler->TransferChannelBuffer(); + returnBufferList.InsertAsTail(returnBufferList. + CastIn(nextBuffer)); + delete nextHandler; + } + + // + // Remember number of links (number of buffers - 1) + // + m_outstandingBuffers += returnBufferList.CountLinks(); + } + } + + // + // For each return buffer, start up new reader if possible + // + DrBListEntry* listEntry = returnBufferList.GetHead(); + while (listEntry != NULL) + { + RChannelBuffer* buffer = returnBufferList.CastOut(listEntry); + listEntry = returnBufferList.GetNext(listEntry); + returnBufferList.Remove(returnBufferList.CastIn(buffer)); + ReturnBuffer(buffer); + } + + // + // Wait for latch to reset if needed + // + if (mustWaitForLatch) + { + m_latch.Wait(); + } + + // + // Wait for read handler to return if needed + // + if (handlerEvent != INVALID_HANDLE_VALUE) + { + DWORD dRet = ::WaitForSingleObject(handlerEvent, INFINITE); + LogAssert(dRet == WAIT_OBJECT_0); + BOOL bRet = ::CloseHandle(handlerEvent); + LogAssert(bRet != 0); + } +} + +/* this needs to be overwritten by derived classes that implement lazy + open/eager close (i.e. files but not pipes) */ +bool RChannelBufferReaderNative::LazyOpenFile() +{ + LogAssert(false); + return true; +} + +/* this is only ever called as part of the lazy open mechanism, to + fill in details of handlers after a deferred open has + completed. Therefore if it's not overwritten, it asserts */ +void RChannelBufferReaderNative::FillInOpenedDetails(ReadHandler* handler) +{ + LogAssert(false); +} + +// +// this does nothing by default: it needs to be overwritten by derived +// classes that implement lazy open/eager close (i.e. files but not +// pipes) +// +void RChannelBufferReaderNative::EagerCloseFile() +{ +} + +// +// Stop reader but allow outstanding reads to finish +// +void RChannelBufferReaderNative::DrainNativeReader() +{ + // + // Interrupt the reader to avoid additional reads + // + Interrupt(); + + bool mustWaitForBuffers = false; + bool performedClose = false; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == S_Stopping); + + if (m_outstandingBuffers > 0 || m_drainingOpenQueue) + { + // + // Need to wait for outstanding buffers to be handled + // + mustWaitForBuffers = true; + BOOL bRet = ::ResetEvent(m_bufferReturnEvent); + LogAssert(bRet != 0); + } + else if (m_openState != OS_Stopped) + { + // + // If everything done and stopped, close file + // + performedClose = FinishUsingFile(); + } + } + + if (performedClose && m_openThrottler != NULL) + { + // + // the file was open previously, and is now closed. If we're + // being throttled, let the throttler know so it can open the + // next file in the queue + // + m_openThrottler->NotifyFileCompleted(); + } + + if (mustWaitForBuffers) + { + // + // If there were outstanding buffers or the queue was still draining, wait + // for all buffers to be returned + // + DWORD dRet = ::WaitForSingleObject(m_bufferReturnEvent, INFINITE); + LogAssert(dRet == WAIT_OBJECT_0); + } + + { + AutoCriticalSection acs(&m_baseDR); + + // + // Ensure that everything was shut down correctly and + // clean up all additional resources (handles and latch) + // + LogAssert(m_outstandingHandlers == 0); + LogAssert(m_outstandingBuffers == 0); + LogAssert(m_drainingOpenQueue == false); + LogAssert(m_reorderMap.empty()); + m_latch.Stop(); + BOOL bRet = ::CloseHandle(m_handlerReturnEvent); + LogAssert(bRet != 0); + m_handlerReturnEvent = INVALID_HANDLE_VALUE; + m_handler = NULL; + + LogAssert(m_state == S_Stopping); + LogAssert(m_openState == OS_OpenError || + m_openState == OS_Stopped); + m_state = S_Stopped; + m_openState = OS_Stopped; + m_sentLastBuffer = false; + } +} + +void RChannelBufferReaderNative::SetPrefetchBufferCount(UInt32 numberOfBuffers) +{ + m_prefetchBuffers = numberOfBuffers; +} + +UInt32 RChannelBufferReaderNative::GetPrefetchBufferCount() +{ + return m_prefetchBuffers; +} + +// +// Make an end-of-stream buffer +// +RChannelBuffer* RChannelBufferReaderNative:: + MakeEndOfStreamBuffer(UInt64 streamOffset) +{ + RChannelItem* item = + RChannelMarkerItem::Create(RChannelItem_EndOfStream, true); + DryadMetaData* metaData = item->GetMetaData(); + MakeFileMetaData(metaData, streamOffset); + + RChannelBuffer* buffer = + RChannelBufferMarkerDefault::Create(RChannelBuffer_EndOfStream, + item, + this); + metaData = buffer->GetMetaData(); + MakeFileMetaData(metaData, streamOffset); + + return buffer; +} + +RChannelBuffer* RChannelBufferReaderNative:: + MakeErrorBuffer(UInt64 streamOffset, DrError errorCode) +{ + RChannelItem* item = + RChannelMarkerItem::Create(RChannelItem_Abort, true); + DryadMetaData* metaData = item->GetMetaData(); + MakeFileMetaData(metaData, streamOffset); + metaData->AddError(errorCode); + + RChannelBuffer* errorBuffer = + RChannelBufferMarkerDefault::Create(RChannelBuffer_Abort, + item, + this); + metaData = errorBuffer->GetMetaData(); + MakeFileMetaData(metaData, streamOffset); + metaData->AddError(errorCode); + + SetErrorBuffer(errorBuffer); + return errorBuffer; +} + +RChannelBuffer* RChannelBufferReaderNative:: + MakeOpenErrorBuffer(DrError errorCode, const char* description) +{ + RChannelItem* item = + RChannelMarkerItem::Create(RChannelItem_Abort, true); + DryadMetaData* metaData = item->GetMetaData(); + metaData->AddErrorWithDescription(errorCode, description); + + RChannelBuffer* errorBuffer = + RChannelBufferMarkerDefault::Create(RChannelBuffer_Abort, + item, + this); + metaData = errorBuffer->GetMetaData(); + metaData->AddErrorWithDescription(errorCode, description); + + return errorBuffer; +} + +// +// Create a data buffer for the reader to use +// +RChannelBuffer* RChannelBufferReaderNative:: + MakeDataBuffer(UInt64 streamOffset, DryadLockedMemoryBuffer* block) +{ + RChannelBuffer* dataBuffer = + RChannelBufferDataDefault::Create(block, + streamOffset, + this); + + DryadMetaData* metaData = dataBuffer->GetMetaData(); + MakeFileMetaData(metaData, streamOffset); + DryadMTagRef tag; + tag.Attach(DryadMTagUInt64::Create(Prop_Dryad_BufferLength, + block->GetAvailableSize())); + metaData->Append(tag, false); + + block->IncRef(); + + return dataBuffer; +} + +// +// Create a new read handler if needed and do the accounting for returning a buffer +// +void RChannelBufferReaderNative::ReturnBuffer(RChannelBuffer* buffer) +{ + // + // Decrement reference count to buffer + // + buffer->DecRef(); + + // + // Enter critical section and decrement number of outstanding buffers + // If this is the last buffer and everything else done, set buffer return event + // + ReadHandler* newHandler = NULL; + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == S_Running || m_state == S_Stopping); + + // + // Decrement outstand buffer count + // + LogAssert(m_outstandingBuffers > 0); + --m_outstandingBuffers; + + if (m_outstandingBuffers == 0 && + m_drainingOpenQueue == false && + m_state == S_Stopping) + { + // + // If no more outstanding buffers, done with open queue, and stopping, + // then set buffer return event + // + BOOL bRet = ::SetEvent(m_bufferReturnEvent); + LogAssert(bRet != 0); + } + else + { + // + // If not done, enumerate buffers still being processed and create a read handle if + // number of buffers currently working is less than number of prefetch buffers allowed + // + UInt32 buffersInFlight = + m_outstandingBuffers + m_outstandingHandlers; + + if (m_fetching == true && + buffersInFlight < m_prefetchBuffers) + { + LogAssert(m_state == S_Running); + + bool lazyOpenDone = ((m_openState == OS_Opened || + m_openState == OS_OpenError) && + !m_drainingOpenQueue); + + newHandler = GetNextReadHandler(lazyOpenDone); + ++m_outstandingHandlers; + } + } + } + + // + // If a new read handler was created, have it read from the port + // + if (newHandler != NULL) + { + newHandler->QueueRead(m_port); + } +} + + +// +// Create a new file read handler and initialize buffer +// +RChannelBufferReaderNativeFile::FileReadHandler:: + FileReadHandler(HANDLE fileHandle, + bool detailsPresent, + UInt64 streamOffset, + UInt32 dataSize, + size_t dataAlignment, + RChannelBufferReaderNativeFile* parent) : + RChannelBufferReaderNative::ReadHandler(streamOffset, streamOffset, + dataSize, dataAlignment) +{ + m_parent = parent; + m_fileHandle = fileHandle; + m_detailsPresent = detailsPresent; +} + +void RChannelBufferReaderNativeFile::FileReadHandler::SetFileHandle(HANDLE h) +{ + m_fileHandle = h; + LogAssert(m_detailsPresent == false); + m_detailsPresent = true; +} + +HANDLE RChannelBufferReaderNativeFile::FileReadHandler::GetFileHandle() +{ + return m_fileHandle; +} + +// +// Under normal circumstances, create a read buffer and queue the file read into it +// Also deals with completion logic (normal or error) +// +void RChannelBufferReaderNativeFile::FileReadHandler:: + QueueRead(DryadNativePort* port) +{ + if (m_detailsPresent == false) + { + // + // this handler was created before the file was opened. Let's + // see if we're the first handler to come by, and if so try to + // open the file + // + bool waitForThrottledOpen = m_parent->EnsureOpenForRead(this); + if (waitForThrottledOpen) + { + // + // If we have to wait for opening, + // do nothing right now. This will be sent to the port + // eventually (on another thread) when the file is finally + // opened. + // + return; + } + else + { + LogAssert(m_detailsPresent); + } + } + + + if (m_fileHandle == INVALID_HANDLE_VALUE) + { + // + // If file handle is invalid, then file open has failed. + // report error and 0 bytes read + // + ProcessIO(DrError_EndOfStream, 0); + } + else + { + // + // If file is valid, queue up a read + // + port->QueueNativeRead(GetFileHandle(), this); + } +} + +// +// Create data or error buffer and handle it depending on the provided status +// +void RChannelBufferReaderNativeFile::FileReadHandler:: + ProcessIO(DrError cse, UInt32 numBytes) +{ + bool makeNewHandler = false; + + // todo: decide if we want remove commented code +// DrLogI( +// "Native read completed", +// "name = %s handle=%p, offset=%I64u, numBytes=%u, err=%s", +// m_parent->m_fileNameA, +// m_fileHandle, GetStreamOffset(), numBytes, DRERRORSTRING(cse)); + + if (m_parent->m_fileIsPipe && + cse == DrErrorFromWin32(ERROR_BROKEN_PIPE)) + { + // + // if pipe read fails with ERROR_BROKEN_PIPE => assume pipe is closed + // + cse = DrError_EndOfStream; + } + + + if (cse == DrError_EndOfStream) + { + // + // If there is an end of stream message, no bytes should be processed + // and final buffer should be created + // + LogAssert(numBytes == 0); + + SetChannelBuffer(m_parent->MakeFinalBuffer(GetStreamOffset())); + } + else + { + LogAssert(m_fileHandle != INVALID_HANDLE_VALUE); + + if (cse == DrError_OK) + { + // + // If the processing under normal circumstances, + // create a channel buffer to use for IO + // + DryadAlignedReadBlock* block = GetBlock(); + LogAssert(numBytes > 0); + LogAssert(numBytes <= block->GetAllocatedSize()); + block->Trim(numBytes); + + SetChannelBuffer(m_parent-> + MakeDataBuffer(GetStreamOffset(), block)); + + // + // todo: I don't understand why this is known to be the last data buffer when numBytes != allocated block size + // but a new handler should be made when numBytes == allocated block size + // ask JC + // + if (numBytes == block->GetAllocatedSize()) + { + makeNewHandler = true; + } + else + { + SignalLastDataBuffer(); + } + } + else + { + // + // If reporting error other than end of stream, no bytes should be processed + // and final buffer should be created + // + SetChannelBuffer(m_parent-> + MakeErrorBuffer(GetStreamOffset(), cse)); + } + } + + // + // Cause all buffers to be read and handle completion logic + // + m_parent->DispatchBuffer(this, makeNewHandler, NULL); +} + + +RChannelBufferReaderNativeFile:: + RChannelBufferReaderNativeFile(UInt32 bufferSize, + size_t bufferAlignment, + UInt32 prefetchBuffers, + DryadNativePort* port, + WorkQueue* workQueue, + RChannelOpenThrottler* openThrottler) : + RChannelBufferReaderNative(prefetchBuffers, port, + workQueue, openThrottler, + true) +{ + m_bufferSize = bufferSize; + m_bufferAlignment = bufferAlignment; + m_nextOffsetToRequest = 0; + m_fileHandle = INVALID_HANDLE_VALUE; + m_fileNameA = new char[MAX_PATH]; + m_fileNameW = new wchar_t[MAX_PATH]; + m_wideFileName = false; + m_fileIsPipe = false; +} + +RChannelBufferReaderNativeFile::~RChannelBufferReaderNativeFile() +{ + LogAssert(m_fileHandle == INVALID_HANDLE_VALUE); + delete [] m_fileNameA; + delete [] m_fileNameW; +} + +void RChannelBufferReaderNativeFile::FillInStatus(DryadChannelDescription* status) +{ + RChannelBufferReaderNative::FillInStatus(status); + { + AutoCriticalSection acs(&m_baseDR); + status->SetChannelURI(m_fileNameA); + } +} + +bool RChannelBufferReaderNativeFile::LazyOpenFile() +{ + DWORD flags = 0; + + if (m_fileIsPipe) + { + flags = FILE_FLAG_OVERLAPPED; + } + else + { + flags = FILE_FLAG_NO_BUFFERING | FILE_FLAG_OVERLAPPED; + } + + HANDLE h = INVALID_HANDLE_VALUE; + int attemptCount = 0; + + DrLogI( "Opening native file. Filename %s (%swide-char)", m_fileNameA, + m_wideFileName ? "" : "not "); + + while (h == INVALID_HANDLE_VALUE && attemptCount < 3) + { + attemptCount++; + + if (m_wideFileName) + { + h = ::CreateFileW(m_fileNameW, + GENERIC_READ, + FILE_SHARE_READ, + NULL, + OPEN_EXISTING, + flags, + NULL); + } + else + { + h = ::CreateFileA(m_fileNameA, + GENERIC_READ, + FILE_SHARE_READ, + NULL, + OPEN_EXISTING, + flags, + NULL); + } + + if (h == INVALID_HANDLE_VALUE) + { + // sleep for attemptCount * (3000 + rand value in the range 0 < x < 2000) milliseconds + unsigned int sleepTime = 0; + errno_t err = rand_s(&sleepTime); + if (err == 0) + { + sleepTime = (unsigned int) ((double)sleepTime / + (double) UINT_MAX * 2000.0); + } + else + { + DrLogE("rand_s call failed, adding 2000 ms."); + sleepTime = 2000; + } + sleepTime = attemptCount * (3000 + sleepTime); + DrLogI("Native file open failed", + "Filename %s (%swide-char) Retrying after %d ms", m_fileNameA, + m_wideFileName ? "" : "not ", sleepTime); + Sleep(sleepTime); + } + + } + + { + AutoCriticalSection acs(GetBaseDR()); + + if (h == INVALID_HANDLE_VALUE) + { + DrLogI( "Native file open failed. Filename %s (%swide-char)", m_fileNameA, + m_wideFileName ? "" : "not "); + + DrError err = DrGetLastError(); + DrStr64 description; + description.SetF("Can't open native file '%s' to read", + m_fileNameA); + m_openErrorBuffer.Attach(MakeOpenErrorBuffer(err, description)); + + return false; + } + else + { + // todo: remove comments if we're not logging this +// DrLogI( "Native file open succeeded", +// "Filename %s (%swide-char)", m_fileNameA, +// m_wideFileName ? "" : "not "); + + m_fileHandle = h; + + LARGE_INTEGER fileSize; + BOOL bRet = ::GetFileSizeEx(m_fileHandle, &fileSize); + LogAssert(bRet != 0); + + AssociateHandleWithPort(h); + SetTotalLength(fileSize.QuadPart); + + return true; + } + } +} + +// +// Close file when done +// +void RChannelBufferReaderNativeFile::EagerCloseFile() +{ + LogAssert(m_fileHandle != INVALID_HANDLE_VALUE); + LogAssert(m_openErrorBuffer == NULL); + + DrLogI( "Closing native file. Name %s (%swide-char)", m_fileNameA, + m_wideFileName ? "" : "not "); + + BOOL bRet = ::CloseHandle(m_fileHandle); + LogAssert(bRet != 0); + m_fileHandle = INVALID_HANDLE_VALUE; +} + +bool RChannelBufferReaderNativeFile::OpenA(const char* pathName) +{ + { + AutoCriticalSection acs(GetBaseDR()); + DrStr128 mappedPath; + + LogAssert(m_fileHandle == INVALID_HANDLE_VALUE); + LogAssert(m_nextOffsetToRequest == 0); + + /* DRYADONLY DrNetworkToLocal(0, pathName, mappedPath); */ + HRESULT hr = ::StringCbCopyA(m_fileNameA, MAX_PATH, pathName); + LogAssert(SUCCEEDED(hr)); + hr = ::StringCbLengthA(m_fileNameA, MAX_PATH, &m_fileNameLength); + LogAssert(SUCCEEDED(hr)); + m_wideFileName = false; + m_openErrorBuffer == NULL; + + if (ConcreteRChannel::IsNamedPipe(m_fileNameA)) + { + m_fileIsPipe = true; + + // Resize the buffer and change open + SetPrefetchBufferCount(1); // to make pipe data access sequential + UInt32 bufferAlignment = 1024; + UInt32 buffSize = 64*bufferAlignment; + if (m_bufferSize > buffSize) + { + m_bufferSize = buffSize; + m_bufferAlignment = bufferAlignment; + DrLogI("Reduced input buffer size for pipe. Size now %u", m_bufferSize); + } + } + + OpenNativeReader(); + } + + return true; +} + +/* JC +bool RChannelBufferReaderNativeFile::OpenW(const wchar_t* pathName) +{ + { + AutoCriticalSection acs(GetBaseDR()); + DrStr128 strPathName; + + LogAssert(m_fileHandle == INVALID_HANDLE_VALUE); + LogAssert(m_nextOffsetToRequest == 0); + + // DRYADONLY DrNetworkToLocal(0, DRWSTRINGTOUTF8(pathName), strPathName) + HRESULT hr = ::StringCbCopyW(m_fileNameW, MAX_PATH, DRUTF8TOWSTRING(strPathName)); + LogAssert(SUCCEEDED(hr)); + m_wideFileName = true; + + LogAssert(strPathName.GetString() != NULL); + LogAssert(strPathName.GetLength() < MAX_PATH-1); + hr = ::StringCbCopyA(m_fileNameA, MAX_PATH, strPathName.GetString()); + LogAssert(SUCCEEDED(hr)); + hr = ::StringCbLengthA(m_fileNameA, MAX_PATH, &m_fileNameLength); + LogAssert(SUCCEEDED(hr)); + m_openErrorBuffer == NULL; + + if (ConcreteRChannel::IsNamedPipe(m_fileNameA)) + { + m_fileIsPipe = true; + + // Resize the buffer and change open + SetPrefetchBufferCount(1); // to make pipe data access sequential + UInt32 bufferAlignment = 1024; + UInt32 buffSize = 64*bufferAlignment; + if (m_bufferSize > buffSize) + { + m_bufferSize = buffSize; + m_bufferAlignment = bufferAlignment; + DrLogI( + "Reduced input buffer size for pipe", + "Size now %u", m_bufferSize); + } + } + + OpenNativeReader(); + } + + return true; +} +*/ + +void RChannelBufferReaderNativeFile:: + Start(RChannelBufferPrefetchInfo* /*prefetchCookie*/, + RChannelBufferReaderHandler* handler) +{ + { + AutoCriticalSection acs(GetBaseDR()); + + LogAssert(m_nextOffsetToRequest == 0); + } + + StartNativeReader(handler); +} + +// +// Drain read channel +// +void RChannelBufferReaderNativeFile::Drain(RChannelItem* drainItem) +{ + DrainNativeReader(); + + { + AutoCriticalSection acs(GetBaseDR()); + + m_nextOffsetToRequest = 0; + } +} + +void RChannelBufferReaderNativeFile::Close() +{ + DrRef errorBuffer; + + { + AutoCriticalSection acs(GetBaseDR()); + + LogAssert(m_fileHandle == INVALID_HANDLE_VALUE); + + errorBuffer.Attach(CloseNativeReader()); + } +} + +void RChannelBufferReaderNativeFile:: + MakeFileMetaData(DryadMetaData* metaData, UInt64 streamOffset) +{ + DryadMTagRef tag; + tag.Attach(DryadMTagString::Create(Prop_Dryad_ChannelURI, + m_fileNameA)); + metaData->Append(tag, false); + + tag.Attach(DryadMTagUInt64::Create(Prop_Dryad_ChannelBufferOffset, + streamOffset)); + metaData->Append(tag, false); +} + +// +// Return new read handler +// +RChannelBufferReaderNative::ReadHandler* + RChannelBufferReaderNativeFile::GetNextReadHandler(bool lazyOpenDone) +{ + RChannelBufferReaderNative::ReadHandler* handler; + + { + AutoCriticalSection acs(GetBaseDR()); + + // + // Create a read handler + // + handler = new FileReadHandler(m_fileHandle, + lazyOpenDone, + m_nextOffsetToRequest, + m_bufferSize, + m_bufferAlignment, + this); + m_nextOffsetToRequest += m_bufferSize; + } + + return handler; +} + +// +// Make an error buffer or return one if it already exists +// +RChannelBuffer* RChannelBufferReaderNativeFile:: + MakeFinalBuffer(UInt64 streamOffset) +{ + { + AutoCriticalSection acs(GetBaseDR()); + + if (m_openErrorBuffer == NULL) + { + return MakeEndOfStreamBuffer(streamOffset); + } + else + { + m_openErrorBuffer->IncRef(); + return m_openErrorBuffer; + } + } +} + +void RChannelBufferReaderNativeFile::FillInOpenedDetails(ReadHandler* h) +{ + FileReadHandler* handler = dynamic_cast(h); + + if (m_openErrorBuffer == NULL) + { + LogAssert(m_fileHandle != INVALID_HANDLE_VALUE); + } + else + { + LogAssert(m_fileHandle == INVALID_HANDLE_VALUE); + } + + handler->SetFileHandle(m_fileHandle); +} diff --git a/DryadVertex/VertexHost/system/channel/src/channelbuffernativereader.h b/DryadVertex/VertexHost/system/channel/src/channelbuffernativereader.h new file mode 100644 index 0000000..5a56661 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelbuffernativereader.h @@ -0,0 +1,241 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "dryadnativeport.h" +#include "channelreader.h" +#include "concreterchannelhelpers.h" +#include +#include + +#pragma warning(disable:4995) +#include + +class RChannelBufferReaderNative : + public RChannelBufferReader, public RChannelBufferDefaultHandler, + public RChannelThrottledStream +{ +public: + class ReadHandler : public DryadNativePort::Handler + { + public: + ReadHandler(UInt64 requestOffset, UInt64 streamOffset, + UInt32 bufferSize, + size_t bufferAlignment); + virtual ~ReadHandler(); + + UInt64 GetStreamOffset(); + void* GetData(); + DryadAlignedReadBlock* GetBlock(); + RChannelBuffer* TransferChannelBuffer(); + bool IsLastDataBuffer(); + + virtual void QueueRead(DryadNativePort* port) = 0; + + protected: + void SetChannelBuffer(RChannelBuffer* buffer); + void SignalLastDataBuffer(); + + private: + UInt64 m_streamOffset; + DryadAlignedReadBlock* m_block; + RChannelBuffer* m_buffer; + bool m_isLastDataBuffer; + DrBListEntry m_listPtr; + friend class DryadBList; + }; + + RChannelBufferReaderNative(UInt32 prefetchBuffers, + DryadNativePort* port, + WorkQueue* workQueue, + RChannelOpenThrottler* openThrottler, + bool supportsLazyOpen); + virtual ~RChannelBufferReaderNative(); + + void Interrupt(); + + void ReturnBuffer(RChannelBuffer* buffer); + + virtual void FillInStatus(DryadChannelDescription* status); + + bool EnsureOpenForRead(ReadHandler* handler); + + void OpenAfterThrottle(); + +protected: + void StartNativeReader(RChannelBufferReaderHandler* handler); + void DrainNativeReader(); + + void SetPrefetchBufferCount(UInt32 numberOfBuffers); + UInt32 GetPrefetchBufferCount(); + + void AssociateHandleWithPort(HANDLE h); + void SetTotalLength(UInt64 totalLength); + bool GetTotalLength(UInt64* pLen); + + /* called with baseDR held */ + void OpenNativeReader(); + /* called with baseDR held */ + RChannelBuffer* CloseNativeReader(); + + void SetErrorBuffer(RChannelBuffer* buffer); + + void DispatchBuffer(ReadHandler* handler, bool makeNewHandler, + ReadHandler* fillInReadHandler); + + RChannelBuffer* MakeEndOfStreamBuffer(UInt64 streamOffset); + RChannelBuffer* MakeErrorBuffer(UInt64 streamOffset, + DrError errorCode); + RChannelBuffer* MakeOpenErrorBuffer(DrError errorCode, + const char* description); + RChannelBuffer* MakeDataBuffer(UInt64 streamOffset, + DryadLockedMemoryBuffer* block); + virtual void MakeFileMetaData(DryadMetaData* metaData, + UInt64 streamOffset) = 0; + virtual bool LazyOpenFile(); + virtual void FillInOpenedDetails(ReadHandler* handler); + virtual void EagerCloseFile(); + + CRITSEC* GetBaseDR(); + +private: + enum State { + S_Closed, + S_Stopped, + S_Running, + S_Stopping + }; + + enum OpenState { + OS_Closed, + OS_NotOpened, + OS_Waiting, + OS_Opened, + OS_OpenError, + OS_Stopped + }; + + typedef DryadBList HandlerList; + typedef std::map OffsetHandlerMap; + + bool FinishUsingFile(); + virtual ReadHandler* GetNextReadHandler(bool lazyOpenDone) = 0; + + + UInt32 m_prefetchBuffers; + + RChannelBufferReaderHandler* m_handler; + DryadNativePort* m_port; + WorkQueue* m_workQueue; + RChannelOpenThrottler* m_openThrottler; + bool m_supportsLazyOpen; + + State m_state; + OpenState m_openState; + bool m_fetching; + bool m_sentLastBuffer; + UInt32 m_outstandingHandlers; + UInt32 m_outstandingBuffers; + UInt64 m_nextStreamOffsetToProcess; + UInt64 m_totalLength; + OffsetHandlerMap m_reorderMap; + DrRef m_errorBuffer; + HANDLE m_handlerReturnEvent; + HANDLE m_bufferReturnEvent; + HandlerList m_openWaitingList; + bool m_drainingOpenQueue; + + DryadOrderedSendLatch m_latch; + +protected: + CRITSEC m_baseDR; + + friend class ReadHandler; +}; + +class RChannelBufferReaderNativeFile : public RChannelBufferReaderNative +{ +public: + class FileReadHandler : public RChannelBufferReaderNative::ReadHandler + { + public: + FileReadHandler(HANDLE fileHandle, + bool detailsPresent, + UInt64 streamOffset, + UInt32 dataSize, size_t dataAlignment, + RChannelBufferReaderNativeFile* parent); + + void ProcessIO(DrError cse, UInt32 numBytes); + void SetFileHandle(HANDLE h); + HANDLE GetFileHandle(); + void QueueRead(DryadNativePort* port); + + private: + HANDLE m_fileHandle; + bool m_detailsPresent; + RChannelBufferReaderNativeFile* m_parent; + }; + + RChannelBufferReaderNativeFile(UInt32 bufferSize, + size_t bufferAlignment, + UInt32 prefetchBuffers, + DryadNativePort* port, + WorkQueue* workQueue, + RChannelOpenThrottler* openThrottler); + ~RChannelBufferReaderNativeFile(); + + virtual void FillInStatus(DryadChannelDescription* status); + + bool OpenA(const char* pathName); +//JC bool OpenW(const wchar_t* pathName); + + void Start(RChannelBufferPrefetchInfo* prefetchCookie, + RChannelBufferReaderHandler* handler); + + void Drain(RChannelItem* drainItem); + + void Close(); + +protected: + void EagerCloseFile(); + bool LazyOpenFile(); + void ResetFileNameAndErrorStateA(const char *fileName); + +private: + void MakeFileMetaData(DryadMetaData* metaData, UInt64 streamOffset); + RChannelBuffer* MakeFinalBuffer(UInt64 streamOffset); + void FillInOpenedDetails(ReadHandler* handler); + + ReadHandler* GetNextReadHandler(bool lazyOpenDone); + + UInt32 m_bufferSize; + size_t m_bufferAlignment; + UInt64 m_nextOffsetToRequest; + HANDLE m_fileHandle; + char* m_fileNameA; + size_t m_fileNameLength; + bool m_fileIsPipe; + wchar_t* m_fileNameW; + bool m_wideFileName; + DrRef m_openErrorBuffer; + + friend class FileReadHandler; +}; diff --git a/DryadVertex/VertexHost/system/channel/src/channelbuffernativewriter.cpp b/DryadVertex/VertexHost/system/channel/src/channelbuffernativewriter.cpp new file mode 100644 index 0000000..ef78417 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelbuffernativewriter.cpp @@ -0,0 +1,3228 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include +#include +#include +#include +#include +//YARN #include + +#include +#include + +#pragma unmanaged + +static const UInt64 s_fileExtendChunk = 256 * 1024 * 1024; + +static const char* s_dscPartitionPrefix = "hpcdscpt://"; + +// This is used in the call to DscGetWritePath to make sure the +// current node isn't completely full. We don't know the initial +// size hint from the vertex code at the time we call DscGetWritePath, +// so can't use a better estimate without major code refactoring. +static const UInt64 c_estimatedInitialFileSize = 1; + +//JCstatic DryadStreamPropertyUpdater g_streamUpdater; +bool RChannelBufferWriterNativeFile::s_triedToSetPrivilege = false; +bool RChannelBufferWriterNativeFile::s_setPrivilege = false; +CRITSEC RChannelBufferWriterNativeFile::s_privilegeDR; + +RChannelBufferWriterNative::WriteHandler:: + WriteHandler(DryadFixedMemoryBuffer* block, + UInt64 streamOffset, + bool flushAfter, + RChannelBufferWriterHandler* handler, + RChannelBufferWriterNative* parent) +{ + m_block = block; + if (m_block != NULL) + { + Size_t availableLength = m_block->GetAvailableSize(); + LogAssert(availableLength < 0x100000000); + m_writeLength = (UInt32) availableLength; + m_data = m_block->GetWriteAddress(0, m_writeLength, &availableLength); + LogAssert(availableLength >= (Size_t) m_writeLength); + } + else + { + m_writeLength = 0; + m_data = NULL; + } + + m_streamOffset = streamOffset; + m_flush = flushAfter; + m_handler = handler; + m_parent = parent; +} + +RChannelBufferWriterNative::WriteHandler::~WriteHandler() +{ + if (m_block != NULL) + { + m_block->DecRef(); + } +} + +UInt32 RChannelBufferWriterNative::WriteHandler::GetWriteLength() +{ + return m_writeLength; +} + +UInt64 RChannelBufferWriterNative::WriteHandler::GetStreamOffset() +{ + return m_streamOffset; +} + +DryadFixedMemoryBuffer* RChannelBufferWriterNative::WriteHandler::GetBlock() +{ + return m_block; +} + +RChannelBufferWriterNative* RChannelBufferWriterNative::WriteHandler:: + GetParent() +{ + return m_parent; +} + +void RChannelBufferWriterNative::WriteHandler:: + ProcessingComplete(RChannelItemType statusCode) +{ + LogAssert(m_handler != NULL); + m_handler->ProcessWriteCompleted(statusCode); + m_handler = NULL; +} + +void* RChannelBufferWriterNative::WriteHandler::GetData() +{ + return m_data; +} + +bool RChannelBufferWriterNative::WriteHandler::IsFlush() +{ + return m_flush; +} + + +RChannelBufferWriterNative:: + RChannelBufferWriterNative(UInt32 outstandingWritesLowWatermark, + UInt32 outstandingWritesHighWatermark, + DryadNativePort* port, + RChannelOpenThrottler* openThrottler, + bool supportsLazyOpen) +{ + m_lowWatermark = outstandingWritesLowWatermark; + m_highWatermark = outstandingWritesHighWatermark; + m_port = port; + m_openThrottler = openThrottler; + m_supportsLazyOpen = supportsLazyOpen; + if (!m_supportsLazyOpen) + { + LogAssert(m_openThrottler == NULL); + } + + LogAssert(m_highWatermark > 0); + + m_outstandingBuffers = 0; + m_outstandingIOs = 0; + m_outstandingHandlerSends = 0; + m_outstandingTermination = false; + m_blocking = false; + m_flushing = false; + m_terminationHandler = NULL; + m_terminationEvent = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(m_terminationEvent != NULL); + + m_processedLength = 0; + m_initialSizeHint = 0; + + m_state = S_Closed; + m_openState = OS_Closed; + m_drainingOpenQueue = false; +} + +RChannelBufferWriterNative::~RChannelBufferWriterNative() +{ + LogAssert(m_state == S_Closed); + LogAssert(m_openState == OS_Closed); + LogAssert(m_drainingOpenQueue == false); + BOOL bRet = ::CloseHandle(m_terminationEvent); + LogAssert(bRet != 0); +} + +void RChannelBufferWriterNative::SetInitialSizeHint(UInt64 hint) +{ + m_initialSizeHint = hint; +} + +UInt64 RChannelBufferWriterNative::GetInitialSizeHint() +{ + return m_initialSizeHint; +} + +void RChannelBufferWriterNative::Start() +{ + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == S_Stopped); + LogAssert(m_openState == OS_Stopped); + LogAssert(m_drainingOpenQueue == false); + LogAssert(m_outstandingBuffers == 0); + LogAssert(m_outstandingIOs == 0); + LogAssert(m_outstandingHandlerSends == 0); + LogAssert(m_outstandingTermination == false); + LogAssert(m_blocking == false); + LogAssert(m_flushing == false); + LogAssert(m_blockedList.IsEmpty()); + LogAssert(m_pendingWriteSet.empty()); + LogAssert(m_pendingUnblockSet.empty()); + LogAssert(m_terminationHandler == NULL); + + m_openErrorItem = NULL; + + StartConcreteWriter(&m_completionItem); + + LogAssert(m_completionItem == NULL || + m_completionItem->GetType() != RChannelItem_EndOfStream); + + m_processedLength = 0; + + m_state = S_Running; + if (m_supportsLazyOpen) + { + m_openState = OS_NotOpened; + } + else + { + m_openState = OS_Opened; + } + + // todo: remove comment if not logging +// DrLogD( "Started", +// "this=%p", this); + } +} + +CRITSEC* RChannelBufferWriterNative::GetBaseDR() +{ + return &m_baseDR; +} + +DryadNativePort* RChannelBufferWriterNative::GetPort() +{ + return m_port; +} + +void RChannelBufferWriterNative::SetOpenErrorItem(RChannelItem* errorItem) +{ + LogAssert(m_openErrorItem == NULL); + m_openErrorItem = errorItem; +} + +bool RChannelBufferWriterNative::OpenError() +{ + return (m_openErrorItem == NULL) ? false : true; +} + +void RChannelBufferWriterNative::OpenInternal() +{ + LogAssert(m_state == S_Closed); + LogAssert(m_openState == OS_Closed); + LogAssert(m_drainingOpenQueue == false); + m_state = S_Stopped; + m_openState = OS_Stopped; +} + +bool RChannelBufferWriterNative::EnsureOpenForWrite(WriteHandler* handler) +{ + bool consumedHandler = false; + bool openFailed = false; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_supportsLazyOpen); + + if (m_openState == OS_NotOpened) + { + /* this is the first write ever queued on this channel */ + if (m_openThrottler == NULL || + m_openThrottler->QueueOpen(this)) + { + /* we aren't throttling opens or it's ready to open + now, so open the file now */ + if (LazyOpenFile()) + { + m_openState = OS_Opened; + } + else + { + m_openState = OS_OpenError; + openFailed = true; + } + } + else + { + /* the throttler will call us back (outside all + locks) when the file gets opened. */ + m_openState = OS_Waiting; + } + } + + if (m_openState == OS_Waiting || m_drainingOpenQueue) + { + /* the open has been queued but hasn't yet completed, so + add this handler to a list of waiters that will get + sent once the file is open */ + m_openWaitingList.InsertAsTail(m_openWaitingList. + CastIn(handler)); + + /* true return means the caller shouldn't do anything with + the handler right now since it has been queued. */ + consumedHandler = true; + } + else + { + LogAssert(m_openState == OS_Opened || m_openState == OS_OpenError); + /* the file is successfully opened or had an open error: + either way, it's as open as it's going to get: fill in + the appropriate handles */ + FillInOpenedDetails(handler); + } + } + + if (openFailed && m_openThrottler != NULL) + { + /* the throttler said we could open the file, and the open + failed, so tell it the file isn't open. Call into the + throttler outside the lock since this may cause another + file to be opened */ + m_openThrottler->NotifyFileCompleted(); + } + + return consumedHandler; +} + +bool RChannelBufferWriterNative::FinishUsingFile() +{ + /* we shouldn't get to this routine if we are waiting for a + blocked open from the throttling mechanism */ + LogAssert(m_openState == OS_NotOpened || + m_openState == OS_OpenError || + m_openState == OS_Opened); + + if (m_drainingOpenQueue) + { + LogAssert(m_openWaitingList.IsEmpty()); + } + + /* this is the value that will be returned, and it is true if the + file was open when we entered, and was closed during the + progress of the function. This value is needed so that the + caller can notify the throttler outside the lock when a file + gets closed */ + bool performedClose = (m_openState == OS_Opened); + + if (m_openState == OS_NotOpened) + { + /* this means nothing was written before we were asked to + close, so we never did the lazy open */ + LogAssert(OpenError() == false); + + /* open now and close immediately. This writes a 0-length + output otherwise people downstream that try to read the + output will be sorely disappointed. Ignore throttling since + it makes life so much easier. */ + if (LazyOpenFile()) + { + m_openState = OS_Opened; + } + else + { + m_openState = OS_OpenError; + } + } + + if (m_openState == OS_Opened) + { + LogAssert(OpenError() == false); + EagerCloseFile(); + m_openState = OS_Stopped; + } + else + { + LogAssert(m_openState == OS_OpenError); + } + + return performedClose; +} + +void RChannelBufferWriterNative::OpenAfterThrottle() +{ + bool firstTime = true; + bool drained = false; + + while (!drained) + { + bool openFailure = false; + WriteHandlerList sendList; + + { + AutoCriticalSection acs(&m_baseDR); + + if (firstTime) + { + /* hold a reference to make sure the channel can't + finish before we exit this drain loop */ + ++m_outstandingHandlerSends; + } + + LogAssert(m_supportsLazyOpen); + LogAssert(m_openThrottler != NULL); + + if (m_openState == OS_Waiting) + { + LogAssert(firstTime); + LogAssert(m_drainingOpenQueue == false); + m_drainingOpenQueue = true; + } + + if (m_openWaitingList.IsEmpty()) + { + /* there's supposed to be something waiting when we + get opened */ + LogAssert(firstTime == false); + LogAssert(m_drainingOpenQueue == true); + m_drainingOpenQueue = false; + + drained = true; + /* now we will exit the while loop and decrement + m_outstandingHandlerSends again */ + } + else + { + if (firstTime == true) + { + LogAssert(m_openState == OS_Waiting); + + /* try to actually open the file now. If we try + and fail to open the file here, record the fact + so we can tell the throttler about it once we + exit the lock. */ + if (LazyOpenFile()) + { + m_openState = OS_Opened; + } + else + { + m_openState = OS_OpenError; + openFailure = true; + } + + firstTime = false; + } + + LogAssert(m_openState == OS_Opened || + m_openState == OS_OpenError); + + /* copy out everything that's currently waiting */ + + while (m_openWaitingList.IsEmpty() == false) + { + WriteHandler* nextRequest = + m_openWaitingList. + CastOut(m_openWaitingList.RemoveHead()); + FillInOpenedDetails(nextRequest); + sendList.InsertAsTail(sendList.CastIn(nextRequest)); + } + } + + /* the openWaitingList is now empty but we're going to + leave the lock while remaining in the + m_drainingOpenQueue state and send off the + handlers. Then we'll go around the loop again in case + somebody added more handlers to the openWaitingList + while we were doing the send. */ + } + + while (sendList.IsEmpty() == false) + { + WriteHandler* nextRequest = + sendList.CastOut(sendList.RemoveHead()); + nextRequest->QueueWrite(m_port); + } + + if (openFailure) + { + /* this means we tried and failed to open the file. Let + the throttler know (outside the lock) so it can queue + up the next one */ + m_openThrottler->NotifyFileCompleted(); + } + } + + /* return our reference on m_outstandingHandlerSends. In the event + that a write we queued was the termination write, this may + cause the channel to give up its last reference and close */ + DecrementOutstandingHandlers(); +} + +/* this needs to be overwritten by derived classes that implement lazy + open/eager close (i.e. files but not pipes) */ +bool RChannelBufferWriterNative::LazyOpenFile() +{ + LogAssert(false); + return true; +} + +// +// this is only ever called as part of the lazy open mechanism, to +// fill in details of handlers after a deferred open has +// completed. Therefore if it's not overwritten, it asserts +// +void RChannelBufferWriterNative::FillInOpenedDetails(WriteHandler* handler) +{ + LogAssert(false); +} + +/* this does nothing by default: it needs to be overwritten by derived + classes that implement lazy open/eager close (i.e. files but not + pipes) */ +void RChannelBufferWriterNative::EagerCloseFile() +{ +} + +/* this does nothing by default: it needs to be overwritten by derived + classes that don't implement lazy open/eager close (i.e. pipes) */ +void RChannelBufferWriterNative::FinalCloseFile() +{ +} + +void RChannelBufferWriterNative::Close() +{ + { + AutoCriticalSection acs(&m_baseDR); + + FinalCloseFile(); + + LogAssert(m_state == S_Stopped); + LogAssert(m_openState == OS_OpenError || + m_openState == OS_Stopped); + LogAssert(m_drainingOpenQueue == false); + LogAssert(m_outstandingBuffers == 0); + LogAssert(m_outstandingIOs == 0); + LogAssert(m_outstandingHandlerSends == 0); + LogAssert(m_outstandingTermination == false); + LogAssert(m_blocking == false); + LogAssert(m_flushing == false); + LogAssert(m_blockedList.IsEmpty()); + LogAssert(m_pendingWriteSet.empty()); + LogAssert(m_pendingUnblockSet.empty()); + LogAssert(m_terminationHandler == NULL); + m_completionItem = NULL; + m_openErrorItem = NULL; + + m_state = S_Closed; + m_openState = OS_Closed; + + //todo: remove comment if not logging +// DrLogD( "Closed", +// "this=%p", this); + } +} + +void RChannelBufferWriterNative:: + FillInStatus(DryadChannelDescription* status) +{ + { + AutoCriticalSection acs (&m_baseDR); + + status->SetChannelTotalLength(0); + status->SetChannelProcessedLength(m_processedLength); + } +} + +/* called with baseDR held */ +void RChannelBufferWriterNative::SetProcessedLength(UInt64 processedLength) +{ + m_processedLength = processedLength; +} + +DryadFixedMemoryBuffer* RChannelBufferWriterNative::GetNextWriteBuffer() +{ + DryadFixedMemoryBuffer* block; + + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_state == S_Running); + LogAssert(m_terminationHandler == NULL); + LogAssert(m_outstandingTermination == false); + ++m_outstandingBuffers; + + block = GetNextWriteBufferInternal(); + } + + return block; +} + +void RChannelBufferWriterNative:: + SetLowAndHighWaterMark(UInt32 outstandingWritesLowWatermark, + UInt32 outstandingWritesHighWatermark) +{ + m_lowWatermark = outstandingWritesLowWatermark; + m_highWatermark = outstandingWritesHighWatermark; +} + +DryadFixedMemoryBuffer* + RChannelBufferWriterNative::GetCustomWriteBuffer(Size_t bufferSize) +{ + DryadFixedMemoryBuffer* block; + + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_state == S_Running); + LogAssert(m_terminationHandler == NULL); + LogAssert(m_outstandingTermination == false); + ++m_outstandingBuffers; + + block = GetCustomWriteBufferInternal(bufferSize); + } + + return block; +} + +void RChannelBufferWriterNative:: + ReturnUnusedBuffer(DryadFixedMemoryBuffer* block) +{ + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_state == S_Running); + LogAssert(m_outstandingBuffers > 0); + --m_outstandingBuffers; + + ReturnUnusedBufferInternal(block); + } +} + +/* called with baseDR held */ +bool RChannelBufferWriterNative::AddToWriteQueue(WriteHandler* writeRequest, + WriteHandlerList* writeQueue, + bool extendFile) +{ + bool shouldBlock = false; + + // todo: remove comment if not logging +// DrLogD( "Adding to write queue", +// "blocking %s flushing %s ios %u", +// (m_blocking) ? "true" : "false", (m_flushing) ? "true" : "false", +// m_outstandingIOs); + + if (extendFile) + { + m_blockingForFileExtension = true; + } + + if (m_blocking) + { + m_blockedList.InsertAsTail(m_blockedList.CastIn(writeRequest)); + shouldBlock = true; + } + else + { + LogAssert(m_outstandingIOs < m_highWatermark); + LogAssert(m_flushing == false); + + ++m_outstandingIOs; + if (m_outstandingIOs == m_highWatermark || writeRequest->IsFlush() || + m_blockingForFileExtension) + { + //todo: remove comment if not logging +// DrLogE( "Blocking"); + m_blocking = true; + m_flushing = writeRequest->IsFlush(); + shouldBlock = true; + } + + writeQueue->InsertAsTail(writeQueue->CastIn(writeRequest)); + + writeRequest->IncRef(); + std::pair retval; + retval = m_pendingWriteSet.insert(writeRequest); + LogAssert(retval.second == true); + } + + // todo: remove comment if not logging +// DrLogD( "Added to write queue", +// "blocking %s flushing %s shouldBlock %s ios %u", +// (m_blocking) ? "true" : "false", (m_flushing) ? "true" : "false", +// (shouldBlock) ? "true" : "false", m_outstandingIOs); + + return shouldBlock; +} + +bool RChannelBufferWriterNative:: + WriteBuffer(DryadFixedMemoryBuffer* buffer, + bool flushAfter, + RChannelBufferWriterHandler* handler) +{ + LogAssert(buffer->GetAvailableSize() <= buffer->GetAllocatedSize()); + + bool shouldBlock = false; + RChannelItemType returnCode = RChannelItem_Data; + WriteHandlerList processList; + + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_state == S_Running); + LogAssert(m_terminationHandler == NULL); + + if (m_completionItem == NULL) + { + bool lazyOpenDone = ((m_openState == OS_Opened || + m_openState == OS_OpenError) && + !m_drainingOpenQueue); + bool extendFile = false; + WriteHandler* writeHandler = + MakeWriteHandler(buffer, flushAfter, handler, lazyOpenDone, + &extendFile); + if (!lazyOpenDone && extendFile) + { + DrLogW("Not extending file since lazyOpen not done"); + extendFile = false; + } + if (extendFile) + { + DrLogW("Considering extending file"); + /* we need to extend the valid length, but if there + are outstanding IOs we'll have to drain them first, + and that process will be initiated by passing + extendFile=true to AddToWriteQueue */ + if (m_outstandingIOs == 0) + { + DrLogW("Extending file since no IOs in flight"); + ExtendFileValidLength(); + extendFile = false; + } + } + shouldBlock = AddToWriteQueue(writeHandler, &processList, + extendFile); + } + else + { + returnCode = m_completionItem->GetType(); + LogAssert(returnCode != RChannelItem_Data); + buffer->DecRef(); + } + + LogAssert(m_outstandingBuffers > 0); + --m_outstandingBuffers; + } + + if (returnCode != RChannelItem_Data) + { + LogAssert(shouldBlock == false); + LogAssert(processList.IsEmpty()); + handler->ProcessWriteCompleted(returnCode); + } + else + { + DrBListEntry* listEntry = processList.GetHead(); + while (listEntry != NULL) + { + WriteHandler* processHandler = processList.CastOut(listEntry); + listEntry = processList.GetNext(listEntry); + processList.Remove(processList.CastIn(processHandler)); + processHandler->QueueWrite(m_port); + } + } + + return shouldBlock; +} + +void RChannelBufferWriterNative:: + WriteTermination(RChannelItemType reasonCode, + RChannelBufferWriterHandler* handler) +{ + bool sendImmediately = false; + RChannelItemType statusCode = RChannelItem_Data; + + // Open channel if it was not opened before -- we can get errors here and should report them + if (m_openState == OS_NotOpened) + { + /* this means nothing was written before we were asked to + close, so we never did the lazy open */ + LogAssert(OpenError() == false); + + /* open now and close immediately. This writes a 0-length + output otherwise people downstream that try to read the + output will be sorely disappointed. Ignore throttling since + it makes life so much easier. */ + if (LazyOpenFile()) + { + m_openState = OS_Opened; + } + else + { + m_openState = OS_OpenError; + } + } + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == S_Running); + LogAssert(m_terminationHandler == NULL); + LogAssert(m_outstandingTermination == false); + + bool lazyOpenDone = ((m_openState == OS_Opened || + m_openState == OS_OpenError) && + !m_drainingOpenQueue); + bool extendFile = false; + m_terminationHandler = + MakeWriteHandler(NULL, false, handler, lazyOpenDone, &extendFile); + + m_outstandingTermination = true; + + if (m_outstandingHandlerSends == 0 && + m_outstandingIOs == 0) + { + LogAssert(m_blockedList.IsEmpty()); + LogAssert(m_outstandingBuffers == 0); + + sendImmediately = true; + + if (m_completionItem == NULL && m_openErrorItem != NULL) + { + m_completionItem = m_openErrorItem; + } + + if (m_completionItem == NULL) + { + statusCode = RChannelItem_EndOfStream; + } + else + { + statusCode = m_completionItem->GetType(); + } + } + } + + if (sendImmediately) + { + m_terminationHandler->ProcessingComplete(statusCode); + + bool performedClose; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_outstandingBuffers == 0); + LogAssert(m_blockedList.IsEmpty()); + LogAssert(m_outstandingIOs == 0); + LogAssert(m_outstandingHandlerSends == 0); + LogAssert(m_outstandingTermination == true); + /* hold on to the outstanding termination until after + we've set the event to avoid a race with the drain + thread */ + + if (m_completionItem == NULL) + { + LogAssert(m_openErrorItem == NULL); + + m_completionItem.Attach(RChannelMarkerItem::Create(reasonCode, + true)); + DryadMetaData* metaData = m_completionItem->GetMetaData(); + + DrError errorCode = DrError_OK; + switch (reasonCode) + { + case RChannelItem_EndOfStream: + errorCode = DrError_EndOfStream; + break; + + case RChannelItem_Restart: + errorCode = DryadError_ChannelRestart; + break; + + case RChannelItem_MarshalError: + case RChannelItem_Abort: + errorCode = DryadError_ChannelAbort; + break; + + default: + LogAssert(false); + }; + + metaData->AddErrorWithDescription(errorCode, + "Writer Sent Termination"); + } + + performedClose = FinishUsingFile(); + } + + if (performedClose && m_openThrottler != NULL) + { + /* the file was open previously, and is now closed. If + we're being throttled, let the throttler know so it can + open the next file in the queue */ + m_openThrottler->NotifyFileCompleted(); + } + + { + AutoCriticalSection acs(&m_baseDR); + + BOOL bRet = ::SetEvent(m_terminationEvent); + LogAssert(bRet != 0); + + LogAssert(m_outstandingTermination == true); + m_outstandingTermination = false; + } + } +} + +// +// Drain write channel +// +void RChannelBufferWriterNative::Drain(RChannelItemRef* pReturnItem) +{ + bool mustWait = false; + + //todo: remove comment if not logging +// DrLogD( "Draining", +// "this=%p", this); + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == S_Running); + + LogAssert(m_terminationHandler != NULL); + if (m_outstandingIOs > 0 || + m_outstandingHandlerSends > 0 || + m_outstandingTermination) + { + // + // If anything outstanding, wait for it to complete + // + mustWait = true; + BOOL bRet = ::ResetEvent(m_terminationEvent); + LogAssert(bRet != 0); + } + + m_state = S_Stopping; + } + + if (mustWait) + { + // + // wait for all writes to complete + // + + // todo: remove comment if not logging +// DrLogD( "Drain: waiting", +// "this=%p", this); + DWORD dRet = ::WaitForSingleObject(m_terminationEvent, INFINITE); + LogAssert(dRet == WAIT_OBJECT_0); + } + + { + AutoCriticalSection acs(&m_baseDR); + + // + // Verify everything is shutdown as expected + // + LogAssert(m_state == S_Stopping); + LogAssert(m_outstandingBuffers == 0); + LogAssert(m_outstandingIOs == 0); + LogAssert(m_outstandingHandlerSends == 0); + LogAssert(m_outstandingTermination == false); + LogAssert(m_blocking == false); + LogAssert(m_flushing == false); + LogAssert(m_terminationHandler != NULL); + LogAssert(m_blockedList.IsEmpty()); + LogAssert(m_openWaitingList.IsEmpty()); + LogAssert(m_pendingWriteSet.empty()); + LogAssert(m_pendingUnblockSet.empty()); + LogAssert(m_completionItem != NULL); + + // + // Clean up resources + // + m_terminationHandler->DecRef(); + m_terminationHandler = NULL; + + *pReturnItem = m_completionItem; + + DrainConcreteWriter(); + + m_state = S_Stopped; + +// todo: remove comment if not logging +// DrLogD( "Drained", +// "this=%p", this); + } +} + +/* called with baseDR held */ +void RChannelBufferWriterNative::Unblock(WriteHandlerList* returnList, + WriteHandlerList* processList) +{ +// todo: remove comment if not logging +// DrLogE( "UnBlocking"); + WriteHandlerSet::iterator iter; + for (iter = m_pendingUnblockSet.begin(); iter != m_pendingUnblockSet.end(); + iter = m_pendingUnblockSet.erase(iter)) + { + returnList->InsertAsTail(returnList->CastIn(*iter)); + } + + if (m_blockingForFileExtension) + { + DrLogW("Drained IOs to extend file length"); + ExtendFileValidLength(); + } + + m_blocking = false; + m_flushing = false; + m_blockingForFileExtension = false; + + while (m_blocking == false && m_blockedList.IsEmpty() == false) + { + WriteHandler* nextRequest = + m_blockedList.CastOut(m_blockedList.RemoveHead()); + AddToWriteQueue(nextRequest, processList, false); + } +} + +void RChannelBufferWriterNative:: + ConsumeErrorBuffer(DrError errorCode, + WriteHandlerList* pReturnList, + WriteHandlerList* pProcessList) +{ + if (m_completionItem == NULL) + { + if (m_openErrorItem != NULL) + { + m_completionItem = m_openErrorItem; + } + else + { + m_completionItem.Attach(RChannelMarkerItem:: + CreateErrorItem(RChannelItem_Abort, + errorCode)); + } + + pReturnList->TransitionToTail(&m_blockedList); + Unblock(pReturnList, pProcessList); + LogAssert(pProcessList->IsEmpty()); + } + else + { + LogAssert(m_completionItem->GetType() != RChannelItem_EndOfStream); + LogAssert(m_blocking == false); + LogAssert(m_flushing == false); + } + + LogAssert(m_blockedList.IsEmpty()); + LogAssert(m_pendingUnblockSet.empty()); +} + +RChannelItemType RChannelBufferWriterNative:: + ConsumeBufferCompletion(WriteHandler* writeHandler, + DrError errorCode, + WriteHandlerList* pReturnList, + WriteHandlerList* pProcessList) +{ + RChannelItemType statusCode; + + { + AutoCriticalSection acs(&m_baseDR); + +// todo: remove comment if not logging +// DrLogE( "Consuming buffer", +// "state %u sends %u ios %u blocking %s flushing %s", +// m_state, m_outstandingHandlerSends, m_outstandingIOs, +// (m_blocking) ? "true" : "false", (m_flushing) ? "true" : "false"); + + LogAssert(m_state == S_Running || m_state == S_Stopping); + LogAssert(m_outstandingIOs > 0); + --m_outstandingIOs; + + size_t nRemoved = m_pendingWriteSet.erase(writeHandler); + LogAssert(nRemoved == 1); + + if (errorCode != DrError_OK) + { + ConsumeErrorBuffer(errorCode, pReturnList, pProcessList); + } + + if (m_blocking == false) + { + LogAssert(m_flushing == false); + LogAssert(m_blockedList.IsEmpty()); + pReturnList->InsertAsTail(pReturnList->CastIn(writeHandler)); + } + else + { + LogAssert(m_completionItem == NULL); + + std::pair retval; + retval = m_pendingUnblockSet.insert(writeHandler); + LogAssert(retval.second == true); + + bool flushAll = (m_flushing || m_blockingForFileExtension); + if ((flushAll == false && m_outstandingIOs <= m_lowWatermark) || + (flushAll == true && m_outstandingIOs == 0)) + { + Unblock(pReturnList, pProcessList); + } + } + + if (m_outstandingIOs == 0) + { + LogAssert(m_blockedList.IsEmpty()); + } + + if (m_completionItem == NULL || + m_completionItem->GetType() == RChannelItem_EndOfStream) + { + statusCode = RChannelItem_Data; + } + else + { + statusCode = m_completionItem->GetType(); + } + + ++m_outstandingHandlerSends; + } + + return statusCode; +} + +RChannelItemType RChannelBufferWriterNative:: + DealWithReturnHandlers(WriteHandler* writeHandler, DrError errorCode) +{ + RChannelItemType statusCode; + WriteHandlerList returnList; + WriteHandlerList processList; + + statusCode = ConsumeBufferCompletion(writeHandler, errorCode, + &returnList, &processList); + + writeHandler->DecRef(); + writeHandler = NULL; + + DrBListEntry* listEntry = returnList.GetHead(); + while (listEntry != NULL) + { + WriteHandler* handler = returnList.CastOut(listEntry); + listEntry = returnList.GetNext(listEntry); + returnList.Remove(returnList.CastIn(handler)); + handler->ProcessingComplete(statusCode); + handler->DecRef(); + } + + listEntry = processList.GetHead(); + while (listEntry != NULL) + { + WriteHandler* handler = processList.CastOut(listEntry); + listEntry = processList.GetNext(listEntry); + processList.Remove(processList.CastIn(handler)); + handler->QueueWrite(m_port); + } + + return statusCode; +} + +void RChannelBufferWriterNative:: + ReceiveBufferInternal(WriteHandler* writeHandler, + DrError errorCode) +{ + /* the following increments m_outstandingHandlerSends */ + DealWithReturnHandlers(writeHandler, errorCode); + + DecrementOutstandingHandlers(); +} + +void RChannelBufferWriterNative::DecrementOutstandingHandlers() +{ + WriteHandler* returnTermination = NULL; + + RChannelItemType statusCode = RChannelItem_Data; + + { + AutoCriticalSection acs(&m_baseDR); + +// todo: remove comment if not logging +// DrLogD( "Decrementing outstanding", +// "ios %u outstandinghandlers %u outstandingtermination %s", +// m_outstandingIOs, m_outstandingHandlerSends, +// (m_outstandingTermination) ? "true" : "false"); + + LogAssert(m_outstandingHandlerSends > 0); + --m_outstandingHandlerSends; + + if (m_outstandingIOs == 0 && + m_outstandingHandlerSends == 0 && + m_outstandingTermination == true) + { + LogAssert(m_outstandingBuffers == 0); + LogAssert(m_terminationHandler != NULL); + returnTermination = m_terminationHandler; + + if (m_completionItem == NULL) + { + statusCode = RChannelItem_EndOfStream; + } + else + { + statusCode = m_completionItem->GetType(); + } + } + } + + if (returnTermination != NULL) + { +// todo: remove comment if not logging +// DrLogD( "Calling processing complete", +// "status %s", +// DRERRORSTRING(statusCode)); + + returnTermination->ProcessingComplete(statusCode); + + bool performedClose; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_outstandingBuffers == 0); + LogAssert(m_outstandingIOs == 0); + LogAssert(m_outstandingHandlerSends == 0); + LogAssert(m_blockedList.IsEmpty()); + LogAssert(m_outstandingTermination == true); + /* hold on to the outstanding termination until after + we've set the event to avoid a race with the drain + thread */ + if (m_completionItem == NULL) + { + LogAssert(m_openErrorItem == NULL); + + m_completionItem.Attach(RChannelMarkerItem:: + Create(RChannelItem_EndOfStream, + true)); + } + +// todo: remove comment if not logging +// DrLogD( "Calling finish using file"); + + performedClose = FinishUsingFile(); + } + + if (performedClose && m_openThrottler != NULL) + { + /* the file was open previously, and is now closed. If + we're being throttled, let the throttler know so it can + open the next file in the queue */ + m_openThrottler->NotifyFileCompleted(); + } + + { + AutoCriticalSection acs(&m_baseDR); + + //todo: remove comment if not logging +// DrLogD( "Setting event"); + + BOOL bRet = ::SetEvent(m_terminationEvent); + LogAssert(bRet != 0); + + LogAssert(m_outstandingTermination == true); + m_outstandingTermination = false; + } + } +} + + +RChannelBufferWriterNativeFile:: + RChannelBufferWriterNativeFile(UInt32 bufferSize, + size_t bufferAlignment, + UInt32 outstandingWritesLowWatermark, + UInt32 outstandingWritesHighWatermark, + DryadNativePort* port, + RChannelOpenThrottler* openThrottler) : + RChannelBufferWriterNative(outstandingWritesLowWatermark, + outstandingWritesHighWatermark, + port, openThrottler, true) +{ + m_bufferSize = bufferSize; + m_bufferAlignment = bufferAlignment; + LogAssert(m_bufferSize >= m_bufferAlignment); + + m_rawFileHandle = INVALID_HANDLE_VALUE; + m_bufferedFileHandle = INVALID_HANDLE_VALUE; + m_fileNameA = new char[MAX_PATH]; +//JC m_fileNameW = new wchar_t[MAX_PATH]; + m_fileNameLength = 0; +//JC m_wideFileName = false; + + m_tryToCreatePath = false; + + m_nextOffsetToWrite = 0; + m_realignmentSize = 0; + m_fileIsPipe = false; + m_calcFP = false; + m_fp = 0; + +} + +RChannelBufferWriterNativeFile::~RChannelBufferWriterNativeFile() +{ + LogAssert(m_rawFileHandle == INVALID_HANDLE_VALUE); + LogAssert(m_bufferedFileHandle == INVALID_HANDLE_VALUE); + delete [] m_fileNameA; +//JC delete [] m_fileNameW; + if (m_calcFP) + { + Dryad_dupelim_fprint_close(m_fpo); + m_calcFP = false; + } +} + +DrError RChannelBufferWriterNativeFile::SetMetaData(DryadMetaData* metaData) +{ + if (metaData == NULL) + { + return DrError_OK; + } + + if (metaData->LookUpVoidTag(Prop_Dryad_TryToCreateChannelPath) != NULL) + { + m_tryToCreatePath = true; + } + + UInt64 initialSize; + if (metaData->LookUpUInt64(Prop_Dryad_InitialChannelWriteSize, + &initialSize) == DrError_OK) + { + SetInitialSizeHint(initialSize); + } + + return DrError_OK; +} + +DrError RChannelBufferWriterNativeFile::TryToCreatePathA() +{ + DrStr128 fileName(m_fileNameA); + size_t separator = 0; + + if (fileName.StartsWith("\\\\", 2)) + { + /* skip the machine name */ + separator = fileName.IndexOfChar('\\', 2); + if (separator == DrStr_InvalidIndex) + { + return DrErrorFromWin32(ERROR_PATH_NOT_FOUND); + } + /* skip the share name */ + separator = fileName.IndexOfChar('\\', separator+1); + } + else + { + /* skip the drive letter */ + separator = fileName.IndexOfChar('\\', 0); + } + + if (separator == DrStr_InvalidIndex) + { + /* this isn't a fully qualified path so punt */ + return DrErrorFromWin32(ERROR_PATH_NOT_FOUND); + } + + while (separator != DrStr_InvalidIndex) + { + /* find the next path component */ + separator = fileName.IndexOfChar('\\', separator+1); + if (separator != DrStr_InvalidIndex) + { + /* try to create this directory */ + DrStr128 pathName; + pathName.Set(fileName, separator); + BOOL b = ::CreateDirectoryA(pathName, NULL); + if (!b) + { + DrError err = DrGetLastError(); + if (err != DrErrorFromWin32(ERROR_ALREADY_EXISTS)) + { + return err; + } + } + } + } + + return DrError_OK; +} + +//JC +#if 0 +DrError RChannelBufferWriterNativeFile::TryToCreatePathW() +{ + DrWStr128 fileName(m_fileNameW); + size_t separator = 0; + + if (fileName.StartsWith(L"\\\\", 2)) + { + /* skip the machine name */ + separator = fileName.IndexOfChar(L'\\', 2); + if (separator == DrStr_InvalidIndex) + { + return DrErrorFromWin32(ERROR_PATH_NOT_FOUND); + } + /* skip the share name */ + separator = fileName.IndexOfChar(L'\\', separator+1); + } + else + { + /* skip the drive letter */ + separator = fileName.IndexOfChar(L'\\', 0); + } + + if (separator == DrStr_InvalidIndex) + { + /* this isn't a fully qualified path so punt */ + return DrErrorFromWin32(ERROR_PATH_NOT_FOUND); + } + + while (separator != DrStr_InvalidIndex) + { + /* find the next path component */ + separator = fileName.IndexOfChar(L'\\', separator+1); + if (separator != DrStr_InvalidIndex) + { + /* try to create this directory */ + DrWStr128 pathName; + pathName.Set(fileName, separator); + BOOL b = ::CreateDirectoryW(pathName, NULL); + if (!b) + { + DrError err = DrGetLastError(); + if (err != DrErrorFromWin32(ERROR_ALREADY_EXISTS)) + { + return err; + } + } + } + } + + return DrError_OK; +} +#endif + +HANDLE RChannelBufferWriterNativeFile::CreateFileAndPath(DrError* pErr, SECURITY_ATTRIBUTES *sa) +{ + HANDLE h; + DrError err = DrError_OK; + +/* JC + if (m_wideFileName) + { + h = ::CreateFileW(m_fileNameW, + GENERIC_WRITE, + FILE_SHARE_WRITE | FILE_SHARE_READ, + NULL, + CREATE_ALWAYS, + FILE_FLAG_OVERLAPPED, + NULL); + } + else + {*/ + h = ::CreateFileA(m_fileNameA, + GENERIC_WRITE, + FILE_SHARE_WRITE | FILE_SHARE_READ, + sa, + CREATE_ALWAYS, + FILE_FLAG_OVERLAPPED, + NULL); +//JC } + + if (h == INVALID_HANDLE_VALUE) + { + err = DrGetLastError(); + if (err == DrErrorFromWin32(ERROR_PATH_NOT_FOUND) && m_tryToCreatePath) + { +/*JC if (m_wideFileName) + { + err = TryToCreatePathW(); + } + else + {*/ + err = TryToCreatePathA(); +//JC } + + if (err == DrError_OK) + { +/*JC if (m_wideFileName) + { + h = ::CreateFileW(m_fileNameW, + GENERIC_WRITE, + FILE_SHARE_WRITE | FILE_SHARE_READ, + NULL, + CREATE_ALWAYS, + FILE_FLAG_OVERLAPPED, + NULL); + } + else + {*/ + h = ::CreateFileA(m_fileNameA, + GENERIC_WRITE, + FILE_SHARE_WRITE | FILE_SHARE_READ, + sa, + CREATE_ALWAYS, + FILE_FLAG_OVERLAPPED, + NULL); +//JC } + + if (h == INVALID_HANDLE_VALUE) + { + err = DrGetLastError(); + } + } + } + } + + *pErr = err; + return h; +} + +// +// Helper to amplify privileges +// +static +BOOL +SetCurrentPrivilege ( + IN LPCTSTR Privilege, // Privilege to enable/disable + IN OUT BOOL *bEnablePrivilege // to enable or disable privilege + ) +/* + + If successful, *bEnablePrivlege is set to the new state. + If NOT successful, bEnablePrivlege is invalid + + Returns: + TRUE - success + FALSE - failure + */ +{ + HANDLE hToken; + TOKEN_PRIVILEGES tp; + LUID luid; + TOKEN_PRIVILEGES tpPrevious; + DWORD cbPrevious = sizeof(TOKEN_PRIVILEGES); + BOOL bSuccess; + BOOL bEnableIt; + + bEnableIt = *bEnablePrivilege; + + if (!LookupPrivilegeValue(NULL, Privilege, &luid)) { + return FALSE; + } + + if(!OpenProcessToken( + GetCurrentProcess(), + TOKEN_QUERY | TOKEN_ADJUST_PRIVILEGES, + &hToken + )) { + return FALSE; + } + + // + // first pass. get current privilege setting + // + tp.PrivilegeCount = 1; + tp.Privileges[0].Luid = luid; + tp.Privileges[0].Attributes = 0; + + AdjustTokenPrivileges( + hToken, + FALSE, + &tp, + sizeof(TOKEN_PRIVILEGES), + &tpPrevious, + &cbPrevious + ); + + bSuccess = FALSE; + + if(GetLastError() == ERROR_SUCCESS) { + // + // second pass. set privilege based on previous setting + // + tpPrevious.PrivilegeCount = 1; + tpPrevious.Privileges[0].Luid = luid; + + *bEnablePrivilege = tpPrevious.Privileges[0].Attributes | (SE_PRIVILEGE_ENABLED); + + if(bEnableIt) { + tpPrevious.Privileges[0].Attributes |= (SE_PRIVILEGE_ENABLED); + } + else { + tpPrevious.Privileges[0].Attributes ^= (SE_PRIVILEGE_ENABLED & + tpPrevious.Privileges[0].Attributes); + } + + AdjustTokenPrivileges( + hToken, + FALSE, + &tpPrevious, + cbPrevious, + NULL, + NULL + ); + + if (GetLastError() == ERROR_SUCCESS) { + bSuccess=TRUE; + } + } + + CloseHandle(hToken); + + return bSuccess; +} + +static void SetInitialFileLength(HANDLE fp, UInt64 initialLength, + bool setPrivilegeForValidLength) +{ + LARGE_INTEGER fPointer; + fPointer.QuadPart = initialLength; + DWORD status = SetFilePointerEx(fp, fPointer, NULL, FILE_BEGIN); + if (status == INVALID_SET_FILE_POINTER) + { + DrLogW( "Failed to set initial length. Length %I64u error %s", initialLength, + DRERRORSTRING(DrGetLastError())); + return; + } + else + { + DrLogI( "Set initial length. Length %I64u", initialLength); + } + + BOOL ok = SetEndOfFile(fp); + if (!ok) + { + DrLogW( "Failed to set end of file. Length %I64u error %s", initialLength, + DRERRORSTRING(DrGetLastError())); + } + else + { + DrLogI( "Set end of file. Length %I64u", initialLength); + } + + if (setPrivilegeForValidLength) + { + ok = SetFileValidData(fp, initialLength); + if (!ok) + { + DrLogW( "Failed to set file valid data. Length %I64u error %s", initialLength, + DRERRORSTRING(DrGetLastError())); + } + else + { + DrLogI( "Set file valid data. Length %I64u", initialLength); + } + } + else + { + DrLogW("Not setting initial valid length: didn't get privileges"); + } + + status = SetFilePointer(fp, 0, 0, FILE_BEGIN); + if (status == INVALID_SET_FILE_POINTER) + { + /* assert here, since we really do need to start writing from + byte zero... */ + DrLogA( "Failed to reset initial length. Length %I64u error %s", initialLength, + DRERRORSTRING(DrGetLastError())); + } + else + { + DrLogI( "Reset file pointer"); + } +} + +bool RChannelBufferWriterNativeFile::TryToSetPrivilege() +{ + { + AutoCriticalSection acs(&s_privilegeDR); + + if (s_triedToSetPrivilege == false) + { + LogAssert(s_setPrivilege == false); + s_triedToSetPrivilege = true; + + BOOL bEnabled = TRUE; + if (SetCurrentPrivilege(SE_MANAGE_VOLUME_NAME, &bEnabled)) + { + s_setPrivilege = true; + DrLogI("Set SE_MANAGE_VOLUME_NAME privilege"); + } + else + { + DrLogI("Failed to set SE_MANAGE_VOLUME_NAME privilege"); + } + } + + return s_setPrivilege; + } +} + +// +// Release the security descriptor if it has been created +// +void CleanUpSecurityDescriptor(SECURITY_ATTRIBUTES* sa) +{ + if (NULL != sa) + { + if(NULL != sa->lpSecurityDescriptor) + { + HeapFree(GetProcessHeap(), HEAP_ZERO_MEMORY, sa->lpSecurityDescriptor); + sa->lpSecurityDescriptor = NULL; + } + } +} + +// +// Build a security descriptor for the job owner +// +DrError GenerateSecurityDescriptor(SECURITY_ATTRIBUTES *sa) +{ + PSID pHpcReplicationSID = NULL, pAdminSID = NULL, pOwnerSID = NULL, pRunAsSID = NULL; + PACL pACL = NULL; + PSECURITY_DESCRIPTOR pSD = NULL; + sa->lpSecurityDescriptor = pSD; + int accessCount = 4; + EXPLICIT_ACCESS ea[4]; + SID_IDENTIFIER_AUTHORITY SIDAuthNT = SECURITY_NT_AUTHORITY; + + + DrError e = DrError_Fail; + + // Prep Explicit Access Structure + ZeroMemory(&ea, accessCount * sizeof(EXPLICIT_ACCESS)); + + // Get the SID for runas user + e = DrGetSidForUser(L"HpcReplication", &pHpcReplicationSID); + if(e != DrError_OK) + { + DrLogE("Unable to get SID for HpcReplication user: %u", GetLastError()); + goto Cleanup; + } + + // If successful, provide read access to HpcReplication user + ea[0].grfAccessPermissions = GENERIC_READ; + ea[0].grfAccessMode = SET_ACCESS; + ea[0].grfInheritance= NO_INHERITANCE; + ZeroMemory(&ea[0].Trustee, sizeof(TRUSTEE)); + ea[0].Trustee.TrusteeForm = TRUSTEE_IS_SID; + ea[0].Trustee.TrusteeType = TRUSTEE_IS_GROUP; + ea[0].Trustee.ptstrName = (LPTSTR) pHpcReplicationSID; + + // Create a SID for the administrators group. + if(! AllocateAndInitializeSid(&SIDAuthNT, 2, + SECURITY_BUILTIN_DOMAIN_RID, + DOMAIN_ALIAS_RID_ADMINS, + 0, 0, 0, 0, 0, 0, + &pAdminSID)) + { + DrLogE("Unable to allocate SID for Administrators group: %u", GetLastError()); + goto Cleanup; + } + + // If successful, provide full control to administrators group + ea[1].grfAccessPermissions = GENERIC_ALL; + ea[1].grfAccessMode = SET_ACCESS; + ea[1].grfInheritance= NO_INHERITANCE; + ZeroMemory(&ea[1].Trustee, sizeof(TRUSTEE)); + ea[1].Trustee.TrusteeForm = TRUSTEE_IS_SID; + ea[1].Trustee.TrusteeType = TRUSTEE_IS_GROUP; + ea[1].Trustee.ptstrName = (LPTSTR) pAdminSID; + + // Get username of job runas user from the environment + WCHAR userName[MAX_PATH+1] = {0}; + HRESULT hr = DrGetEnvironmentVariable(L"USERNAME",userName); + if(hr != 0) + { + DrLogE("Unable to get job runas user name: 0x%08x", hr); + goto Cleanup; + } + + // Get domain of job runas user from the environment + WCHAR domain[MAX_PATH+1] = {0}; + hr = DrGetEnvironmentVariable(L"USERDOMAIN",domain); + if(hr != 0) + { + DrLogE("Unable to get job runas user domain: 0x%08x", hr); + goto Cleanup; + } + + // Build fully qualified user name in the form domain\username + WCHAR domainUser[MAX_PATH] = {0}; + wsprintf(domainUser, L"%s\\%s", domain, userName); + + // Get the SID for runas user + e = DrGetSidForUser(domainUser, &pRunAsSID); + if(e != DrError_OK) + { + DrLogE("Unable to get SID for job runas user, %ls.", domainUser); + goto Cleanup; + } + + // If successful, give runas user full permissions + ea[2].grfAccessPermissions = GENERIC_ALL; + ea[2].grfAccessMode = SET_ACCESS; + ea[2].grfInheritance= NO_INHERITANCE; + ZeroMemory(&ea[2].Trustee, sizeof(TRUSTEE)); + ea[2].Trustee.TrusteeForm = TRUSTEE_IS_SID; + ea[2].Trustee.TrusteeType = TRUSTEE_IS_USER; + ea[2].Trustee.ptstrName = (LPTSTR) pRunAsSID; + + // Get SID of job owner from the environment + WCHAR sidEnv[MAX_PATH+1] = {0}; + hr = DrGetEnvironmentVariable(L"CCP_OWNER_SID",sidEnv); + if(hr != 0) + { + DrLogW("Unable to get job owner SID: %ld.", hr); + accessCount = 3; + } + else + { + // Converts SID string into functional SID + if(!ConvertStringSidToSidW(sidEnv, &pOwnerSID)) + { + DrLogW("Unable to convert job owner SID into functional SID: %u", GetLastError()); + accessCount = 3; + } + else + { + // If successful, give full control to job owner + ea[3].grfAccessPermissions = GENERIC_ALL; + ea[3].grfAccessMode = SET_ACCESS; + ea[3].grfInheritance= NO_INHERITANCE; + ZeroMemory(&ea[3].Trustee, sizeof(TRUSTEE)); + ea[3].Trustee.TrusteeForm = TRUSTEE_IS_SID; + ea[3].Trustee.TrusteeType = TRUSTEE_IS_USER; + ea[3].Trustee.ptstrName = (LPTSTR) pOwnerSID; + } + } + + // Create a new ACL that contains all the entries + hr = SetEntriesInAcl(accessCount, ea, NULL, &pACL); + if (ERROR_SUCCESS != hr) + { + DrLogE("Unable to create new access control list: %u", GetLastError()); + goto Cleanup; + } + + // Initialize a security descriptor. + pSD = (PSECURITY_DESCRIPTOR) HeapAlloc( + GetProcessHeap(), + HEAP_ZERO_MEMORY, + SECURITY_DESCRIPTOR_MIN_LENGTH); + if (NULL == pSD) + { + printf("LocalAlloc Error %u\n", GetLastError()); + goto Cleanup; + } + + if (!InitializeSecurityDescriptor(pSD, SECURITY_DESCRIPTOR_REVISION)) + { + DrLogE("Unable to initialize a Security Descriptor: Error %u", GetLastError()); + goto Cleanup; + } + + // Add the ACL to the security descriptor. + if (!SetSecurityDescriptorDacl(pSD, TRUE, pACL, FALSE)) + { + DrLogE("Unable to set the DACL in the Security Descriptor: Error %u", GetLastError()); + goto Cleanup; + } + + // Initialize the security attributes structure with created security descriptor + sa->nLength = sizeof (SECURITY_ATTRIBUTES); + sa->lpSecurityDescriptor = pSD; + sa->bInheritHandle = FALSE; + + DrLogI("Successfully generated security descriptor"); + // Successful if everything worked as planned + return DrError_OK; + +Cleanup: + // Free security descriptor pointer + sa->lpSecurityDescriptor = pSD; + CleanUpSecurityDescriptor(sa); + + return DrError_Fail; +} + +// YARN Skip security descriptor version for now but leave the code in place in case we want it in the future +#if 0 +bool RChannelBufferWriterNativeFile::LazyOpenFile() +{ + HANDLE hBuffered; + SECURITY_ATTRIBUTES sa; + + DrLogI( "Opening native file. Filename %s", m_fileNameA); + + /* always try to extend the length of the file even if we didn't + get a length hint, since it helps write speeds so much */ + if (m_fileIsPipe) + { + m_canExtendFileLength = false; + } + else + { + m_canExtendFileLength = TryToSetPrivilege(); + } + + + + // Generate the security descriptor for output files + // YARN - Ignore the security descriptor for now, + //but leave the code in place in case we want it in the future + DrError err = GenerateSecurityDescriptor(&sa); + if(DrError_OK == err) + { + // If successful, create the file + hBuffered = CreateFileAndPath(&err, &sa); + + if (hBuffered == INVALID_HANDLE_VALUE) + { + // If the file cannot be opened, report failure + DrLogI( "Buffered native file open failed. Filename %s", m_fileNameA); + + RChannelItemRef errorItem; + DrStr64 description; + description.SetF("Can't open buffered native file '%s' to write", + m_fileNameA); + errorItem.Attach(RChannelMarkerItem:: + CreateErrorItemWithDescription(RChannelItem_Abort, + err, + description)); + SetOpenErrorItem(errorItem); + } + else + { + LogAssert(err == DrError_OK); + + DrLogI( "Buffered native file open succeeded. Filename %s", m_fileNameA); + + HANDLE h = INVALID_HANDLE_VALUE; + if (!m_fileIsPipe) + { + /*JC if (m_wideFileName) + { + h = ::CreateFileW(m_fileNameW, + GENERIC_WRITE, + FILE_SHARE_WRITE | FILE_SHARE_READ, + NULL, + OPEN_EXISTING, + FILE_FLAG_NO_BUFFERING | + FILE_FLAG_OVERLAPPED, + NULL); + } + else + {*/ + h = ::CreateFileA(m_fileNameA, + GENERIC_WRITE, + FILE_SHARE_WRITE | FILE_SHARE_READ, + &sa, + OPEN_EXISTING, + FILE_FLAG_NO_BUFFERING | + FILE_FLAG_OVERLAPPED, + NULL); + //JC } + + if (h == INVALID_HANDLE_VALUE) + { + // If file cannot be opened, report failure + DrLogI( "Native file open failed. Filename %s", m_fileNameA); + + RChannelItemRef errorItem; + DrError err = DrGetLastError(); + DrStr64 description; + description.SetF("Can't open native file '%s' to write", + m_fileNameA); + + errorItem.Attach(RChannelMarkerItem:: + CreateErrorItemWithDescription(RChannelItem_Abort, + err, + description)); + SetOpenErrorItem(errorItem); + + BOOL bRet = ::CloseHandle(hBuffered); + LogAssert(bRet != 0); + hBuffered = INVALID_HANDLE_VALUE; + } + else + { + DrLogI( + "Native file open succeeded. Filename %s, rawHandle=%u, bufferedHandle=%u", + m_fileNameA, + h, hBuffered); + } + } + + if (hBuffered != INVALID_HANDLE_VALUE) + { + m_rawFileHandle = h; + m_bufferedFileHandle = hBuffered; + + if (h != INVALID_HANDLE_VALUE) + { + GetPort()->AssociateHandle(h); + } + GetPort()->AssociateHandle(hBuffered); + + m_fileLengthSet = GetInitialSizeHint(); + if (m_fileLengthSet == 0 && m_canExtendFileLength) + { + DrLogW("No initial size hint: extending to 0x%I64x", + m_fileLengthSet); + m_fileLengthSet = s_fileExtendChunk; + } + if (m_fileLengthSet != 0) + { + LogAssert(!m_fileIsPipe); + + DrLogW("Setting initial size hint: 0x%I64x", + m_fileLengthSet); + SetInitialFileLength(h, m_fileLengthSet, + m_canExtendFileLength); + } + + // Release the security descriptor + CleanUpSecurityDescriptor(&sa); + + return true; + } + } + } + else + { + // Failed to build security descriptor. Report failure + RChannelItemRef errorItem; + DrStr64 description; + description.SetF("Can't create security descriptor for %s", m_fileNameA); + errorItem.Attach(RChannelMarkerItem::CreateErrorItemWithDescription(RChannelItem_Abort, err, description)); + SetOpenErrorItem(errorItem); + } + + + // Release the security descriptor + CleanUpSecurityDescriptor(&sa); + + // we get here if the open failed + LogAssert(OpenError()); + + return false; +} +#endif + +bool RChannelBufferWriterNativeFile::LazyOpenFile() +{ + HANDLE hBuffered; + DrLogI( "Opening native file. Filename %s", m_fileNameA); + + /* always try to extend the length of the file even if we didn't + get a length hint, since it helps write speeds so much */ + if (m_fileIsPipe) + { + m_canExtendFileLength = false; + } + else + { + m_canExtendFileLength = TryToSetPrivilege(); + } + + DrError err = DrError_OK; + // If successful, create the file + hBuffered = CreateFileAndPath(&err, NULL); + + if (hBuffered == INVALID_HANDLE_VALUE) + { + // If the file cannot be opened, report failure + DrLogI( "Buffered native file open failed. Filename %s", m_fileNameA); + + RChannelItemRef errorItem; + DrStr64 description; + description.SetF("Can't open buffered native file '%s' to write", + m_fileNameA); + errorItem.Attach(RChannelMarkerItem:: + CreateErrorItemWithDescription(RChannelItem_Abort, + err, + description)); + SetOpenErrorItem(errorItem); + } + else + { + LogAssert(err == DrError_OK); + + DrLogI( "Buffered native file open succeeded. Filename %s", m_fileNameA); + + HANDLE h = INVALID_HANDLE_VALUE; + if (!m_fileIsPipe) + { + /*JC if (m_wideFileName) + { + h = ::CreateFileW(m_fileNameW, + GENERIC_WRITE, + FILE_SHARE_WRITE | FILE_SHARE_READ, + NULL, + OPEN_EXISTING, + FILE_FLAG_NO_BUFFERING | + FILE_FLAG_OVERLAPPED, + NULL); + } + else + {*/ + h = ::CreateFileA(m_fileNameA, + GENERIC_WRITE, + FILE_SHARE_WRITE | FILE_SHARE_READ, + NULL, + OPEN_EXISTING, + FILE_FLAG_NO_BUFFERING | + FILE_FLAG_OVERLAPPED, + NULL); + //JC } + + if (h == INVALID_HANDLE_VALUE) + { + // If file cannot be opened, report failure + DrLogI( "Native file open failed. Filename %s", m_fileNameA); + + RChannelItemRef errorItem; + DrError err = DrGetLastError(); + DrStr64 description; + description.SetF("Can't open native file '%s' to write", + m_fileNameA); + + errorItem.Attach(RChannelMarkerItem:: + CreateErrorItemWithDescription(RChannelItem_Abort, + err, + description)); + SetOpenErrorItem(errorItem); + + BOOL bRet = ::CloseHandle(hBuffered); + LogAssert(bRet != 0); + hBuffered = INVALID_HANDLE_VALUE; + } + else + { + DrLogI( + "Native file open succeeded. Filename %s, rawHandle=%u, bufferedHandle=%u", + m_fileNameA, + h, hBuffered); + } + } + + if (hBuffered != INVALID_HANDLE_VALUE) + { + m_rawFileHandle = h; + m_bufferedFileHandle = hBuffered; + + if (h != INVALID_HANDLE_VALUE) + { + GetPort()->AssociateHandle(h); + } + GetPort()->AssociateHandle(hBuffered); + + m_fileLengthSet = GetInitialSizeHint(); + if (m_fileLengthSet == 0 && m_canExtendFileLength) + { + DrLogW("No initial size hint: extending to 0x%I64x", + m_fileLengthSet); + m_fileLengthSet = s_fileExtendChunk; + } + if (m_fileLengthSet != 0) + { + LogAssert(!m_fileIsPipe); + + DrLogW("Setting initial size hint: 0x%I64x", + m_fileLengthSet); + SetInitialFileLength(h, m_fileLengthSet, + m_canExtendFileLength); + } + + return true; + } + } + // we get here if the open failed + LogAssert(OpenError()); + + return false; +} + +void RChannelBufferWriterNativeFile::EagerCloseFile() +{ + LogAssert(m_bufferedFileHandle != INVALID_HANDLE_VALUE); + + DrLogI( "Closing native file. File %s", m_fileNameA); + + m_fpDataLength = m_nextOffsetToWrite; + + if (!m_fileIsPipe) + { + LARGE_INTEGER finalLength; + finalLength.QuadPart = m_nextOffsetToWrite; + BOOL ok = SetFilePointerEx(m_bufferedFileHandle, finalLength, + NULL, FILE_BEGIN); + if (!ok) + { + DrLogA( + "Couldn't set file pointer to end. Length %I64u Error: %s", + finalLength.QuadPart, DRERRORSTRING(DrGetLastError())); + } + else + { + DrLogI( "Set final file pointer"); + } + + ok = SetEndOfFile(m_bufferedFileHandle); + if (!ok) + { + DrLogA( "Couldn't truncate file. Error: %s", DRERRORSTRING(DrGetLastError())); + } + else + { + DrLogI( "Truncated file"); + } + } + + BOOL bRet; + + if (m_rawFileHandle == INVALID_HANDLE_VALUE) + { + LogAssert(m_fileIsPipe); + } + else + { + LogAssert(!m_fileIsPipe); + bRet = ::CloseHandle(m_rawFileHandle); + LogAssert(bRet != 0); + m_rawFileHandle = INVALID_HANDLE_VALUE; + } + + bRet = ::CloseHandle(m_bufferedFileHandle); + LogAssert(bRet != 0); + m_bufferedFileHandle = INVALID_HANDLE_VALUE; +} + +bool RChannelBufferWriterNativeFile::OpenA(const char* pathName) +{ + { + AutoCriticalSection acs(GetBaseDR()); + DrStr128 mappedPath; + + LogAssert(m_rawFileHandle == INVALID_HANDLE_VALUE); + LogAssert(m_bufferedFileHandle == INVALID_HANDLE_VALUE); + LogAssert(m_nextOffsetToWrite == 0); + + /* DRYADONLY DrNetworkToLocal(0, pathName, mappedPath); */ + HRESULT hr = ::StringCbCopyA(m_fileNameA, MAX_PATH, pathName); + LogAssert(SUCCEEDED(hr)); +//JC m_wideFileName = false; + hr = ::StringCbLengthA(m_fileNameA, MAX_PATH, &m_fileNameLength); + LogAssert(SUCCEEDED(hr)); + + if (ConcreteRChannel::IsNamedPipe(m_fileNameA)) + { + UInt32 bufferAlignment = 1024; + UInt32 buffSize = 64*bufferAlignment; + if (m_bufferSize > buffSize) + { + m_bufferSize = buffSize; + m_bufferAlignment = bufferAlignment; + DrLogI( + "Reduced output buffer size for pipe. Size now %u", m_bufferSize); + } + SetLowAndHighWaterMark(0,1); // make the pipe writes sequential + m_fileIsPipe = true; + } + + OpenInternal(); + } + + return true; +} + +/* JC +bool RChannelBufferWriterNativeFile::OpenW(const wchar_t* pathName) +{ + { + AutoCriticalSection acs(GetBaseDR()); + DrStr128 strPathName; + + LogAssert(m_rawFileHandle == INVALID_HANDLE_VALUE); + LogAssert(m_bufferedFileHandle == INVALID_HANDLE_VALUE); + LogAssert(m_nextOffsetToWrite == 0); + + // DRYADONLY DrNetworkToLocal(0, DRWSTRINGTOUTF8(pathName), strPathName); + HRESULT hr = ::StringCbCopyW(m_fileNameW, MAX_PATH, DRUTF8TOWSTRING(strPathName)); + LogAssert(SUCCEEDED(hr)); + m_wideFileName = true; + + LogAssert(strPathName.GetString() != NULL); + LogAssert(strPathName.GetLength() < MAX_PATH-1); + hr = ::StringCbCopyA(m_fileNameA, MAX_PATH, strPathName.GetString()); + LogAssert(SUCCEEDED(hr)); + hr = ::StringCbLengthA(m_fileNameA, MAX_PATH, &m_fileNameLength); + LogAssert(SUCCEEDED(hr)); + + if (ConcreteRChannel::IsNamedPipe(m_fileNameA)) + { + UInt32 bufferAlignment = 1024; + UInt32 buffSize = 64*bufferAlignment; + if (m_bufferSize > buffSize) + { + m_bufferSize = buffSize; + m_bufferAlignment = bufferAlignment; + DrLogI( + "Reduced output buffer size for pipe", + "Size now %u", m_bufferSize); + } + SetLowAndHighWaterMark(0,1); // make the pipe writes sequential + m_fileIsPipe = true; + } + + OpenInternal(); + } + + return true; +} +*/ + +/* called with baseDR held */ +void RChannelBufferWriterNativeFile:: + StartConcreteWriter(RChannelItemRef* pCompletionItem) +{ + LogAssert(m_realignmentSize == 0); + if (*pCompletionItem != NULL && + (*pCompletionItem)->GetType() == RChannelItem_EndOfStream) + { + RChannelItem* error = + RChannelMarkerItem:: + CreateErrorItemWithDescription(RChannelItem_Abort, + DryadError_ChannelRestartError, + "Can't restart channel after " + "sending EOF"); + pCompletionItem->Attach(error); + } +} + +/* called with baseDR held */ +void RChannelBufferWriterNativeFile::DrainConcreteWriter() +{ + m_realignmentSize = 0; + m_nextOffsetToWrite = 0; +} + +/* called with baseDR held */ +DryadFixedMemoryBuffer* RChannelBufferWriterNativeFile:: + GetNextWriteBufferInternal() +{ + if (m_realignmentSize == 0) + { + return GetCustomWriteBufferInternal(m_bufferSize); + } + else + { + return GetCustomWriteBufferInternal(m_realignmentSize); + } +} + +/* called with baseDR held */ +DryadFixedMemoryBuffer* RChannelBufferWriterNativeFile:: + GetCustomWriteBufferInternal(Size_t bufferSize) +{ + DryadFixedMemoryBuffer* buffer; + + if (m_realignmentSize == 0 && AlignmentGap(bufferSize) == 0) + { + buffer = + new DryadAlignedWriteBlock(bufferSize, m_bufferAlignment); + } + else + { + LogAssert(m_realignmentSize < m_bufferSize); + buffer = new DryadAlignedWriteBlock(bufferSize, 0); + } + + if (bufferSize <= m_realignmentSize) + { + m_realignmentSize -= (UInt32) bufferSize; + } + else + { + /* figure out new realignment offset. First re-base bufferSize + to the start of an alignment block */ + bufferSize -= m_realignmentSize; + /* then figure out the unaligned overhang we wrote */ + UInt32 downAlignment = AlignmentGap(bufferSize); + if (downAlignment != 0) + { + /* and if there is any, adjust to get back to the regular + buffer size later */ + LogAssert(downAlignment < m_bufferSize); + m_realignmentSize = m_bufferSize - downAlignment; + } + } + + return buffer; +} + +/* called with baseDR held */ +void RChannelBufferWriterNativeFile:: + ReturnUnusedBufferInternal(DryadFixedMemoryBuffer* block) +{ + block->DecRef(); +} + +void RChannelBufferWriterNativeFile::ReceiveBuffer(WriteHandler* writeHandler, + DrError errorCode) +{ + ReceiveBufferInternal(writeHandler, errorCode); +} + +bool RChannelBufferWriterNativeFile::IsAligned(UInt64 offset) +{ + return ((offset & ((UInt64) m_bufferAlignment - 1)) == 0); +} + +UInt32 RChannelBufferWriterNativeFile::AlignmentGap(UInt64 offset) +{ + UInt64 gap = (offset & ((UInt64) (m_bufferAlignment - 1))); + LogAssert(gap < 0x100000000); + return (UInt32) gap; +} + +/* called with BaseDR set */ +void RChannelBufferWriterNativeFile::ExtendFileValidLength() +{ + m_fileLengthSet += s_fileExtendChunk; + SetInitialFileLength(m_rawFileHandle, m_fileLengthSet, + m_canExtendFileLength); +} + +/* called with baseDR held */ +RChannelBufferWriterNative::WriteHandler* RChannelBufferWriterNativeFile:: + MakeWriteHandler(DryadFixedMemoryBuffer* block, + bool flushAfter, + RChannelBufferWriterHandler* handler, + bool lazyOpenDone, + bool* extendFile) +{ + UInt32 writeLength; + HANDLE handleToUse; + bool isBuffered = false; + + if (block == NULL) + { + writeLength = 0; + handleToUse = INVALID_HANDLE_VALUE; + } + else + { + Size_t availableLength = block->GetAvailableSize(); + LogAssert(availableLength < 0x100000000); + writeLength = (UInt32) availableLength; + + if (!m_fileIsPipe && + IsAligned(m_nextOffsetToWrite) && IsAligned(writeLength)) + { + handleToUse = m_rawFileHandle; + } + else + { + handleToUse = m_bufferedFileHandle; + isBuffered = true; + } + if (m_calcFP) + { + size_t dataSize; + void *dataAddr = block->GetDataAddress(0, &dataSize, NULL); + //LogAssert(dataSize == availableLength); + m_fp = Dryad_dupelim_fprint_extend (m_fpo, m_fp, + (const unsigned char *) dataAddr, writeLength); + } + + } + + FileWriteHandler* writeHandler = + new FileWriteHandler(handleToUse, + lazyOpenDone, + isBuffered, + block, + m_nextOffsetToWrite, + flushAfter, + handler, this); + + m_nextOffsetToWrite += writeLength; + UInt32 downAlignment = AlignmentGap(m_nextOffsetToWrite); + if (downAlignment != 0) + { + LogAssert(downAlignment < m_bufferSize); + m_realignmentSize = m_bufferSize - downAlignment; + } + + SetProcessedLength(m_nextOffsetToWrite); + + if (!m_fileIsPipe && m_canExtendFileLength && + m_nextOffsetToWrite >= m_fileLengthSet) + { + *extendFile = true; + } + + return writeHandler; +} + +void RChannelBufferWriterNativeFile::FillInOpenedDetails(WriteHandler* h) +{ + HANDLE handleToUse; + + FileWriteHandler* handler = dynamic_cast(h); + + if (handler->IsBuffered()) + { + handleToUse = m_bufferedFileHandle; + } + else + { + handleToUse = m_rawFileHandle; + } + + if (OpenError()) + { + LogAssert(handleToUse == INVALID_HANDLE_VALUE); + } + else + { + LogAssert(handleToUse != INVALID_HANDLE_VALUE); + } + + handler->SetFileHandle(handleToUse); +} + +RChannelBufferWriterNativeFile::FileWriteHandler:: + FileWriteHandler(HANDLE handle, + bool detailsPresent, + bool isBuffered, + DryadFixedMemoryBuffer* block, + UInt64 streamOffset, + bool flushAfter, + RChannelBufferWriterHandler* handler, + RChannelBufferWriterNativeFile* parent) : + RChannelBufferWriterNative::WriteHandler(block, streamOffset, + flushAfter, + handler, parent) +{ + m_fileHandle = handle; + m_detailsPresent = detailsPresent; + if (!m_detailsPresent) + { + LogAssert(m_fileHandle == INVALID_HANDLE_VALUE); + } + + m_isBuffered = isBuffered; + if (block != NULL) + { + Size_t availableLength = block->GetAvailableSize(); + LogAssert(availableLength < 0x100000000); + InitializeInternal((UInt32) availableLength, streamOffset); + } +} + +void RChannelBufferWriterNativeFile::FileWriteHandler::SetFileHandle(HANDLE h) +{ + m_fileHandle = h; + LogAssert(m_detailsPresent == false); + m_detailsPresent = true; +} + +HANDLE RChannelBufferWriterNativeFile::FileWriteHandler::GetFileHandle() +{ + return m_fileHandle; +} + +bool RChannelBufferWriterNativeFile::FileWriteHandler::IsBuffered() +{ + return m_isBuffered; +} + +void RChannelBufferWriterNativeFile::FileWriteHandler:: + QueueWrite(DryadNativePort* port) +{ + if (m_detailsPresent == false) + { + /* we get here if the write was queued before the file was + opened. This happens in the case of lazy open. Either this + is the first write (that will trigger the lazy open) or + opens have been throttled and we will queue this up until + the open gets to the front of the queue */ + LogAssert(GetFileHandle() == INVALID_HANDLE_VALUE); + + RChannelBufferWriterNative* parent = GetParent(); + + bool waitForThrottledOpen = parent->EnsureOpenForWrite(this); + if (waitForThrottledOpen) + { + /* do nothing right now. This will be sent to the port + eventually (perhaps on another thread) when the file is + finally opened. */ + return; + } + else + { + LogAssert(m_detailsPresent); + } + } + + if (GetFileHandle() == INVALID_HANDLE_VALUE) + { + /* we will get here if there was an open error, so just pass + it straight back into the machinery without (obviously) + trying to actually write anything. The writer class will + fill in the appropriate error details. */ + ProcessIO(DrError_EndOfStream, 0); + } + else + { + // todo: remove comment if not logging +// DrLogE( "queue native write", +// "offset %I64u length %u", GetStreamOffset(), GetWriteLength()); + port->QueueNativeWrite(GetFileHandle(), this); + } +} + +void RChannelBufferWriterNativeFile::FileWriteHandler:: + ProcessIO(DrError errorCode, UInt32 numBytes) +{ + LogAssert(m_detailsPresent); + + if (errorCode == DrError_OK) + { + UInt32 requested = (UInt32) (*GetNumberOfBytesToTransferPtr()); + LogAssert(numBytes == requested); + } + else + { + LogAssert(numBytes == 0); + DrLogE( + "Native file write failed. handle=%u, err=%s", + GetFileHandle(), DRERRORSTRING(errorCode)); + } + + RChannelBufferWriterNative* baseParent = GetParent(); + RChannelBufferWriterNativeFile* parent = + (RChannelBufferWriterNativeFile *) baseParent; + parent->ReceiveBuffer(this, errorCode); +} + +#ifdef TIDYFS +RChannelBufferWriterNativeTidyFSStream:: + RChannelBufferWriterNativeTidyFSStream(UInt32 bufferSize, + size_t bufferAlignment, + UInt32 outstandingWritesLowWatermark, + UInt32 outstandingWritesHighWatermark, + DryadNativePort* port, + RChannelOpenThrottler* openThrottler) : + RChannelBufferWriterNativeFile(bufferSize, bufferAlignment, + outstandingWritesLowWatermark, + outstandingWritesHighWatermark, + port, openThrottler) +{ + DrLogD( "RChannelBufferWriterNativeTidyFSStream"); + m_client = new MDClient(); + LogAssert(m_client != NULL); + m_partId = 0; + + m_fpo = Dryad_dupelim_fprint_new (0x911498ae0e66bad6, 0); + m_fp = Dryad_dupelim_fprint_empty(m_fpo); + m_calcFP = true; +} + +RChannelBufferWriterNativeTidyFSStream::~RChannelBufferWriterNativeTidyFSStream() +{ + DrLogD( "~RChannelBufferWriterNativeTidyFSStream"); + // check if stream properly closed + delete m_client; + m_client = NULL; + m_partId = 0; + m_hostname = NULL; + DrLogD( "~RChannelBufferWriterNativeTidyFSStream Done"); +} + +DrError RChannelBufferWriterNativeTidyFSStream::OpenA(const char* streamName, DryadMetaData *metaData) +{ + + DrError result = m_client->Initialize("rsl.ini"); + if (result != DrError_OK) + { + DrLogE( + "Error initializing TidyFS client", "ErrorCode: %u ErrorString: %s", + result, GetDrErrorDescription(result)); + + return result; + } + + m_client->SetKeepAlive(true); + + const char *hostname = Configuration::GetRawMachineName(); + DrLogI( "CreateTidyFSStreamWriter", "Stream: %s, Host: %s", streamName, hostname); + + FILETIME currFileTime; + GetSystemTimeAsFileTime(&currFileTime); + UINT64 currTime = FiletimeToUInt64(currFileTime); + DrTimeInterval leaseInterval; + result = metaData->LookUpTimeInterval(Prop_Dryad_StreamExpireTimeWhileClosed, &leaseInterval); + if (result != DrError_OK) + { + const char* text = metaData->GetText(); + DrLogE( + "Can't read stream metadata" + "%s --- Stream: %s ErrorString: %s", + text, streamName, GetDrErrorDescription(result)); + + delete [] text; + return result; + } + + result= m_client->CreateStream(streamName, currTime + leaseInterval, 1, &m_partId, true); + if (result != DrError_OK) + { + DrLogE( + "Can't create temporary TidyFS Stream", "Stream: '%s' ErrorCode: %u ErrorString: %s", + streamName, result, GetDrErrorDescription(result)); + + return result; + } + + char path[2048]; + int hostLen = 1024; + m_hostname = new char[hostLen]; + LogAssert(m_hostname != NULL); + result = m_client->GetWritePath(path, 2048, m_hostname, hostLen, m_partId, hostname); + if (result != DrError_OK) + { + DrLogE( "Error in GetWritePath", + "Partition: %I64x Node: %s ErrorCode: %u ErrorString: %s", + m_partId, hostname, result, GetDrErrorDescription(result)); + return result; + } + m_client->SetKeepAlive(false); + + bool ok = RChannelBufferWriterNativeFile::OpenA(path); + if (!ok) + { + return DrError_Fail; + } + return DrError_OK; +} + +void RChannelBufferWriterNativeTidyFSStream::Close() +{ + // properly close stream + const char *hostname = Configuration::GetRawMachineName(); + DrLogI( "Closing TidyFS Stream", "Id: %I64x Size: %I64d FP: %I64x", + m_partId, m_fpDataLength, m_fp); + PartitionInfo *pi = new PartitionInfo(m_partId, m_fpDataLength, m_fp, m_hostname); + DrError result = m_client->AddPartitionInformation(&pi, 1); + if (result != DrError_OK) + { + DrLogE( "Error in AddPartitionInformation", + "Partition: %I64x Node: %s ErrorCode: %u ErrorString: %s", + m_partId, hostname, result, GetDrErrorDescription(result)); + } + delete m_hostname; + m_hostname = NULL; + delete pi; + DrLogI( "TidyFS Writer Close Calling parent Close"); + RChannelBufferWriterNativeFile::Close(); + +} +#endif + +//JC +#if 0 +RChannelBufferWriterDryadStream:: + RChannelBufferWriterDryadStream(UInt32 bufferSize, + UInt32 outstandingWritesLowWatermark, + UInt32 outstandingWritesHighWatermark, + DryadNativePort* port, + RChannelOpenThrottler* openThrottler) : + RChannelBufferWriterNative(outstandingWritesLowWatermark, + outstandingWritesHighWatermark, + port, openThrottler, true) +{ + m_bufferSize = bufferSize; + + m_streamHandle = NULL; + m_streamName = new char[MAX_PATH]; + m_streamNameLength = 0; + + m_nextOffsetToWrite = 0; + + m_expireLengthWhileOpen = DR_INFINITE; + m_expireLengthWhileClosed = DR_INFINITE; +} + +RChannelBufferWriterDryadStream::~RChannelBufferWriterDryadStream() +{ + LogAssert(m_streamHandle == NULL); + delete [] m_streamName; +} + +DrError RChannelBufferWriterDryadStream::SetMetaData(DryadMetaData* metaData) +{ + if (metaData == NULL) + { + return DrError_OK; + } + + DrTimeInterval interval; + + if (metaData->LookUpTimeInterval(Prop_Dryad_StreamExpireTimeWhileOpen, + &interval) == DrError_OK) + { + if (interval == DrTimeInterval_Infinite) + { + m_expireLengthWhileOpen = DR_INFINITE; + } + else + { + m_expireLengthWhileOpen = interval / DrTimeInterval_100ns; + } + } + + if (metaData->LookUpTimeInterval(Prop_Dryad_StreamExpireTimeWhileClosed, + &interval) == DrError_OK) + { + if (interval == DrTimeInterval_Infinite) + { + m_expireLengthWhileClosed = DR_INFINITE; + } + else + { + m_expireLengthWhileClosed = interval / DrTimeInterval_100ns; + } + } + + return DrError_OK; +} + +bool RChannelBufferWriterDryadStream::LazyOpenFile() +{ + DrLogI( "Opening cosmos stream", + "Name %s", m_streamName); + + bool handleOpened = false; + + DR_STREAM_PROPERTIES properties; + memset(&properties, 0, sizeof(properties)); + properties.cbSize = sizeof(properties); + /* we want to extend this initial expire period to add the refresh + interval with which we will be updating it, in case it would + otherwise expire before we get a chance to update it */ + properties.ExpirePeriod = + g_streamUpdater.GetExpirePeriodWithSlop(m_expireLengthWhileOpen); + + DrError err = ::DrOpenStream(m_streamName, + DR_APPEND | DR_CREATE, + &properties, + &m_streamHandle, + NULL); + + // Delete output stream if it already exists + if (err == DrError_StreamAlreadyExists) + { + err = ::DrDelete(m_streamName, DR_DELETE_FORCE, NULL); + if (err == DrError_OK) + err = ::DrOpenStream(m_streamName, + DR_APPEND | DR_CREATE, + &properties, + &m_streamHandle, + NULL); + } + + if (err == DrError_OK) + { + UINT64 appendSize = DR_STREAM_OPTION_MAX_APPEND_SIZE_MAX; + err = DrSetStreamOption(m_streamHandle, DR_STREAM_OPTION_MAX_APPEND_SIZE, &appendSize, sizeof(appendSize)); + + handleOpened = true; + } + + if (err == DrError_OK) + { + DrLogI( "Dryad stream open succeeded", + "Name %s", m_streamName); + + if (m_expireLengthWhileOpen != DR_INFINITE) + { + /* put it in a queue to be regularly extended while we are + running */ + g_streamUpdater.AddStream(m_streamName, + m_expireLengthWhileOpen); + } + + return true; + } + else + { + DrLogI( "Dryad stream open failed", + "Name %s", m_streamName); + + if (handleOpened) + { + ::DrCloseHandle(m_streamHandle); + m_streamHandle = NULL; + } + LogAssert(m_streamHandle == NULL); + + RChannelItemRef errorItem; + DrStr64 description; + description.SetF("Can't open cosmos stream '%s' to write", + m_streamName); + errorItem.Attach(RChannelMarkerItem:: + CreateErrorItemWithDescription(RChannelItem_Abort, + err, + description)); + SetOpenErrorItem(errorItem); + + return false; + } +} + +void RChannelBufferWriterDryadStream::EagerCloseFile() +{ + LogAssert(m_streamHandle != NULL); + + DR_STREAM_POSITION appendPosition; + appendPosition.ExtentIndex = 0; + appendPosition.Offset = m_nextOffsetToWrite; + SIZE_T numberOfBytesAppended; + + DrError cse; + DrLogI( + "Closing output stream --- appending zero bytes to updated DRM", + "streamhandle=%p", m_streamHandle); + cse = DrAppendStream(m_streamHandle, + NULL, + 0, + DR_SEAL | DR_FIXED_OFFSET_APPEND | + DR_UPDATE_DRM, + &appendPosition, + &numberOfBytesAppended, + NULL); + if (cse != DrError_OK) + { + DrLogW( + "Stream seal at close failed", + "stream: %s error %s", + m_streamName, DRERRORSTRING(cse)); + } + + if (m_expireLengthWhileOpen != DR_INFINITE) + { + /* we have been extending it. Take it off the queue */ + bool ret = g_streamUpdater.RemoveStream(m_streamName); + LogAssert(ret == true); + } + + if (m_expireLengthWhileOpen != DR_INFINITE || + m_expireLengthWhileClosed != DR_INFINITE) + { + DR_STREAM_PROPERTIES properties; + memset(&properties, 0, sizeof(properties)); + properties.cbSize = sizeof(properties); + properties.ExpirePeriod = m_expireLengthWhileClosed; + + cse = ::DrSetStreamPropertiesByHandle(m_streamHandle, + &properties, + NULL); + if (cse != DrError_OK) + { + DrLogW( + "Stream set properties at close failed", + "stream: %s error %s", + m_streamName, DRERRORSTRING(cse)); + } + } + + cse = DrCloseHandle(m_streamHandle); + LogAssert(cse == DrError_OK); + m_streamHandle = NULL; +} + +DrError RChannelBufferWriterDryadStream::OpenA(const char* pathName) +{ + { + AutoCriticalSection acs(GetBaseDR()); + + LogAssert(m_streamHandle == NULL); + LogAssert(m_nextOffsetToWrite == 0); + + HRESULT hr = ::StringCbCopyA(m_streamName, MAX_PATH, pathName); + LogAssert(SUCCEEDED(hr)); + hr = ::StringCbLengthA(m_streamName, MAX_PATH, + &m_streamNameLength); + LogAssert(SUCCEEDED(hr)); + + OpenInternal(); + } + + return DrError_OK; +} + +/* called with baseDR held */ +void RChannelBufferWriterDryadStream:: + StartConcreteWriter(RChannelItemRef* pCompletionItem) +{ + if (*pCompletionItem != NULL && + (*pCompletionItem)->GetType() == RChannelItem_EndOfStream) + { + RChannelItem* error = + RChannelMarkerItem:: + CreateErrorItemWithDescription(RChannelItem_Abort, + DryadError_ChannelRestartError, + "Can't restart channel after " + "sending EOF"); + pCompletionItem->Attach(error); + } +} + +/* called with baseDR held */ +void RChannelBufferWriterDryadStream::DrainConcreteWriter() +{ +} + +/* called with baseDR held */ +DryadFixedMemoryBuffer* RChannelBufferWriterDryadStream:: + GetNextWriteBufferInternal() +{ + return GetCustomWriteBufferInternal(m_bufferSize); +} + +/* called with baseDR held */ +DryadFixedMemoryBuffer* RChannelBufferWriterDryadStream:: + GetCustomWriteBufferInternal(Size_t bufferSize) +{ + return new DryadAlignedWriteBlock(bufferSize, 0); +} + +/* called with baseDR held */ +void RChannelBufferWriterDryadStream:: + ReturnUnusedBufferInternal(DryadFixedMemoryBuffer* block) +{ + block->DecRef(); +} + +void RChannelBufferWriterDryadStream:: + ReceiveBuffer(WriteHandler* writeHandler, + DrError errorCode) +{ + ReceiveBufferInternal(writeHandler, errorCode); +} + +/* called with baseDR held */ +RChannelBufferWriterNative::WriteHandler* RChannelBufferWriterDryadStream:: + MakeWriteHandler(DryadFixedMemoryBuffer* block, + bool flushAfter, + RChannelBufferWriterHandler* handler, + bool lazyOpenDone) +{ + UInt32 writeLength; + if (block == NULL) + { + writeLength = 0; + } + else + { + Size_t availableLength = block->GetAvailableSize(); + LogAssert(availableLength < 0x100000000); + writeLength = (UInt32) availableLength; + } + + DryadWriteHandler* writeHandler = + new DryadWriteHandler(m_streamHandle, + lazyOpenDone, + block, + m_nextOffsetToWrite, + flushAfter, + handler, this); + + m_nextOffsetToWrite += writeLength; + + SetProcessedLength(m_nextOffsetToWrite); + + return writeHandler; +} + +void RChannelBufferWriterDryadStream::FillInOpenedDetails(WriteHandler* h) +{ + DryadWriteHandler* handler = dynamic_cast(h); + + if (OpenError()) + { + LogAssert(m_streamHandle == NULL); + } + else + { + LogAssert(m_streamHandle != NULL); + } + + handler->SetStreamHandle(m_streamHandle); +} + +RChannelBufferWriterDryadStream::DryadWriteHandler:: + DryadWriteHandler(DRHANDLE handle, + bool detailsPresent, + DryadFixedMemoryBuffer* block, + UInt64 streamOffset, + bool flushAfter, + RChannelBufferWriterHandler* handler, + RChannelBufferWriterDryadStream* parent) : + RChannelBufferWriterNative::WriteHandler(block, streamOffset, + flushAfter, + handler, parent) +{ + m_streamHandle = handle; + m_detailsPresent = detailsPresent; + if (!m_detailsPresent) + { + LogAssert(m_streamHandle == NULL); + } + + if (block != NULL) + { + Size_t availableLength = block->GetAvailableSize(); + LogAssert(availableLength < 0x100000000); + InitializeInternal((UInt32) availableLength, streamOffset); + } +} + +void RChannelBufferWriterDryadStream::DryadWriteHandler:: + SetStreamHandle(DRHANDLE h) +{ + LogAssert(m_detailsPresent == false); + m_streamHandle = h; + m_detailsPresent = true; +} + +DRHANDLE RChannelBufferWriterDryadStream::DryadWriteHandler:: + GetStreamHandle() +{ + return m_streamHandle; +} + +DrError* RChannelBufferWriterDryadStream::DryadWriteHandler:: + GetPendingStatePtr() +{ + return &m_pendingState; +} + +void RChannelBufferWriterDryadStream::DryadWriteHandler:: + QueueWrite(DryadNativePort* port) +{ + if (!m_detailsPresent) + { + LogAssert(GetStreamHandle() == NULL); + + RChannelBufferWriterNative* baseParent = GetParent(); + RChannelBufferWriterDryadStream* parent = + (RChannelBufferWriterDryadStream *) baseParent; + + bool waitForThrottledOpen = parent->EnsureOpenForWrite(this); + if (waitForThrottledOpen) + { + /* do nothing right now. This will be sent to the port + eventually (perhaps on another thread) when the file is + finally opened. */ + return; + } + else + { + LogAssert(m_detailsPresent); + } + } + + if (GetStreamHandle() == NULL) + { + /* we will get here if there was an open error, so just pass + it straight back into the machinery without (obviously) + trying to actually write anything. The writer class will + fill in the appropriate error details. */ + m_pendingState = DrError_EndOfStream; + ProcessIO(DrError_OK, 0); + } + else + { + port->QueueDryadWrite(GetStreamHandle(), + GetPendingStatePtr(), + GetStreamOffset(), + this); + } +} + +void RChannelBufferWriterDryadStream::DryadWriteHandler:: + ProcessIO(DrError cse, UInt32 numBytes) +{ + LogAssert(m_detailsPresent); + + LogAssert(cse == DrError_OK); + cse = m_pendingState; + if (cse == DrError_OK) + { + numBytes = (UInt32) *(GetNumberOfBytesToTransferPtr()); + LogAssert(numBytes == GetWriteLength()); + DrLogI( + "Dryad append completed", + "streamhandle=%p, numBytes=%u", + m_streamHandle, numBytes); + + } + else + { + numBytes = 0; + DrLogE( + "Dryad append failed", + "streamhandle=%p, err=%s", + m_streamHandle, DRERRORSTRING(cse)); + + } + + RChannelBufferWriterNative* baseParent = GetParent(); + RChannelBufferWriterDryadStream* parent = + (RChannelBufferWriterDryadStream *) baseParent; + parent->ReceiveBuffer(this, cse); +} +#endif // if 0 diff --git a/DryadVertex/VertexHost/system/channel/src/channelbuffernativewriter.h b/DryadVertex/VertexHost/system/channel/src/channelbuffernativewriter.h new file mode 100644 index 0000000..a6cc717 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelbuffernativewriter.h @@ -0,0 +1,336 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "dryadnativeport.h" +#include "channelwriter.h" +#include "concreterchannelhelpers.h" +#ifdef TIDYFS +#include +#endif +#include "DrFPrint.h" + +#pragma warning(disable:4995) +#include + +class RChannelBufferWriterNative : + public RChannelBufferWriter, public RChannelThrottledStream +{ +public: + class WriteHandler : public DryadNativePort::Handler, public DrRefCounter + { + public: + WriteHandler(DryadFixedMemoryBuffer* block, + UInt64 streamOffset, + bool flushAfter, + RChannelBufferWriterHandler* handler, + RChannelBufferWriterNative* parent); + ~WriteHandler(); + + void ProcessingComplete(RChannelItemType statusCode); + + UInt32 GetWriteLength(); + UInt64 GetStreamOffset(); + void* GetData(); + DryadFixedMemoryBuffer* GetBlock(); + bool IsFlush(); + + virtual void QueueWrite(DryadNativePort* port) = 0; + + protected: + RChannelBufferWriterNative* GetParent(); + + private: + DryadFixedMemoryBuffer* m_block; + UInt32 m_writeLength; + bool m_flush; + RChannelBufferWriterHandler* m_handler; + RChannelBufferWriterNative* m_parent; + UInt64 m_streamOffset; + void* m_data; + DrBListEntry m_listPtr; + friend class DryadBList; + }; + + typedef DryadBList WriteHandlerList; + + RChannelBufferWriterNative(UInt32 outstandingWritesLowWatermark, + UInt32 outstandingWritesHighWatermark, + DryadNativePort* port, + RChannelOpenThrottler* openThrottler, + bool supportsLazyOpen); + ~RChannelBufferWriterNative(); + + void Start(); + + DryadFixedMemoryBuffer* GetNextWriteBuffer(); + DryadFixedMemoryBuffer* GetCustomWriteBuffer(Size_t bufferSize); + + bool WriteBuffer(DryadFixedMemoryBuffer* buffer, + bool flushAfter, + RChannelBufferWriterHandler* handler); + + void ReturnUnusedBuffer(DryadFixedMemoryBuffer* buffer); + + void WriteTermination(RChannelItemType reasonCode, + RChannelBufferWriterHandler* handler); + + void FillInStatus(DryadChannelDescription* status); + + bool EnsureOpenForWrite(WriteHandler* handler); + + void Drain(RChannelItemRef* returnItem); + + void Close(); + + /* Get/set a hint about the total length the channel is expected + to be. Some channel implementations can use this to improve + write performance and decrease disk fragmentation. A value of 0 + (the default) means that the size is unknown. */ + UInt64 GetInitialSizeHint(); + void SetInitialSizeHint(UInt64 hint); + + void OpenAfterThrottle(); + +protected: + CRITSEC* GetBaseDR(); + DryadNativePort* GetPort(); + void OpenInternal(); + void ReceiveBufferInternal(WriteHandler* writeHandler, DrError errorCode); + void DecrementOutstandingHandlers(); + bool FinishUsingFile(); + /* called with baseDR held */ + void SetProcessedLength(UInt64 processedLength); + void SetLowAndHighWaterMark(UInt32 outstandingWritesLowWatermark, + UInt32 outstandingWritesHighWatermark); + void SetOpenErrorItem(RChannelItem* errorItem); + bool OpenError(); + +private: + enum State { + S_Closed, + S_Running, + S_Stopping, + S_Stopped + }; + + enum OpenState { + OS_Closed, + OS_NotOpened, + OS_Waiting, + OS_Opened, + OS_OpenError, + OS_Stopped + }; + + typedef std::set WriteHandlerSet; + + virtual bool LazyOpenFile(); + virtual void FillInOpenedDetails(WriteHandler* handler); + virtual void EagerCloseFile(); + virtual void FinalCloseFile(); + virtual DryadFixedMemoryBuffer* GetNextWriteBufferInternal() = 0; + virtual DryadFixedMemoryBuffer* + GetCustomWriteBufferInternal(Size_t bufferSize) = 0; + virtual void ReturnUnusedBufferInternal(DryadFixedMemoryBuffer* + buffer) = 0; + virtual WriteHandler* MakeWriteHandler(DryadFixedMemoryBuffer* block, + bool flushAfter, + RChannelBufferWriterHandler* + handler, + bool detailsPresent, + bool* extendFile) = 0; + virtual void ExtendFileValidLength() = 0; + virtual void StartConcreteWriter(RChannelItemRef* pCompletionItem) = 0; + virtual void DrainConcreteWriter() = 0; + + bool AddToWriteQueue(WriteHandler* writeRequest, + WriteHandlerList* writeQueue, + bool extendFile); + void Unblock(WriteHandlerList* returnList, + WriteHandlerList* processList); + void ConsumeErrorBuffer(DrError errorCode, + WriteHandlerList* pReturnList, + WriteHandlerList* pProcessList); + RChannelItemType ConsumeBufferCompletion(WriteHandler* writeHandler, + DrError errorCode, + WriteHandlerList* pReturnList, + WriteHandlerList* pProcessList); + RChannelItemType DealWithReturnHandlers(WriteHandler* writeHandler, + DrError errorCode); + + UInt32 m_lowWatermark; + UInt32 m_highWatermark; + + UInt32 m_outstandingBuffers; + UInt32 m_outstandingIOs; + UInt32 m_outstandingHandlerSends; + bool m_outstandingTermination; + bool m_blocking; + bool m_flushing; + bool m_blockingForFileExtension; + WriteHandler* m_terminationHandler; + WriteHandlerList m_blockedList; + WriteHandlerSet m_pendingWriteSet; + WriteHandlerSet m_pendingUnblockSet; + RChannelItemRef m_completionItem; + RChannelItemRef m_openErrorItem; + WriteHandlerList m_openWaitingList; + RChannelOpenThrottler* m_openThrottler; + + UInt64 m_processedLength; + UInt64 m_initialSizeHint; + + HANDLE m_terminationEvent; + + DryadNativePort* m_port; + + OpenState m_openState; + State m_state; + bool m_supportsLazyOpen; + bool m_drainingOpenQueue; + + CRITSEC m_baseDR; +}; + +class RChannelBufferWriterNativeFile : public RChannelBufferWriterNative +{ +public: + class FileWriteHandler : + public RChannelBufferWriterNative::WriteHandler + { + public: + FileWriteHandler(HANDLE fileHandle, + bool detailsPresent, + bool isBuffered, + DryadFixedMemoryBuffer* block, + UInt64 streamOffset, + bool flushAfter, + RChannelBufferWriterHandler* handler, + RChannelBufferWriterNativeFile* parent); + + void SetFileHandle(HANDLE h); + HANDLE GetFileHandle(); + bool IsBuffered(); + + void ProcessIO(DrError errorCode, UInt32 numBytes); + + void QueueWrite(DryadNativePort* port); + + private: + HANDLE m_fileHandle; + bool m_detailsPresent; + bool m_isBuffered; + }; + + RChannelBufferWriterNativeFile(UInt32 bufferSize, + size_t bufferAlignment, + UInt32 outstandingWritesLowWatermark, + UInt32 outstandingWritesHighWatermark, + DryadNativePort* port, + RChannelOpenThrottler* openThrottler); + ~RChannelBufferWriterNativeFile(); + + DrError SetMetaData(DryadMetaData* metaData); + + bool OpenA(const char* pathName); +//JC bool OpenW(const wchar_t* pathName); + +protected: + UInt64 m_fpDataLength; + bool m_calcFP; + Dryad_dupelim_fprint_data_t m_fpo; + Dryad_dupelim_fprint_uint64_t m_fp; + + +private: + DrError TryToCreatePathA(); +//JC DrError TryToCreatePathW(); + HANDLE CreateFileAndPath(DrError* pErr, SECURITY_ATTRIBUTES *sa); + bool LazyOpenFile(); + void FillInOpenedDetails(WriteHandler* handler); + void EagerCloseFile(); + DryadFixedMemoryBuffer* GetNextWriteBufferInternal(); + DryadFixedMemoryBuffer* GetCustomWriteBufferInternal(Size_t bufferSize); + void ReturnUnusedBufferInternal(DryadFixedMemoryBuffer* buffer); + WriteHandler* MakeWriteHandler(DryadFixedMemoryBuffer* block, + bool flushAfter, + RChannelBufferWriterHandler* + handler, + bool detailsPresent, + bool* extendFile); + void StartConcreteWriter(RChannelItemRef* pCompletionItem); + void DrainConcreteWriter(); + void ExtendFileValidLength(); + + bool IsAligned(UInt64 offset); + UInt32 AlignmentGap(UInt64 offset); + void ReceiveBuffer(WriteHandler* writeHandler, DrError errorCode); + + static bool TryToSetPrivilege(); + + UInt32 m_bufferSize; + size_t m_bufferAlignment; + + HANDLE m_rawFileHandle; + HANDLE m_bufferedFileHandle; + + bool m_fileIsPipe; + + UInt64 m_nextOffsetToWrite; + UInt32 m_realignmentSize; + + char* m_fileNameA; +//JC wchar_t* m_fileNameW; + size_t m_fileNameLength; +//JC bool m_wideFileName; + + bool m_tryToCreatePath; + bool m_canExtendFileLength; + UInt64 m_fileLengthSet; + + static bool s_triedToSetPrivilege; + static bool s_setPrivilege; + static CRITSEC s_privilegeDR; + + friend class FileWriteHandler; +}; + +#ifdef TIDYFS +class RChannelBufferWriterNativeTidyFSStream : public RChannelBufferWriterNativeFile +{ +public: + RChannelBufferWriterNativeTidyFSStream(UInt32 bufferSize, + size_t bufferAlignment, + UInt32 outstandingWritesLowWatermark, + UInt32 outstandingWritesHighWatermark, + DryadNativePort* port, + RChannelOpenThrottler* openThrottler); + ~RChannelBufferWriterNativeTidyFSStream(); + DrError OpenA(const char* streamName, DryadMetaData *metaData); +private: + void Close(); + MDClient *m_client; + UInt64 m_partId; + char *m_hostname; + +}; +#endif diff --git a/DryadVertex/VertexHost/system/channel/src/channelbufferqueue.cpp b/DryadVertex/VertexHost/system/channel/src/channelbufferqueue.cpp new file mode 100644 index 0000000..36c562c --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelbufferqueue.cpp @@ -0,0 +1,777 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "channelbufferqueue.h" +#include "channelhelpers.h" +#include "channelreader.h" + +#pragma unmanaged + + +RChannelBufferQueue::RChannelBufferQueue(RChannelReaderImpl* parent, + RChannelBufferReader* bufferReader, + RChannelItemParserBase* parser, + UInt32 maxParseBatchSize, + UInt32 maxOutstandingUnits, + WorkQueue* workQueue) +{ + m_parent = parent; + m_bufferReader = bufferReader; + m_parser = parser; + m_workQueue = workQueue; + + m_maxParseBatchSize = m_parser->GetMaxParseBatchSize(); + if (m_maxParseBatchSize == 0) + { + m_maxParseBatchSize = maxParseBatchSize; + } + + m_maxOutstandingUnits = maxOutstandingUnits; + m_outstandingUnits = 0; + + m_state = BQ_Stopped; + m_shutDownRequested = false; + m_shutDownEvent = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(m_shutDownEvent != NULL); + /* alertCompleteEvent starts out signaled, and only gets reset + transiently during the period that ProcessAfterParsing may be + calling AlertApplication */ + m_alertCompleteEvent = ::CreateEvent(NULL, TRUE, TRUE, NULL); + LogAssert(m_alertCompleteEvent != NULL); + + m_currentBuffer = NULL; + m_firstBuffer = true; + m_nextDataSequenceNumber = 0; + m_nextDeliverySequenceNumber = 0; +} + +RChannelBufferQueue::~RChannelBufferQueue() +{ + LogAssert(m_state == BQ_Stopped); + LogAssert(m_currentBuffer == NULL); + LogAssert(m_pendingList.IsEmpty()); + BOOL bRet = ::CloseHandle(m_shutDownEvent); + LogAssert(bRet != 0); + bRet = ::CloseHandle(m_alertCompleteEvent); + LogAssert(bRet != 0); +} + +void RChannelBufferQueue:: + StartSupplier(RChannelBufferPrefetchInfo* prefetchCookie) +{ + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_bufferReader != NULL); + LogAssert(m_parser != NULL); + LogAssert(m_state == BQ_Stopped); + LogAssert(m_currentBuffer == NULL); + LogAssert(m_pendingList.IsEmpty()); + m_sendLatch.Start(); + + m_shutDownRequested = false; + m_firstBuffer = true; + m_state = BQ_Empty; + + m_nextDataSequenceNumber = 0; + m_nextDeliverySequenceNumber = 0; + } + + m_bufferReader->Start(prefetchCookie, this); +} + +// +// Return whether queue is shutting down or has been requested +// to shut down +// +bool RChannelBufferQueue::ShutDownRequested() +{ + bool retval; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state != BQ_Stopped); + retval = (m_shutDownRequested || m_state == BQ_Stopping); + } + + return retval; +} + +// +// Return all buffers in list and queue a work request if supplied +// +void RChannelBufferQueue::DispatchEvents(ChannelBufferList* bufferList, + WorkRequest* workRequest) +{ + DrBListEntry* listEntry = bufferList->GetHead(); + while (listEntry != NULL) + { + // + // Return all buffers + // + RChannelBuffer* buffer = bufferList->CastOut(listEntry); + listEntry = bufferList->GetNext(listEntry); + bufferList->Remove(bufferList->CastIn(buffer)); +// DrLogD( +// "RChannelBufferQueue::RequestNextParse returning buffer"); + buffer->ProcessingComplete(NULL); + } + + if (workRequest != NULL) + { + // + // Queue up work request if specified + // + bool bRet = m_workQueue->EnQueue(workRequest); + LogAssert(bRet == true); + } +} + +// +// Add current and pending buffers to buffer list and forget about them +// called with RChannelBufferQueue::m_baseDR locked +// +void RChannelBufferQueue::ShutDownBufferQueue(ChannelBufferList* bufferList) +{ + LogAssert(m_currentBuffer != NULL); + + bufferList->InsertAsTail(bufferList->CastIn(m_currentBuffer)); + m_currentBuffer = NULL; + bufferList->TransitionToTail(&m_pendingList); +} + +// +// Queue for processing or clean up if buffer queue is shutting down +// +void RChannelBufferQueue::ProcessBuffer(RChannelBuffer* buffer) +{ + WorkRequest* workRequest = NULL; + bool returnBuffer = false; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_shutDownRequested == false); + + if (m_state == BQ_Empty) + { + // + // If queue is currently empty, create a new work request for this buffer + // + m_state = BQ_InWorkQueue; + LogAssert(m_currentBuffer == NULL); + m_currentBuffer = buffer; + // todo: remove comment if not logging +// DrLogD( "scheduling parse"); + workRequest = new RChannelParseRequest(this, true); + } + else if (m_state == BQ_Stopping) + { + // + // If queue is currently stopping, log and return + // + LogAssert(m_currentBuffer == NULL); + LogAssert(m_pendingList.IsEmpty()); + returnBuffer = true; + DrLogD( "stopping: returning buffer"); + } + else + { + // + // If queue is neither empty nor stopping, add buffer to tail of queue + // to be processed later + // + LogAssert(m_state == BQ_BlockingItem || + m_state == BQ_InWorkQueue || + m_state == BQ_Locked); + LogAssert(m_currentBuffer != NULL); + m_pendingList.InsertAsTail(m_pendingList.CastIn(buffer)); + // todo: remove comment if not logging +// DrLogD( "queueing buffer"); + } + } + + // + // make all calls into other components with no locks held + // + + // + // If stopping and buffer not being processed, return the buffer to the pool + // + if (returnBuffer) + { + buffer->ProcessingComplete(NULL); + } + + // + // If new work request was created, enqueue it in work queue + // + if (workRequest != NULL) + { + bool bRet = m_workQueue->EnQueue(workRequest); + LogAssert(bRet == true); + } +} + +// +// Shut down queue if already requested, otherwise, mark it as locked +// +RChannelBuffer* RChannelBufferQueue::GetAndLockCurrentBuffer(bool* outReset) +{ + RChannelBuffer* buffer; + ChannelBufferList bufferList; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_currentBuffer != NULL); + LogAssert(m_state == BQ_InWorkQueue); + + if (m_shutDownRequested) + { + // + // If shut down is requested, do it now + // + ShutDownBufferQueue(&bufferList); + m_state = BQ_Stopping; + BOOL bRet = ::SetEvent(m_shutDownEvent); + LogAssert(bRet != 0); + *outReset = false; + LogAssert(m_currentBuffer == NULL); + } + else + { + // + // If not shutting down, mark queue as locked + // + m_state = BQ_Locked; + *outReset = m_firstBuffer; + m_firstBuffer = false; + LogAssert(m_currentBuffer != NULL); + } + + // + // m_currentBuffer is NULL if and only if m_shutDownRequested + // was true + // + buffer = m_currentBuffer; + } + + // + // Return all buffers in list if queue is shutting down + // + DispatchEvents(&bufferList, NULL); + + return buffer; +} + +void RChannelBufferQueue::NotifyUnitConsumption() +{ + WorkRequest* workRequest = NULL; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_outstandingUnits > 0); + --m_outstandingUnits; + + if (m_state == BQ_BlockingItem && + m_outstandingUnits < m_maxOutstandingUnits) + { + workRequest = new RChannelParseRequest(this, false); + m_state = BQ_InWorkQueue; + } + } + + if (workRequest != NULL) + { + bool bRet = m_workQueue->EnQueue(workRequest); + LogAssert(bRet == true); + } +} + +void RChannelBufferQueue:: + ProcessAfterParsing(NextParseAction nextParseAction, + RChannelItemArray* itemArray, + UInt64 numberOfSubItemsRead, + UInt64 dataSizeRead, + RChannelBufferPrefetchInfo* prefetchCookie) +{ + ChannelUnitList unitList; + ChannelBufferList bufferList; + WorkRequest* workRequest = NULL; + RChannelItemType lastItemType = RChannelItem_ItemHole; + bool mustWakeUp = false; + RChannelItem* lastItem = NULL; + RChannelItemRef alertItem = NULL; + + if (itemArray->GetNumberOfItems() > 0) + { + lastItem = + itemArray->GetItemArray()[itemArray->GetNumberOfItems()-1]; + lastItemType = lastItem->GetType(); + } + + if (lastItemType == RChannelItem_ParseError) + { + /* call into other components with no locks held. + + Make sure no more buffers will be delivered since we are + going to stop parsing. When this returns we are sure no more + calls to ProcessBuffer will come in so it's safe to call + ShutDownBufferQueue below */ + m_bufferReader->Interrupt(); + } + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == BQ_Locked); + LogAssert(m_currentBuffer != NULL); + + if (m_shutDownRequested) + { + /* a shutdown request came in from the application while + we were in the parser, so we don't want to do anything + except throw out all pending buffers and return. */ + ShutDownBufferQueue(&bufferList); + m_state = BQ_Stopping; + mustWakeUp = true; + } + else + { + if (lastItemType == RChannelItem_ParseError) + { + /* ditch any buffers we haven't parsed yet along with + the current buffer */ + ShutDownBufferQueue(&bufferList); + m_state = BQ_Stopping; + mustWakeUp = true; + } + else if (nextParseAction != NPA_RequestItem) + { + /* we're done with the current buffer; put it in the + unit queue for completion when the items in it have + been processed */ + RChannelUnit* unit = + new RChannelBufferBoundaryUnit(m_currentBuffer, + prefetchCookie); + unitList.InsertAsTail(unitList.CastIn(unit)); + m_currentBuffer = NULL; + } + + if (itemArray->GetNumberOfItems() > 0) + { + RChannelItemRef terminationItem; + if (RChannelItem::IsTerminationItem(lastItemType)) + { + terminationItem = lastItem; + } + + RChannelItemUnit* unit = + new RChannelSerializedUnit(this, + itemArray, terminationItem); + unit->SetSizes(numberOfSubItemsRead, dataSizeRead); + ++m_outstandingUnits; + + unitList.InsertAsTail(unitList.CastIn(unit)); + if (nextParseAction == NPA_StopParsing) + { + alertItem = lastItem; + if (alertItem != NULL) + { + /* reset the event so InterruptSupplier won't + complete until the application has been + alerted */ + BOOL bRet = ::ResetEvent(m_alertCompleteEvent); + LogAssert(bRet != 0); + } + } + } + + switch (nextParseAction) + { + case NPA_RequestBuffer: + LogAssert(m_currentBuffer == NULL); + /* we would like to start processing the next buffer, + so put it in the queue if there is one */ + if (m_pendingList.IsEmpty()) + { + m_state = BQ_Empty; + } + else + { + m_currentBuffer = + m_pendingList.CastOut(m_pendingList.RemoveHead()); + workRequest = new RChannelParseRequest(this, true); + m_state = BQ_InWorkQueue; + } + break; + + case NPA_RequestItem: + LogAssert(m_currentBuffer != NULL); + if (m_outstandingUnits >= m_maxOutstandingUnits) + { + m_state = BQ_BlockingItem; + } + else + { + workRequest = new RChannelParseRequest(this, false); + m_state = BQ_InWorkQueue; + } + break; + + case NPA_EndOfStream: + LogAssert(m_pendingList.IsEmpty()); + LogAssert(m_currentBuffer == NULL); + LogAssert(lastItemType == RChannelItem_EndOfStream); + m_state = BQ_Stopping; + break; + + case NPA_StopParsing: + LogAssert(m_pendingList.IsEmpty()); + LogAssert(m_currentBuffer == NULL); + LogAssert(lastItemType == RChannelItem_ParseError || + lastItemType == RChannelItem_Abort || + lastItemType == RChannelItem_Restart); + m_state = BQ_Stopping; + break; + + default: + LogAssert(false); + } + } + + m_sendLatch.AcceptList(&unitList); + } + + /* make all calls into other components with no locks held */ + + DispatchEvents(&bufferList, workRequest); + + while (unitList.IsEmpty() == false) + { + m_parent->AddUnitList(&unitList); + + { + AutoCriticalSection acs(&m_baseDR); + + m_sendLatch.TransferList(&unitList); + } + } + + if (nextParseAction == NPA_StopParsing) + { + if (alertItem != NULL) + { + m_parent->AlertApplication(alertItem); + /* set the event so InterruptSupplier can complete now + that the application has been alerted */ + BOOL bRet = ::SetEvent(m_alertCompleteEvent); + LogAssert(bRet != 0); + } + } + else + { + LogAssert(alertItem == NULL); + } + + if (mustWakeUp) + { + BOOL bRet = ::SetEvent(m_shutDownEvent); + LogAssert(bRet != 0); + } +} + +// +// Called when RChannelParseRequest.Process is called in work queue +// +void RChannelBufferQueue::ParseRequest(bool useNewBuffer) +{ +// DrLogD( +// "RChannelBufferQueue::ParseRequest entered", +// "useNewBuffer: %s", (useNewBuffer) ? "true" : "false"); + + RChannelBuffer* buffer; + bool resetParser = false; + UInt64 numberOfSubItemsRead = 0; + UInt64 dataSizeRead = 0; + + // + // Handle shutdown logic if necessary or lock the queue + // + buffer = GetAndLockCurrentBuffer(&resetParser); + if (buffer == NULL) + { + // + // there was a shutdown request while this work item was in + // the queue, which has now been dealt with, so just exit + // + return; + } + + RChannelBufferType bufferType = buffer->GetType(); + RChannelItemArrayRef itemArray; + itemArray.Attach(new RChannelItemArray()); + RChannelBufferPrefetchInfo* prefetchCookie = NULL; + NextParseAction nextParseAction; + + if (bufferType == RChannelBuffer_Data || + bufferType == RChannelBuffer_Hole || + bufferType == RChannelBuffer_EndOfStream) + { + /* get the parser to give us the next item, if any */ + RChannelBuffer* parseBuffer = (useNewBuffer) ? buffer : NULL; + + UInt32 parsedItemCount = 0; + itemArray->SetNumberOfItems(m_maxParseBatchSize); + RChannelItemRef* items = itemArray->GetItemArray(); + + do + { + RChannelItem* item = + m_parser->RawParseItem(resetParser, parseBuffer, + &prefetchCookie); + resetParser = false; + parseBuffer = NULL; + + if (item == NULL) + { + /* the parser should only ask for another buffer if + there's more data and it didn't give back an item */ + LogAssert(bufferType != RChannelBuffer_EndOfStream); + nextParseAction = NPA_RequestBuffer; + } + else + { + LogAssert(prefetchCookie == NULL); + + RChannelItemType itemType = item->GetType(); + if (itemType == RChannelItem_EndOfStream) + { + LogAssert(bufferType == RChannelBuffer_EndOfStream); + nextParseAction = NPA_EndOfStream; + } + else if (itemType == RChannelItem_ParseError) + { + nextParseAction = NPA_StopParsing; + } + else + { + LogAssert(itemType == RChannelItem_Data || + itemType == RChannelItem_ItemHole || + itemType == RChannelItem_BufferHole); + nextParseAction = NPA_RequestItem; + } + + item->SetDeliverySequenceNumber(m_nextDeliverySequenceNumber); + ++m_nextDeliverySequenceNumber; + /* Only data items merit a new sequence number; everything + else gets the sequence number of the next data item. */ + item->SetDataSequenceNumber(m_nextDataSequenceNumber); + if (itemType == RChannelItem_Data) + { + ++m_nextDataSequenceNumber; + } + + numberOfSubItemsRead += item->GetNumberOfSubItems(); + dataSizeRead += item->GetItemSize(); + + items[parsedItemCount].Attach(item); + ++parsedItemCount; + } + } while (nextParseAction == NPA_RequestItem && + parsedItemCount < m_maxParseBatchSize); + + itemArray->TruncateToSize(parsedItemCount); + } + else + { + LogAssert(useNewBuffer == true); + RChannelBufferMarker* markerBuffer = (RChannelBufferMarker *) buffer; + RChannelItem* item = markerBuffer->GetItem(); + + item->SetDeliverySequenceNumber(m_nextDeliverySequenceNumber); + ++m_nextDeliverySequenceNumber; + /* Only data items merit a new sequence number; everything + else gets the sequence number of the next data item. */ + item->SetDataSequenceNumber(m_nextDataSequenceNumber); + + itemArray->SetNumberOfItems(1); + itemArray->GetItemArray()[0] = item; + LogAssert(bufferType == RChannelBuffer_Restart || + bufferType == RChannelBuffer_Abort); + nextParseAction = NPA_StopParsing; + } + + ProcessAfterParsing(nextParseAction, + itemArray, numberOfSubItemsRead, dataSizeRead, + prefetchCookie); +} + +// +// Stop creation of new buffers +// +void RChannelBufferQueue::InterruptSupplier() +{ + bool waitForSend = false; + bool waitForLock = false; + + // + // Interrupt the reader associated with this queue + // call into other components with no locks held. When this + // returns we are sure no more calls to ProcessBuffer will come in + // + m_bufferReader->Interrupt(); + + WorkRequest* workRequest = NULL; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_shutDownRequested == false); + + if (m_state == BQ_BlockingItem) + { + // + // the easiest thing to do is to transition into the + // InWorkQueue state and continue interrupting using the + // normal path + // + workRequest = new RChannelParseRequest(this, false); + m_state = BQ_InWorkQueue; + } + + if (m_state == BQ_Locked || + m_state == BQ_InWorkQueue) + { + // + // If queue is locked or working, we're going to + // request a shutdown but wait for the lock + // + m_shutDownRequested = true; + waitForLock = true; + BOOL bRet = ::ResetEvent(m_shutDownEvent); + LogAssert(bRet != 0); + } + else + { + if (m_state == BQ_Empty) + { + m_state = BQ_Stopping; + } + + LogAssert(m_state == BQ_Stopping); + LogAssert(m_currentBuffer == NULL); + LogAssert(m_pendingList.IsEmpty()); + } + + // + // Interrupt the send latch. waitForSend = true if currently sending + // + waitForSend = m_sendLatch.Interrupt(); + } + + if (workRequest != NULL) + { + // + // If blocking, add request to queue to call this.ParseRequest, which will shut down the queue + // + bool bRet = m_workQueue->EnQueue(workRequest); + LogAssert(bRet == true); + } + + if (waitForSend) + { + // + // Wait on send latch to stop sending if necessary + // + m_sendLatch.Wait(); + } + + // + // Wait until any alerts have been delivered to the application + // + BOOL bRet = ::WaitForSingleObject(m_alertCompleteEvent, INFINITE); + LogAssert(bRet == WAIT_OBJECT_0); + + if (waitForLock) + { + // + // Wait for shut down if work queue is still processing + // + bRet = ::WaitForSingleObject(m_shutDownEvent, INFINITE); + LogAssert(bRet == WAIT_OBJECT_0); + } + + { + // + // Ensure that everything is shut down correctly + // + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == BQ_Stopping); + LogAssert(m_currentBuffer == NULL); + LogAssert(m_pendingList.IsEmpty()); + m_shutDownRequested = false; + } +} + +void RChannelBufferQueue::DrainSupplier(RChannelItem* drainItem) +{ + LogAssert(RChannelItem::IsTerminationItem(drainItem->GetType())); + + InterruptSupplier(); + + /* this won't return until all outstanding buffers have had their + completion handlers called. By the time this happens everything + must have drained out of the work queue and the parser */ + m_bufferReader->Drain(drainItem); + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == BQ_Stopping); + LogAssert(m_outstandingUnits == 0); + LogAssert(m_currentBuffer == NULL); + LogAssert(m_pendingList.IsEmpty()); + m_state = BQ_Stopped; + m_sendLatch.Stop(); + } +} + +void RChannelBufferQueue::CloseSupplier() +{ + m_bufferReader->Close(); + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == BQ_Stopped); + m_bufferReader = NULL; + m_parser = NULL; + } +} + +bool RChannelBufferQueue::GetTotalLength(UInt64* pLen) +{ + return m_bufferReader->GetTotalLength(pLen); +} diff --git a/DryadVertex/VertexHost/system/channel/src/channelbufferqueue.h b/DryadVertex/VertexHost/system/channel/src/channelbufferqueue.h new file mode 100644 index 0000000..77904ad --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelbufferqueue.h @@ -0,0 +1,116 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +class RChannelReaderImpl; +class RChannelBuffer; +class RChannelBufferReader; +class RChannelBufferPrefetchInfo; +class RChannelUnit; +class RChannelParseRequest; + +#include "channelbuffer.h" +#include +#include +#include +#include + +typedef DryadBList ChannelUnitList; + +class RChannelBufferQueue : + public RChannelReaderSupplier, + public RChannelBufferReaderHandler +{ +public: + RChannelBufferQueue(RChannelReaderImpl* parent, + RChannelBufferReader* bufferReader, + RChannelItemParserBase* parser, + UInt32 maxParseBatchSize, + UInt32 maxOutstandingUnits, + WorkQueue* workQueue); + ~RChannelBufferQueue(); + + void StartSupplier(RChannelBufferPrefetchInfo* prefetchCookie); + void InterruptSupplier(); + void DrainSupplier(RChannelItem* drainItem); + + void ProcessBuffer(RChannelBuffer* buffer); + void NotifyUnitConsumption(); + + void CloseSupplier(); + + bool GetTotalLength(UInt64* pLen); + +private: + enum QueueState { + BQ_Stopped, + BQ_Empty, + BQ_BlockingItem, + BQ_InWorkQueue, + BQ_Locked, + BQ_Stopping + }; + + enum NextParseAction { + NPA_RequestBuffer, + NPA_RequestItem, + NPA_StopParsing, + NPA_EndOfStream + }; + + void DispatchEvents(ChannelBufferList* bufferList, + WorkRequest* workRequest); + void ShutDownBufferQueue(ChannelBufferList* bufferList); + RChannelBuffer* GetAndLockCurrentBuffer(bool* outReset); + void ProcessAfterParsing(NextParseAction nextParseAction, + RChannelItemArray* itemArray, + UInt64 numberOfSubItemsRead, + UInt64 dataSizeRead, + RChannelBufferPrefetchInfo* prefetchCookie); + + void ParseRequest(bool useNewBuffer); + bool ShutDownRequested(); + + RChannelReaderImpl* m_parent; + RChannelBufferReader* m_bufferReader; + RChannelItemParserRef m_parser; + UInt32 m_maxParseBatchSize; + UInt32 m_maxOutstandingUnits; + WorkQueue* m_workQueue; + + QueueState m_state; + DryadOrderedSendLatch m_sendLatch; + bool m_shutDownRequested; + HANDLE m_shutDownEvent; + HANDLE m_alertCompleteEvent; + + RChannelBuffer* m_currentBuffer; + bool m_firstBuffer; + ChannelBufferList m_pendingList; + UInt32 m_outstandingUnits; + + UInt64 m_nextDataSequenceNumber; + UInt64 m_nextDeliverySequenceNumber; + + CRITSEC m_baseDR; + + friend class RChannelParseRequest; +}; diff --git a/DryadVertex/VertexHost/system/channel/src/channelfifo.cpp b/DryadVertex/VertexHost/system/channel/src/channelfifo.cpp new file mode 100644 index 0000000..b03a110 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelfifo.cpp @@ -0,0 +1,1421 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include +#include + +#pragma unmanaged + + +RChannelFifoWriterBase::SyncHandler::SyncHandler() +{ + m_event = NULL; + m_statusCode = RChannelItem_Data; + m_usingEvent = 0; +} + +RChannelFifoWriterBase::SyncHandler::~SyncHandler() +{ + LogAssert(m_event == NULL); +} + +void RChannelFifoWriterBase::SyncHandler::UseEvent(DryadHandleListEntry* event) +{ + LONG postIncrement = ::InterlockedIncrement(&m_usingEvent); + LogAssert(postIncrement == 1); + LogAssert(m_event == NULL); + m_event = event; +} + +void RChannelFifoWriterBase::SyncHandler:: + ProcessWriteArrayCompleted(RChannelItemType statusCode, + RChannelItemArray* failureArray) +{ + LogAssert(failureArray == NULL); + m_statusCode = statusCode; + LONG postIncrement = ::InterlockedIncrement(&m_usingEvent); + if (postIncrement == 2) + { + LogAssert(m_event != NULL); + BOOL bRet = ::SetEvent(m_event->GetHandle()); + LogAssert(bRet != 0); + } + else + { + LogAssert(postIncrement == 1); + LogAssert(m_event == NULL); + } +} + +bool RChannelFifoWriterBase::SyncHandler::UsingEvent() +{ + return (m_event != NULL); +} + +RChannelItemType RChannelFifoWriterBase::SyncHandler::GetStatusCode() +{ + return m_statusCode; +} + +void RChannelFifoWriterBase::SyncHandler::Wait() +{ + LogAssert(m_event != NULL); + DWORD dRet = ::WaitForSingleObject(m_event->GetHandle(), INFINITE); + LogAssert(dRet == WAIT_OBJECT_0); +} + +DryadHandleListEntry* RChannelFifoWriterBase::SyncHandler::GetEvent() +{ + LogAssert(m_event != NULL); + DryadHandleListEntry* retVal = m_event; + m_event = NULL; + return retVal; +} + +RChannelFifoWriterBase::RChannelFifoWriterBase(RChannelFifo* parent) +{ + m_parent = parent; + m_reader = NULL; + m_writerState = WS_Closed; + m_supplierState = SS_Closed; + m_outstandingUnits = 0; + m_nextDataSequenceNumber = 0; + m_nextDeliverySequenceNumber = 0; + m_numberOfSubItemsWritten = 0; + m_dataSizeWritten = 0; + m_supplierDrainEvent = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(m_supplierDrainEvent != NULL); + m_supplierEpoch = 0; + m_writerEpoch = 0; +} + +RChannelFifoWriterBase::~RChannelFifoWriterBase() +{ + BOOL bRet = ::CloseHandle(m_supplierDrainEvent); + LogAssert(bRet != 0); +} + +void RChannelFifoWriterBase::SetURI(const char* uri) +{ + m_uri = uri; +} + +const char* RChannelFifoWriterBase::GetURI() +{ + return m_uri; +} + +UInt64 RChannelFifoWriterBase::GetInitialSizeHint() +{ + return 0; +} + +void RChannelFifoWriterBase::SetInitialSizeHint(UInt64 /*hint*/) +{ +} + +void RChannelFifoWriterBase::SetReader(RChannelReaderImpl* reader) +{ + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_reader == NULL); + LogAssert(reader != NULL); + m_reader = reader; + LogAssert(m_writerState == WS_Closed); + LogAssert(m_supplierState == SS_Closed); + LogAssert(m_readerTerminationItem == NULL); + LogAssert(m_writerTerminationItem == NULL); + m_writerState = WS_Stopped; + m_supplierState = SS_Stopped; + } +} + +RChannelFifo* RChannelFifoWriterBase::GetParent() +{ + return m_parent; +} + +void RChannelFifoWriterBase::ReturnHandlers(ChannelFifoUnitList* returnList) +{ + while (returnList->IsEmpty() == false) + { + RChannelFifoUnit* returnUnit = + returnList->CastOut(returnList->RemoveHead()); + while (returnUnit != NULL) + { + RChannelItemArrayWriterHandler* handler; + RChannelItemType statusCode; + returnUnit->Disgorge(&handler, &statusCode); + if (handler != NULL) + { + handler->ProcessWriteArrayCompleted(statusCode, NULL); + } + delete returnUnit; + returnUnit = + returnList->CastOut(returnList->RemoveHead()); + } + + { + AutoCriticalSection acs(&m_baseDR); + + m_returnLatch.TransferList(returnList); + } + } +} + +void RChannelFifoWriterBase::SendUnits(ChannelUnitList* unitList) +{ + while (unitList->IsEmpty() == false) + { + m_reader->AddUnitList(unitList); + + { + AutoCriticalSection acs(&m_baseDR); + + m_sendLatch.TransferList(unitList); + } + } +} + +RChannelFifoUnit* + RChannelFifoWriterBase::DuplicateUnit(RChannelFifoUnit* unit, + RChannelItemArray* itemArray) +{ + RChannelFifoUnit* sendUnit = new RChannelFifoUnit(this); + sendUnit->SetSizes(unit->GetNumberOfSubItems(), + unit->GetDataSize()); + sendUnit->SetPayload(itemArray, NULL, RChannelItem_Data); + RChannelItem* t = unit->GetTerminationItem(); + if (t != NULL) + { + sendUnit->SetTerminationItem(t); + } + unit->DiscardItems(); + return sendUnit; +} + +void RChannelFifoWriterBase:: + WriteItemArray(RChannelItemArrayRef& itemArray, + bool flushAfter, + RChannelItemArrayWriterHandler* handler) +{ + EnqueueItemArray(itemArray, handler, NULL); +} + +RChannelItemType RChannelFifoWriterBase:: + WriteItemArraySync(RChannelItemArrayRef& itemArray, + bool flush, + RChannelItemArrayRef* pFailureArray) +{ + SyncHandler syncHandler; + + if (pFailureArray != NULL) + { + *pFailureArray = NULL; + } + + EnqueueItemArray(itemArray, &syncHandler, &syncHandler); + + if (syncHandler.UsingEvent()) + { + syncHandler.Wait(); + + { + AutoCriticalSection acs(&m_baseDR); + + m_eventCache.ReturnEvent(syncHandler.GetEvent()); + } + } + + return syncHandler.GetStatusCode(); +} + +UInt64 RChannelFifoWriterBase::GetDataSizeWritten() +{ + return m_dataSizeWritten; +} + +void RChannelFifoWriterBase::CloseSupplier() +{ + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_supplierState == SS_Stopped); + m_supplierState = SS_Closed; + } +} + +void RChannelFifoWriterBase::Close() +{ + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_writerState == WS_Stopped); + m_writerState = WS_Closed; + m_writerTerminationItem = NULL; + m_readerTerminationItem = NULL; + } +} + + +RChannelFifoWriter::RChannelFifoWriter(RChannelFifo* parent, + UInt32 fifoLength) : + RChannelFifoWriterBase(parent) +{ + m_fifoLength = fifoLength; + m_availableUnits = 0; + m_terminationUnit = NULL; + m_writerDrainEvent = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(m_writerDrainEvent != NULL); +} + +RChannelFifoWriter::~RChannelFifoWriter() +{ + BOOL bRet = ::CloseHandle(m_writerDrainEvent); + LogAssert(bRet != 0); +} + +void RChannelFifoWriter::Start() +{ + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_writerState == WS_Stopped); + LogAssert(m_availableUnits == 0); + LogAssert(m_blockedList.IsEmpty()); + LogAssert(m_reader != NULL); + LogAssert(m_nextDataSequenceNumber == 0); + LogAssert(m_nextDeliverySequenceNumber == 0); + LogAssert(m_outstandingUnits == 0); + LogAssert(m_terminationUnit == NULL); + + m_sendLatch.Start(); + m_returnLatch.Start(); + + m_writerTerminationItem = NULL; + m_readerTerminationItem = NULL; + m_numberOfSubItemsWritten = 0; + m_dataSizeWritten = 0; + + LogAssert(m_supplierEpoch == m_writerEpoch); + + BOOL bRet = ::ResetEvent(m_writerDrainEvent); + LogAssert(bRet != 0); + + if (m_supplierState == SS_Running) + { + m_availableUnits = m_fifoLength; + } + m_writerState = WS_Running; + } +} + +void RChannelFifoWriter:: + StartSupplier(RChannelBufferPrefetchInfo* prefetchCookie) +{ + ChannelUnitList unblockedList; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_supplierState == SS_Stopped); + + LogAssert(m_writerState != WS_Closed); + LogAssert(m_availableUnits == 0); + LogAssert(m_outstandingUnits == 0); + + if (m_writerState == WS_Running) + { + m_availableUnits = m_fifoLength; + } + else + { + if (m_writerState == WS_Draining) + { + LogAssert(m_blockedList.IsEmpty() == false); + m_availableUnits = m_fifoLength; + } + else + { + LogAssert(m_writerState == WS_Stopped); + LogAssert(m_blockedList.IsEmpty()); + } + } + + BOOL bRet = ::ResetEvent(m_supplierDrainEvent); + LogAssert(bRet != 0); + + m_supplierState = SS_Running; + + while (m_availableUnits > 0 && m_blockedList.IsEmpty() == false) + { + --m_availableUnits; + ++m_outstandingUnits; + + RChannelFifoUnit* blockedUnit = + m_blockedList.CastOut(m_blockedList.GetHead()); + + unblockedList.TransitionToTail(unblockedList. + CastIn(blockedUnit)); + } + + if (m_writerTerminationItem != NULL) + { + if (m_terminationUnit == NULL) + { + LogAssert(m_writerState == WS_Stopped); + LogAssert(m_writerEpoch == m_supplierEpoch); + } + m_availableUnits = 0; + } + + m_sendLatch.AcceptList(&unblockedList); + } + + SendUnits(&unblockedList); +} + +/* called with m_baseDR held */ +bool RChannelFifoWriter::CheckForTerminationItem(RChannelFifoUnit* unit) +{ + if (m_writerTerminationItem != NULL || m_writerState == WS_Stopped) + { + LogAssert(m_availableUnits == 0); + unit->DiscardItems(); + return true; + } + else + { + /* we can't be in the draining state if we haven't seen a + termination item */ + LogAssert(m_writerState != WS_Draining); + + RChannelItemArray* itemArray = unit->GetItemArray(); + UInt32 nItems = itemArray->GetNumberOfItems(); + LogAssert(nItems > 0); + RChannelItemRef* items = itemArray->GetItemArray(); + + UInt64 unitSubItems = 0; + UInt64 unitDataSize = 0; + UInt32 i; + for (i=0; iGetType(); + if (RChannelItem::IsTerminationItem(itemType)) + { + m_writerTerminationItem = item; + LogAssert(m_terminationUnit == NULL); + m_terminationUnit = unit; + itemArray->TruncateToSize(i+1); + nItems = i+1; + unit->SetTerminationItem(item); + } + + item->SetDeliverySequenceNumber(m_nextDeliverySequenceNumber); + ++m_nextDeliverySequenceNumber; + /* Only data items merit a new sequence number; + everything else gets the sequence number of the + next data item. */ + item->SetDataSequenceNumber(m_nextDataSequenceNumber); + if (itemType == RChannelItem_Data) + { + ++m_nextDataSequenceNumber; + } + m_numberOfSubItemsWritten += item->GetNumberOfSubItems(); + m_dataSizeWritten += item->GetItemSize(); + unitSubItems += item->GetNumberOfSubItems(); + unitDataSize += item->GetItemSize(); + } + + unit->SetSizes(unitSubItems, unitDataSize); + + return false; + } +} + +void RChannelFifoWriter::StartBlocking(RChannelFifoUnit* unit, + SyncHandler* handler) +{ + if (handler != NULL) + { + /* if we have been called from WriteItemSync then we want to + use an event inside the handler. Since resetting an event + is expensive we don't want to do it on every call to + WriteItemSync, and here inside this lock is the only place + we can get it on demand. */ + handler->UseEvent(m_eventCache.GetEvent(true)); + } + + m_blockedList.InsertAsTail(m_blockedList.CastIn(unit)); +} + +void RChannelFifoWriter:: + EnqueueItemArray(RChannelItemArrayRef& itemArrayIn, + RChannelItemArrayWriterHandler* handler, + SyncHandler* syncHandler) +{ + ChannelUnitList sendList; + ChannelFifoUnitList returnList; + bool writerHasTerminated = false; + RChannelFifoUnit* unit = new RChannelFifoUnit(this); + + RChannelItemArrayRef itemArray; + /* ensure that the caller no longer holds a reference to this array */ + itemArray.TransferFrom(itemArrayIn); + + LogAssert(itemArray->GetNumberOfItems() > 0); + + { + AutoCriticalSection acs(&m_baseDR); + + unit->SetPayload(itemArray, handler, RChannelItem_MarshalError); + + LogAssert(m_writerState != WS_Closed); + + writerHasTerminated = CheckForTerminationItem(unit); + + if (writerHasTerminated) + { + /* some previous invocation (perhaps is another thread) + has already sent a termination item so signal that we + aren't accepting any more items by just returning the + handler immediately with this status code */ + unit->SetStatusCode(RChannelItem_EndOfStream); + } + else + { + LogAssert(m_writerState == WS_Running); + + if (m_supplierState == SS_Interrupted) + { + /* the writer has not terminated, but the reader is + about to request a drain. Queue up the handler to + be returned with the appropriate code once the + drain actually arrives. */ + LogAssert(m_readerTerminationItem == NULL); + LogAssert(m_availableUnits == 0); + LogAssert(unit->GetItemArray()->GetNumberOfItems() > 0); + + StartBlocking(unit, syncHandler); + unit = NULL; + } + else if (m_supplierState == SS_Draining) + { + /* the writer has not terminated, but the reader has + requested a drain. */ + LogAssert(m_readerTerminationItem != NULL); + RChannelItemType drainType = + m_readerTerminationItem->GetType(); + LogAssert(drainType != RChannelItem_MarshalError); + LogAssert(m_availableUnits == 0); + + if (m_outstandingUnits > 0) + { + /* queue up this item to be returned after the + in-flight items have come back. At that point + the extra item reference will be removed, so + leave it there for now. */ + LogAssert(unit->GetItemArray()->GetNumberOfItems() > 0); + StartBlocking(unit, syncHandler); + unit = NULL; + } + else + { + /* if there are no in-flight items we should have + returned the handlers from any previous items + already so we can return this handler + immediately */ + LogAssert(m_blockedList.IsEmpty()); + unit->SetStatusCode(drainType); + unit->DiscardItems(); + + if (m_terminationUnit == unit) + { + m_terminationUnit = NULL; + } + } + } + else if (m_supplierState == SS_Stopped) + { + LogAssert(m_availableUnits == 0); + + if (m_supplierEpoch == m_writerEpoch) + { + /* the supplier hasn't yet been started so queue + up the unit for when it is */ + unit->SetStatusCode(RChannelItem_Data); + LogAssert(unit->GetItemArray()->GetNumberOfItems() > 0); + + StartBlocking(unit, syncHandler); + unit = NULL; + } + else + { + /* the supplier has already shut down so just + return this immediately */ + LogAssert(m_supplierEpoch == m_writerEpoch+1); + LogAssert(m_readerTerminationItem != NULL); + RChannelItemType drainType = + m_readerTerminationItem->GetType(); + LogAssert(drainType != RChannelItem_MarshalError); + LogAssert(m_blockedList.IsEmpty()); + unit->DiscardItems(); + unit->SetStatusCode(drainType); + + if (m_terminationUnit == unit) + { + m_terminationUnit = NULL; + } + } + } + else + { + LogAssert(m_supplierState == SS_Running); + + LogAssert(unit->GetItemArray()->GetNumberOfItems() > 0); + unit->SetStatusCode(RChannelItem_Data); + + if (m_availableUnits == 0) + { + /* we have no spare slots in the FIFO so add the + item to the blocked list: this has the + side-effect of preventing the handler being + returned immediately. */ + StartBlocking(unit, syncHandler); + unit = NULL; + } + else + { + LogAssert(m_blockedList.IsEmpty()); + + /* we're ready to send this on to the reader: use + the send latch to ensure the ordering is + preserved. We will return the handler + immediately since the FIFO was not full and it + is OK to send more items. */ + + --m_availableUnits; + + RChannelFifoUnit* sendUnit; + if (unit == m_terminationUnit) + { + /* always block on the termination item to + ensure its handler gets delivered last */ + StartBlocking(unit, syncHandler); + unit = NULL; + + /* cheat by stealing this back off the end of + the blocking queue and send it + immediately */ + sendUnit = m_blockedList.CastOut(m_blockedList. + RemoveTail()); + + /* throw away any spare FIFO units since we + won't be writing any more items */ + m_availableUnits = 0; + } + else + { + /* return the handler to the user immediately + while also forwarding the unit to the + reader */ + sendUnit = DuplicateUnit(unit, itemArray); + } + + sendList.InsertAsTail(sendList.CastIn(sendUnit)); + m_sendLatch.AcceptList(&sendList); + + ++m_outstandingUnits; + } + } + } + + if (unit != NULL) + { + LogAssert(unit->GetStatusCode() != RChannelItem_MarshalError); + returnList.InsertAsTail(returnList.CastIn(unit)); + m_returnLatch.AcceptList(&returnList); + if (returnList.IsEmpty() && syncHandler != NULL) + { + /* another thread is returning the handler for us, so + we need to wait for its event */ + syncHandler->UseEvent(m_eventCache.GetEvent(true)); + } + } + } + + if (writerHasTerminated) + { + LogAssert(sendList.IsEmpty()); + LogAssert(unit != NULL); + } + + SendUnits(&sendList); + + ReturnHandlers(&returnList); +} + +/* called with m_baseDR held */ +bool RChannelFifoWriter::ReWriteBlockedListForEarlyReturn(RChannelItemType + drainType) +{ + /* rewrite any blocked items to have the termination code correct + before sending them back to the writer after a reader + Drain. Also, discard the items, since the reader is never going + to see them. */ + bool wakeUpWriterDrain = false; + DrBListEntry* listEntry = m_blockedList.GetHead(); + while (listEntry != NULL) + { + RChannelFifoUnit* unit = m_blockedList.CastOut(listEntry); + listEntry = m_blockedList.GetNext(listEntry); + unit->SetStatusCode(drainType); + unit->DiscardItems(); + if (unit == m_terminationUnit) + { + LogAssert(listEntry == NULL); + m_terminationUnit = NULL; + if (m_writerState == WS_Draining) + { + wakeUpWriterDrain = true; + } + } + } + + return wakeUpWriterDrain; +} + +void RChannelFifoWriter::AcceptReturningUnit(RChannelFifoUnit* unit) +{ + ChannelUnitList unblockedList; + ChannelFifoUnitList returnList; + bool wakeUpWriterDrain = false; + bool wakeUpSupplierDrain = false; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_outstandingUnits > 0); + --m_outstandingUnits; + + if (unit == m_terminationUnit) + { + LogAssert(m_outstandingUnits == 0); + LogAssert(m_blockedList.IsEmpty()); + m_terminationUnit = NULL; + + if (m_writerState == WS_Draining) + { + wakeUpWriterDrain = true; + } + } + + returnList.InsertAsTail(returnList.CastIn(unit)); + unit = NULL; + + if (m_supplierState == SS_Draining) + { + /* the supplier has requested a drain */ + if (m_outstandingUnits == 0) + { + /* we have been building up written items in the + blocked list while waiting for the outstanding + units to be returned. We can send them all back + now. */ + LogAssert(m_readerTerminationItem != NULL); + RChannelItemType drainType = + m_readerTerminationItem->GetType(); + wakeUpWriterDrain = + ReWriteBlockedListForEarlyReturn(drainType) || + wakeUpWriterDrain; + returnList.TransitionToTail(&m_blockedList); + m_supplierState = SS_Stopped; + ++m_supplierEpoch; + wakeUpSupplierDrain = true; + } + + LogAssert(m_availableUnits == 0); + } + else if (m_supplierState != SS_Interrupted) + { + LogAssert(m_supplierState == SS_Running); + + if (m_blockedList.IsEmpty()) + { + if (m_writerTerminationItem == NULL) + { + /* there's nothing blocked, so just add a unit to + the available list and return without doing + anything else */ + ++m_availableUnits; + } + } + else + { + /* there was a blocked request waiting to be sent to + the FIFO: put it in now that we have a space + available. */ + LogAssert(m_supplierState == SS_Running); + + RChannelFifoUnit* blockedUnit = + m_blockedList.CastOut(m_blockedList.GetHead()); + + unblockedList.TransitionToTail(unblockedList. + CastIn(blockedUnit)); + ++m_outstandingUnits; + + m_sendLatch.AcceptList(&unblockedList); + } + } + + m_returnLatch.AcceptList(&returnList); + } + + if (wakeUpWriterDrain) + { + LogAssert(unblockedList.IsEmpty()); + BOOL bRet = ::SetEvent(m_writerDrainEvent); + LogAssert(bRet != 0); + } + else + { + SendUnits(&unblockedList); + } + + ReturnHandlers(&returnList); + + if (wakeUpSupplierDrain) + { + BOOL bRet = ::SetEvent(m_supplierDrainEvent); + LogAssert(bRet != 0); + } +} + +void RChannelFifoWriter::InterruptSupplier() +{ + bool waitForLatch; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_supplierState == SS_Running); + + /* remove any spare FIFO slots since we aren't going to send + anything more to the supplier */ + m_availableUnits = 0; + + m_supplierState = SS_Interrupted; + + waitForLatch = m_sendLatch.Interrupt(); + } + + if (waitForLatch) + { + /* make sure all sends which were proceeding outside a lock + have completed */ + m_sendLatch.Wait(); + } +} + +void RChannelFifoWriter::DrainSupplier(RChannelItem* drainItem) +{ + ChannelFifoUnitList returnList; + + bool wakeUpWriterDrain = false; + bool wakeUpSupplierDrain = false; + bool waitForSupplierDrain = false; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_supplierState == SS_Interrupted); + LogAssert(m_availableUnits == 0); + + LogAssert(drainItem != NULL); + RChannelItemType drainType = drainItem->GetType(); + + m_readerTerminationItem = drainItem; + + if (m_outstandingUnits > 0) + { + m_supplierState = SS_Draining; + waitForSupplierDrain = true; + } + else + { + wakeUpWriterDrain = ReWriteBlockedListForEarlyReturn(drainType); + returnList.TransitionToTail(&m_blockedList); + m_returnLatch.AcceptList(&returnList); + m_supplierState = SS_Stopped; + ++m_supplierEpoch; + wakeUpSupplierDrain = true; + } + } + + if (wakeUpWriterDrain) + { + BOOL bRet = ::SetEvent(m_writerDrainEvent); + LogAssert(bRet != 0); + } + + if (wakeUpSupplierDrain) + { + LogAssert(waitForSupplierDrain == false); + BOOL bRet = ::SetEvent(m_supplierDrainEvent); + LogAssert(bRet != 0); + } + else if (waitForSupplierDrain) + { + DWORD dRet = ::WaitForSingleObject(m_supplierDrainEvent, INFINITE); + LogAssert(dRet == WAIT_OBJECT_0); + } + + ReturnHandlers(&returnList); +} + +void RChannelFifoWriter::Drain(DrTimeInterval csTimeOut, + RChannelItemRef* pReturnItem) +{ + DrTimeStamp startTime = DrGetCurrentTimeStamp(); + + bool waitForWriterDrain = false; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_writerState == WS_Running); + LogAssert(m_writerTerminationItem != NULL); + LogAssert(m_availableUnits == 0); + + if (m_supplierState == SS_Running && m_outstandingUnits == 0) + { + LogAssert(m_blockedList.IsEmpty()); + } + + if (m_outstandingUnits > 0 || m_blockedList.IsEmpty() == false) + { + LogAssert(m_terminationUnit != NULL); + m_writerState = WS_Draining; + waitForWriterDrain = true; + } + } + + if (waitForWriterDrain) + { + DWORD dRet = ::WaitForSingleObject(m_writerDrainEvent, INFINITE); + LogAssert(dRet == WAIT_OBJECT_0); + } + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_outstandingUnits == 0); + LogAssert(m_blockedList.IsEmpty()); + LogAssert(m_availableUnits == 0); + LogAssert(m_reader != NULL); + LogAssert(m_outstandingUnits == 0); + LogAssert(m_terminationUnit == NULL); + + m_writerState = WS_Stopped; + ++m_writerEpoch; + m_nextDeliverySequenceNumber = 0; + m_nextDataSequenceNumber = 0; + } + + DrTimeStamp currentTime = DrGetCurrentTimeStamp(); + DrTimeInterval elapsed = DrGetElapsedTime(startTime, currentTime); + if (elapsed < csTimeOut) + { + DWORD timeOut = DrGetTimerMsFromInterval(csTimeOut - elapsed); + DWORD dRet = ::WaitForSingleObject(m_supplierDrainEvent, timeOut); + LogAssert(dRet == WAIT_TIMEOUT || dRet == WAIT_OBJECT_0); + } + + if (pReturnItem != NULL) + { + AutoCriticalSection acs(&m_baseDR); + + *pReturnItem = m_readerTerminationItem; + } +} + +void RChannelFifoWriter::GetTerminationItems(RChannelItemRef* pWriterDrainItem, + RChannelItemRef* pReaderDrainItem) +{ + { + AutoCriticalSection acs(&m_baseDR); + + *pWriterDrainItem = m_writerTerminationItem; + *pReaderDrainItem = m_readerTerminationItem; + } +} + + +RChannelFifoNBWriter::RChannelFifoNBWriter(RChannelFifo* parent) : + RChannelFifoWriterBase(parent) +{ +} + +void RChannelFifoNBWriter::Start() +{ + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_writerState == WS_Stopped); + LogAssert(m_outstandingUnits == 0); + LogAssert(m_blockedList.IsEmpty()); + LogAssert(m_reader != NULL); + LogAssert(m_nextDataSequenceNumber == 0); + LogAssert(m_nextDeliverySequenceNumber == 0); + + m_sendLatch.Start(); + m_returnLatch.Start(); + + m_writerTerminationItem = NULL; + m_readerTerminationItem = NULL; + m_numberOfSubItemsWritten = 0; + m_dataSizeWritten = 0; + + LogAssert(m_supplierEpoch == m_writerEpoch); + + m_writerState = WS_Running; + } +} + +void RChannelFifoNBWriter:: + StartSupplier(RChannelBufferPrefetchInfo* prefetchCookie) +{ + ChannelUnitList unblockedList; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_supplierState == SS_Stopped); + + LogAssert(m_writerState != WS_Closed); + LogAssert(m_outstandingUnits == 0); + + BOOL bRet = ::ResetEvent(m_supplierDrainEvent); + LogAssert(bRet != 0); + + m_supplierState = SS_Running; + + while (m_blockedList.IsEmpty() == false) + { + ++m_outstandingUnits; + + RChannelFifoUnit* blockedUnit = + m_blockedList.CastOut(m_blockedList.GetHead()); + + unblockedList.TransitionToTail(unblockedList. + CastIn(blockedUnit)); + } + + m_sendLatch.AcceptList(&unblockedList); + } + + SendUnits(&unblockedList); +} + +/* called with m_baseDR held */ +bool RChannelFifoNBWriter::CheckForTerminationItem(RChannelFifoUnit* unit) +{ + if (m_writerTerminationItem != NULL || m_writerState == WS_Stopped) + { + unit->DiscardItems(); + return true; + } + else + { + RChannelItemArray* itemArray = unit->GetItemArray(); + UInt32 nItems = itemArray->GetNumberOfItems(); + LogAssert(nItems > 0); + RChannelItemRef* items = itemArray->GetItemArray(); + + UInt64 unitSubItems = 0; + UInt64 unitDataSize = 0; + UInt32 i; + for (i=0; iGetType(); + if (RChannelItem::IsTerminationItem(itemType)) + { + m_writerTerminationItem = item; + itemArray->TruncateToSize(i+1); + nItems = i+1; + unit->SetTerminationItem(item); + } + + item->SetDeliverySequenceNumber(m_nextDeliverySequenceNumber); + ++m_nextDeliverySequenceNumber; + /* Only data items merit a new sequence number; everything + else gets the sequence number of the next data item. */ + item->SetDataSequenceNumber(m_nextDataSequenceNumber); + if (itemType == RChannelItem_Data) + { + ++m_nextDataSequenceNumber; + } + m_numberOfSubItemsWritten += item->GetNumberOfSubItems(); + m_dataSizeWritten += item->GetItemSize(); + unitSubItems += item->GetNumberOfSubItems(); + unitDataSize += item->GetItemSize(); + } + + unit->SetSizes(unitSubItems, unitDataSize); + + return false; + } +} + +void RChannelFifoNBWriter:: + EnqueueItemArray(RChannelItemArrayRef& itemArrayIn, + RChannelItemArrayWriterHandler* handler, + SyncHandler* syncHandler) +{ + ChannelUnitList sendList; + ChannelFifoUnitList returnList; + bool writerHasTerminated = false; + RChannelFifoUnit* unit = new RChannelFifoUnit(this); + + RChannelItemArrayRef itemArray; + /* ensure that the caller no longer holds a reference to this array */ + itemArray.TransferFrom(itemArrayIn); + + LogAssert(itemArray->GetNumberOfItems() > 0); + + { + AutoCriticalSection acs(&m_baseDR); + + unit->SetPayload(itemArray, handler, RChannelItem_MarshalError); + + LogAssert(m_writerState != WS_Closed); + + writerHasTerminated = CheckForTerminationItem(unit); + + if (writerHasTerminated) + { + /* some previous invocation (perhaps is another thread) + has already sent a termination item so signal that we + aren't accepting any more items by just returning the + handler immediately with this status code */ + unit->SetStatusCode(RChannelItem_EndOfStream); + } + else + { + LogAssert(m_writerState == WS_Running); + + if (m_supplierState == SS_Stopped) + { + if (m_supplierEpoch == m_writerEpoch) + { + /* the supplier hasn't started yet so return the + handler to the user and queue up the data */ + unit->SetStatusCode(RChannelItem_Data); + + RChannelFifoUnit* blockedUnit = + DuplicateUnit(unit, itemArray); + m_blockedList. + InsertAsTail(m_blockedList.CastIn(blockedUnit)); + } + else + { + /* the supplier has already shut down so just + return this immediately */ + LogAssert(m_supplierEpoch == m_writerEpoch+1); + LogAssert(m_readerTerminationItem != NULL); + RChannelItemType drainType = + m_readerTerminationItem->GetType(); + LogAssert(drainType != RChannelItem_MarshalError); + LogAssert(m_blockedList.IsEmpty()); + unit->DiscardItems(); + unit->SetStatusCode(drainType); + } + } + else if (m_supplierState == SS_Draining) + { + /* the supplier has sent a drain item, so return the + handler to the user with the appropriate code and + throw the data away */ + LogAssert(m_supplierEpoch == m_writerEpoch); + LogAssert(m_readerTerminationItem != NULL); + RChannelItemType drainType = + m_readerTerminationItem->GetType(); + LogAssert(drainType != RChannelItem_MarshalError); + LogAssert(m_blockedList.IsEmpty()); + unit->DiscardItems(); + unit->SetStatusCode(drainType); + } + else if (m_supplierState == SS_Interrupted) + { + /* the supplier wants to drain but hasn't yet, so just + return the handler to the user and throw the data + away */ + LogAssert(m_supplierEpoch == m_writerEpoch); + LogAssert(m_readerTerminationItem == NULL); + LogAssert(m_blockedList.IsEmpty()); + unit->DiscardItems(); + unit->SetStatusCode(RChannelItem_Data); + } + else + { + LogAssert(m_supplierState == SS_Running); + + unit->SetStatusCode(RChannelItem_Data); + + LogAssert(m_blockedList.IsEmpty()); + + /* return the handler to the user immediately + while also forwarding the unit to the + reader */ + RChannelFifoUnit* sendUnit = + DuplicateUnit(unit, itemArray); + sendList.InsertAsTail(sendList.CastIn(sendUnit)); + m_sendLatch.AcceptList(&sendList); + + ++m_outstandingUnits; + } + } + + LogAssert(unit->GetStatusCode() != RChannelItem_MarshalError); + returnList.InsertAsTail(returnList.CastIn(unit)); + m_returnLatch.AcceptList(&returnList); + if (returnList.IsEmpty() && syncHandler != NULL) + { + /* another thread is returning the handler for us, so + we need to wait for its event */ + syncHandler->UseEvent(m_eventCache.GetEvent(true)); + } + } + + SendUnits(&sendList); + + ReturnHandlers(&returnList); +} + +void RChannelFifoNBWriter::AcceptReturningUnit(RChannelFifoUnit* unit) +{ + bool wakeUpSupplierDrain = false; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_outstandingUnits > 0); + --m_outstandingUnits; + + if (m_supplierState == SS_Draining) + { + /* the supplier has requested a drain */ + if (m_outstandingUnits == 0) + { + m_supplierState = SS_Stopped; + ++m_supplierEpoch; + wakeUpSupplierDrain = true; + } + } + } + + if (wakeUpSupplierDrain) + { + BOOL bRet = ::SetEvent(m_supplierDrainEvent); + LogAssert(bRet != 0); + } + + delete unit; +} + +void RChannelFifoNBWriter::InterruptSupplier() +{ + bool waitForLatch; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_supplierState == SS_Running); + + m_supplierState = SS_Interrupted; + + waitForLatch = m_sendLatch.Interrupt(); + } + + if (waitForLatch) + { + /* make sure all sends which were proceeding outside a lock + have completed */ + m_sendLatch.Wait(); + } +} + +void RChannelFifoNBWriter::DrainSupplier(RChannelItem* drainItem) +{ + ChannelFifoUnitList returnList; + + bool wakeUpSupplierDrain = false; + bool waitForSupplierDrain = false; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_supplierState == SS_Interrupted); + + LogAssert(drainItem != NULL); + + m_readerTerminationItem = drainItem; + + if (m_outstandingUnits > 0) + { + m_supplierState = SS_Draining; + waitForSupplierDrain = true; + } + else + { + m_supplierState = SS_Stopped; + ++m_supplierEpoch; + wakeUpSupplierDrain = true; + } + } + + if (wakeUpSupplierDrain) + { + LogAssert(waitForSupplierDrain == false); + BOOL bRet = ::SetEvent(m_supplierDrainEvent); + LogAssert(bRet != 0); + } + else if (waitForSupplierDrain) + { + DWORD dRet = ::WaitForSingleObject(m_supplierDrainEvent, INFINITE); + LogAssert(dRet == WAIT_OBJECT_0); + } + + ReturnHandlers(&returnList); +} + +void RChannelFifoNBWriter::Drain(DrTimeInterval csTimeOut, + RChannelItemRef* pReturnItem) +{ + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_writerState == WS_Running); + LogAssert(m_writerTerminationItem != NULL); + + m_writerState = WS_Stopped; + ++m_writerEpoch; + m_nextDeliverySequenceNumber = 0; + m_nextDataSequenceNumber = 0; + + if (pReturnItem != NULL) + { + *pReturnItem = m_readerTerminationItem; + } + } +} + +void RChannelFifoNBWriter:: + GetTerminationItems(RChannelItemRef* pWriterDrainItem, + RChannelItemRef* pReaderDrainItem) +{ + { + AutoCriticalSection acs(&m_baseDR); + + *pWriterDrainItem = m_writerTerminationItem; + *pReaderDrainItem = m_readerTerminationItem; + } +} + + +RChannelFifoReader::RChannelFifoReader(RChannelFifo* parent, + WorkQueue* workQueue) +{ + m_parent = parent; + Initialize(m_parent->GetWriter(), workQueue, false); +} + +RChannelFifoReader::~RChannelFifoReader() +{ +} + +RChannelFifo* RChannelFifoReader::GetParent() +{ + return m_parent; +} + +bool RChannelFifoReader::GetTotalLength(UInt64* pLen) +{ + *pLen = 0; + return false; +} + +bool RChannelFifoReader::GetExpectedLength(UInt64* pLen) +{ + *pLen = 0; + return false; +} + +void RChannelFifoReader::SetExpectedLength(UInt64 expectedLength) +{ +} + +RChannelFifo::RChannelFifo(const char* name, + UInt32 fifoLength, WorkQueue* workQueue) +{ + m_name = name; + if (fifoLength == s_infiniteBuffer) + { + m_writer = new RChannelFifoNBWriter(this); + } + else + { + m_writer = new RChannelFifoWriter(this, fifoLength); + } + m_writer->SetURI(name); + + m_reader = new RChannelFifoReader(this, workQueue); + m_reader->SetURI(name); + + m_writer->SetReader(m_reader); +} + +RChannelFifo::~RChannelFifo() +{ + DrLogD( "Deleting fifo. Name %s", m_name.GetString()); + delete m_reader; + delete m_writer; +} + +const char* RChannelFifo::GetName() +{ + return m_name; +} + +RChannelFifoReader* RChannelFifo::GetReader() +{ + return m_reader; +} + +RChannelFifoWriterBase* RChannelFifo::GetWriter() +{ + return m_writer; +} diff --git a/DryadVertex/VertexHost/system/channel/src/channelfifo.h b/DryadVertex/VertexHost/system/channel/src/channelfifo.h new file mode 100644 index 0000000..f52cf08 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelfifo.h @@ -0,0 +1,241 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include +#include + +class RChannelFifo; +class RChannelFifoUnit; +typedef DryadBListDerived ChannelFifoUnitList; + +class RChannelFifoWriterBase : + public RChannelWriter, + public RChannelReaderSupplier +{ +public: + RChannelFifo* GetParent(); + + /* partial RChannelWriter interface */ + void WriteItemArray(RChannelItemArrayRef& itemArray, + bool flushAfter, + RChannelItemArrayWriterHandler* handler); + RChannelItemType WriteItemArraySync(RChannelItemArrayRef& itemArray, + bool flush, + RChannelItemArrayRef* failureArray); + void Close(); + + /* partial RChannelReaderSupplier interface */ + void CloseSupplier(); + + UInt64 GetDataSizeWritten(); + + /* Get/Set the URI of the channel. */ + const char* GetURI(); + void SetURI(const char* uri); + + /* these are just dummies for a fifo: the hint is always returned + as zero */ + UInt64 GetInitialSizeHint(); + void SetInitialSizeHint(UInt64 hint); + +protected: + enum WriterState { + WS_Closed, + WS_Running, + WS_Draining, + WS_Stopped + }; + + enum SupplierState { + SS_Closed, + SS_Running, + SS_Interrupted, + SS_Draining, + SS_Stopped + }; + + class SyncHandler : public RChannelItemArrayWriterHandler + { + public: + SyncHandler(); + ~SyncHandler(); + + void ProcessWriteArrayCompleted(RChannelItemType status, + RChannelItemArray* failureArray); + + void UseEvent(DryadHandleListEntry* event); + bool UsingEvent(); + RChannelItemType GetStatusCode(); + void Wait(); + DryadHandleListEntry* GetEvent(); + + private: + DryadHandleListEntry* m_event; + LONG volatile m_usingEvent; + RChannelItemType m_statusCode; + }; + + RChannelFifoWriterBase(RChannelFifo* parent); + virtual ~RChannelFifoWriterBase(); + + virtual void AcceptReturningUnit(RChannelFifoUnit* unit) = 0; + virtual void EnqueueItemArray(RChannelItemArrayRef& itemArray, + RChannelItemArrayWriterHandler* handler, + SyncHandler* SyncHandler) = 0; + + void SetReader(RChannelReaderImpl* reader); + + void ReturnHandlers(ChannelFifoUnitList* returnList); + void SendUnits(ChannelUnitList* unitList); + RChannelFifoUnit* DuplicateUnit(RChannelFifoUnit* unit, + RChannelItemArray* itemArray); + + RChannelFifo* m_parent; + WriterState m_writerState; + SupplierState m_supplierState; + ChannelFifoUnitList m_blockedList; + DryadOrderedSendLatch m_sendLatch; + DryadOrderedSendLatch m_returnLatch; + RChannelReaderImpl* m_reader; + UInt64 m_nextDataSequenceNumber; + UInt64 m_nextDeliverySequenceNumber; + UInt64 m_numberOfSubItemsWritten; + UInt64 m_dataSizeWritten; + UInt32 m_outstandingUnits; + RChannelItemRef m_writerTerminationItem; + RChannelItemRef m_readerTerminationItem; + UInt32 m_writerEpoch; + UInt32 m_supplierEpoch; + HANDLE m_supplierDrainEvent; + DryadEventCache m_eventCache; + + DrStr128 m_uri; + + CRITSEC m_baseDR; + + friend class RChannelFifoUnit; + friend class RChannelFifo; +}; + +class RChannelFifoWriter : + public RChannelFifoWriterBase +{ +public: + /* RChannelWriter interface */ + void Start(); + void Drain(DrTimeInterval timeOut, RChannelItemRef* pReturnItem); + void GetTerminationItems(RChannelItemRef* pWriterDrainItem, + RChannelItemRef* pReaderDrainItem); + + /* RChannelReaderSupplier interface */ + void StartSupplier(RChannelBufferPrefetchInfo* prefetchCookie); + void InterruptSupplier(); + void DrainSupplier(RChannelItem* drainItem); + +private: + RChannelFifoWriter(RChannelFifo* parent, UInt32 fifoLength); + ~RChannelFifoWriter(); + + bool ReWriteBlockedListForEarlyReturn(RChannelItemType drainType); + bool CheckForTerminationItem(RChannelFifoUnit* unit); + void StartBlocking(RChannelFifoUnit* unit, + SyncHandler* SyncHandler); + void EnqueueItemArray(RChannelItemArrayRef& itemArray, + RChannelItemArrayWriterHandler* handler, + SyncHandler* SyncHandler); + void AcceptReturningUnit(RChannelFifoUnit* unit); + + UInt32 m_availableUnits; + UInt32 m_fifoLength; + RChannelFifoUnit* m_terminationUnit; + HANDLE m_writerDrainEvent; + + friend class RChannelFifoUnit; + friend class RChannelFifo; +}; + +class RChannelFifoNBWriter : + public RChannelFifoWriterBase +{ +public: + /* RChannelWriter interface */ + void Start(); + void Drain(DrTimeInterval timeOut, RChannelItemRef* pReturnItem); + void GetTerminationItems(RChannelItemRef* pWriterDrainItem, + RChannelItemRef* pReaderDrainItem); + + /* RChannelReaderSupplier interface */ + void StartSupplier(RChannelBufferPrefetchInfo* prefetchCookie); + void InterruptSupplier(); + void DrainSupplier(RChannelItem* drainItem); + +private: + RChannelFifoNBWriter(RChannelFifo* parent); + + bool CheckForTerminationItem(RChannelFifoUnit* unit); + void EnqueueItemArray(RChannelItemArrayRef& itemArray, + RChannelItemArrayWriterHandler* handler, + SyncHandler* SyncHandler); + void AcceptReturningUnit(RChannelFifoUnit* unit); + + friend class RChannelFifoUnit; + friend class RChannelFifo; +}; + +class RChannelFifoReader : public RChannelReaderImpl +{ +public: + RChannelFifo* GetParent(); + + bool GetTotalLength(UInt64* pLen); + bool GetExpectedLength(UInt64* pLen); + void SetExpectedLength(UInt64 expectedLength); + +private: + RChannelFifoReader(RChannelFifo* parent, + WorkQueue* workQueue); + ~RChannelFifoReader(); + + RChannelFifo* m_parent; + + friend class RChannelFifo; +}; + +class RChannelFifo +{ +public: + RChannelFifo(const char* name, + UInt32 fifoLength, WorkQueue* workQueue); + ~RChannelFifo(); + + const char* GetName(); + RChannelFifoReader* GetReader(); + RChannelFifoWriterBase* GetWriter(); + + static const UInt32 s_infiniteBuffer = (UInt32) -1; + +private: + DrStr64 m_name; + RChannelFifoReader* m_reader; + RChannelFifoWriterBase* m_writer; +}; diff --git a/DryadVertex/VertexHost/system/channel/src/channelhelpers.cpp b/DryadVertex/VertexHost/system/channel/src/channelhelpers.cpp new file mode 100644 index 0000000..65605a3 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelhelpers.cpp @@ -0,0 +1,316 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "channelparser.h" +#include "channelhelpers.h" +#include "channelreader.h" +#include "channelbuffer.h" +#include "channelbufferqueue.h" +#include "channelfifo.h" + +#pragma unmanaged + + +RChannelUnit::RChannelUnit(RChannelUnitType type) +{ + m_type = type; +} + +RChannelUnit::~RChannelUnit() +{ +} + +RChannelUnitType RChannelUnit::GetType() +{ + return m_type; +} + +RChannelItemUnit::RChannelItemUnit(RChannelItemArray* payload, + RChannelItem* terminationItem) : + RChannelUnit(RChannelUnit_Item) +{ + m_payload = payload; + m_terminationItem = terminationItem; + m_numberOfSubItems = 0; + m_dataSize = 0; +} + +RChannelItemUnit::~RChannelItemUnit() +{ +} + +RChannelItemArray* RChannelItemUnit::GetItemArray() +{ + return m_payload; +} + +RChannelItem* RChannelItemUnit::GetTerminationItem() +{ + return m_terminationItem; +} + +void RChannelItemUnit::DiscardItems() +{ + m_payload = NULL; + m_terminationItem = NULL; +} + +void RChannelItemUnit::SetSizes(UInt64 numberOfSubItems, + UInt64 dataSize) +{ + m_numberOfSubItems = numberOfSubItems; + m_dataSize = dataSize; +} + +UInt64 RChannelItemUnit::GetNumberOfSubItems() +{ + return m_numberOfSubItems; +} + +UInt64 RChannelItemUnit::GetDataSize() +{ + return m_dataSize; +} + + +RChannelSerializedUnit::RChannelSerializedUnit(RChannelBufferQueue* parent, + RChannelItemArray* payload, + RChannelItem* terminationItem) : + RChannelItemUnit(payload, terminationItem) +{ + LogAssert(m_payload != NULL); + m_parent = parent; +} + +void RChannelSerializedUnit::ReturnToSupplier() +{ + /* we should have already transferred away the array before getting + here, and don't need to do anything more */ + LogAssert(m_payload == NULL); + m_parent->NotifyUnitConsumption(); + delete this; +} + + +RChannelFifoUnit::RChannelFifoUnit(RChannelFifoWriterBase* parent) : + RChannelItemUnit(NULL, NULL) +{ + m_parent = parent; + m_handler = NULL; + m_statusCode = RChannelItem_Data; +} + +RChannelFifoUnit::~RChannelFifoUnit() +{ + LogAssert(m_handler == NULL); +} + +void RChannelFifoUnit::SetPayload(RChannelItemArray* itemArray, + RChannelItemArrayWriterHandler* handler, + RChannelItemType statusCode) +{ + LogAssert(itemArray != NULL); + LogAssert(m_payload == NULL); + LogAssert(m_terminationItem == NULL); + LogAssert(m_handler == NULL); + + m_payload = itemArray; + m_handler = handler; + m_statusCode = statusCode; +} + +void RChannelFifoUnit::SetTerminationItem(RChannelItem* terminationItem) +{ + LogAssert(RChannelItem::IsTerminationItem(terminationItem->GetType())); + m_terminationItem = terminationItem; +} + +void RChannelFifoUnit::Disgorge(RChannelItemArrayWriterHandler** pHandler, + RChannelItemType* pStatusCode) +{ + LogAssert(m_payload == NULL); + + *pHandler = m_handler; + m_handler = NULL; + + *pStatusCode = m_statusCode; + m_statusCode = RChannelItem_Data; +} + +void RChannelFifoUnit::SetStatusCode(RChannelItemType statusCode) +{ + LogAssert(m_handler != NULL); + m_statusCode = statusCode; +} + +RChannelItemType RChannelFifoUnit::GetStatusCode() +{ + return m_statusCode; +} + +void RChannelFifoUnit::ReturnToSupplier() +{ + LogAssert(m_parent != NULL); + LogAssert(m_payload == NULL); + m_parent->AcceptReturningUnit(this); +} + +RChannelBufferBoundaryUnit:: + RChannelBufferBoundaryUnit(RChannelBuffer* buffer, + RChannelBufferPrefetchInfo* prefetchCookie) : + RChannelUnit(RChannelUnit_BufferBoundary) +{ + LogAssert(buffer != NULL); + m_buffer = buffer; + m_prefetchCookie = prefetchCookie; +} + +RChannelBufferBoundaryUnit::~RChannelBufferBoundaryUnit() +{ + LogAssert(m_buffer == NULL); +} + +void RChannelBufferBoundaryUnit::ReturnToSupplier() +{ + LogAssert(m_buffer != NULL); + m_buffer->ProcessingComplete(m_prefetchCookie); + m_buffer = NULL; + m_prefetchCookie = NULL; + delete this; +} + + +RChannelProcessRequest:: + RChannelProcessRequest(RChannelReaderImpl* parent, + RChannelItemArrayReaderHandler* handler, + void* cancelCookie) +{ + m_aborted = 0; + m_parent = parent; + m_handler = handler; + m_cookie = cancelCookie; +} + +RChannelProcessRequest::~RChannelProcessRequest() +{ +} + +// +// Have request parent process the item +// +void RChannelProcessRequest::Process() +{ + m_parent->ProcessItemArrayRequest(this); + m_itemArray = NULL; +} + +bool RChannelProcessRequest::ShouldAbort() +{ + return (::InterlockedExchangeAdd(&m_aborted, 0) != 0); +} + +void RChannelProcessRequest::SetItemArray(RChannelItemArray* itemArray) +{ + m_itemArray = itemArray; +} + +RChannelItemArray* RChannelProcessRequest::GetItemArray() +{ + return m_itemArray; +} + +RChannelItemArrayReaderHandler* RChannelProcessRequest::GetHandler() +{ + LogAssert(m_handler != NULL); + return m_handler; +} + +void* RChannelProcessRequest::GetCookie() +{ + return m_cookie; +} + + +void RChannelProcessRequest::Cancel() +{ + ::InterlockedIncrement(&m_aborted); +} + + +RChannelParseRequest:: + RChannelParseRequest(RChannelBufferQueue* parent, + bool useNewBuffer) +{ + m_parent = parent; + m_useNewBuffer = useNewBuffer; +} + +void RChannelParseRequest::Process() +{ + m_parent->ParseRequest(m_useNewBuffer); +} + +bool RChannelParseRequest::ShouldAbort() +{ + return m_parent->ShutDownRequested(); +} + + +RChannelMarshalRequest:: + RChannelMarshalRequest(RChannelSerializedWriter* parent) +{ + m_parent = parent; +} + +void RChannelMarshalRequest::Process() +{ + m_parent->MarshalItems(); +} + +bool RChannelMarshalRequest::ShouldAbort() +{ + return false; +} + + +RChannelReaderSyncWaiter:: + RChannelReaderSyncWaiter(RChannelReaderImpl* parent, + HANDLE event, + RChannelItemArrayRef* itemDstArray) +{ + m_parent = parent; + m_event = event; + m_itemDstArray = itemDstArray; +} + +RChannelReaderSyncWaiter::~RChannelReaderSyncWaiter() +{ + LogAssert(m_event == INVALID_HANDLE_VALUE); +} + +void RChannelReaderSyncWaiter::ProcessItemArray(RChannelItemArray* itemArray) +{ + LogAssert(m_event != INVALID_HANDLE_VALUE); + m_parent->ThreadSafeSetItemArray(m_itemDstArray, itemArray); + HANDLE event = m_event; + m_event = INVALID_HANDLE_VALUE; + BOOL bRet = ::SetEvent(event); + LogAssert(bRet != 0); +} diff --git a/DryadVertex/VertexHost/system/channel/src/channelhelpers.h b/DryadVertex/VertexHost/system/channel/src/channelhelpers.h new file mode 100644 index 0000000..d8b6686 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelhelpers.h @@ -0,0 +1,201 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +class RChannelBuffer; +class RChannelBufferQueue; +class RChannelFifoWriterBase; + +#include "channelparser.h" +#include "channelreader.h" +#include "channelwriter.h" +#include + +enum RChannelUnitType { + RChannelUnit_Item, + RChannelUnit_BufferBoundary +}; + +class RChannelUnit +{ +public: + virtual ~RChannelUnit(); + + RChannelUnitType GetType(); + + virtual void ReturnToSupplier() = 0; + +protected: + RChannelUnit(RChannelUnitType type); + +private: + RChannelUnitType m_type; + DrBListEntry m_listPtr; + friend class DryadBList; +}; + +typedef class DryadBList ChannelUnitList; + +class RChannelItemUnit : public RChannelUnit +{ +public: + virtual ~RChannelItemUnit(); + + RChannelItemArray* GetItemArray(); + RChannelItem* GetTerminationItem(); + + void DiscardItems(); + + void SetSizes(UInt64 numberOfSubItems, UInt64 dataSize); + UInt64 GetNumberOfSubItems(); + UInt64 GetDataSize(); + +protected: + RChannelItemUnit(RChannelItemArray* payload, + RChannelItem* terminationItem); + + RChannelItemArrayRef m_payload; + RChannelItemRef m_terminationItem; + UInt64 m_numberOfSubItems; + UInt64 m_dataSize; +}; + +class RChannelSerializedUnit : public RChannelItemUnit +{ +public: + RChannelSerializedUnit(RChannelBufferQueue* parent, + RChannelItemArray* payload, + RChannelItem* terminationItem); + void ReturnToSupplier(); + +private: + RChannelBufferQueue* m_parent; +}; + +class RChannelFifoUnit : public RChannelItemUnit +{ +public: + RChannelFifoUnit(RChannelFifoWriterBase* parent); + ~RChannelFifoUnit(); + + void SetPayload(RChannelItemArray* itemArray, + RChannelItemArrayWriterHandler* handler, + RChannelItemType statusCode); + + void SetTerminationItem(RChannelItem* terminationItem); + + + void ReturnToSupplier(); + + void Disgorge(RChannelItemArrayWriterHandler** pHandler, + RChannelItemType* pStatusCode); + + void SetStatusCode(RChannelItemType statusCode); + RChannelItemType GetStatusCode(); + +private: + RChannelFifoWriterBase* m_parent; + RChannelItemArrayWriterHandler* m_handler; + RChannelItemType m_statusCode; +}; + +class RChannelBufferBoundaryUnit : public RChannelUnit +{ +public: + RChannelBufferBoundaryUnit(RChannelBuffer* buffer, + RChannelBufferPrefetchInfo* prefetchCookie); + ~RChannelBufferBoundaryUnit(); + + void ReturnToSupplier(); + +private: + RChannelBuffer* m_buffer; + RChannelBufferPrefetchInfo* m_prefetchCookie; +}; + +class RChannelProcessRequest : public WorkRequest +{ +public: + RChannelProcessRequest(RChannelReaderImpl* parent, + RChannelItemArrayReaderHandler* handler, + void* cancelCookie); + ~RChannelProcessRequest(); + + void Process(); + bool ShouldAbort(); + + void SetItemArray(RChannelItemArray* itemArray); + RChannelItemArray* GetItemArray(); + RChannelItemArrayReaderHandler* GetHandler(); + void* GetCookie(); + + void Cancel(); + +private: + LONG m_aborted; + RChannelReaderImpl* m_parent; + + RChannelItemArrayRef m_itemArray; + RChannelItemArrayReaderHandler* m_handler; + void* m_cookie; +}; + +class RChannelMarshalRequest : public WorkRequest +{ +public: + RChannelMarshalRequest(RChannelSerializedWriter* parent); + + void Process(); + bool ShouldAbort(); + +private: + RChannelSerializedWriter* m_parent; +}; + +class RChannelParseRequest : public WorkRequest +{ +public: + RChannelParseRequest(RChannelBufferQueue* parent, + bool useNewBuffer); + + void Process(); + bool ShouldAbort(); + +private: + RChannelBufferQueue* m_parent; + bool m_useNewBuffer; +}; + +class RChannelReaderSyncWaiter : public RChannelItemArrayReaderHandlerImmediate +{ +public: + RChannelReaderSyncWaiter(RChannelReaderImpl* parent, + HANDLE event, + RChannelItemArrayRef* itemDstArray); + ~RChannelReaderSyncWaiter(); + + void ProcessItemArray(RChannelItemArray* itemArray); + +private: + RChannelReaderImpl* m_parent; + HANDLE m_event; + RChannelItemArrayRef* m_itemDstArray; +}; diff --git a/DryadVertex/VertexHost/system/channel/src/channelitem.cpp b/DryadVertex/VertexHost/system/channel/src/channelitem.cpp new file mode 100644 index 0000000..8e34225 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelitem.cpp @@ -0,0 +1,343 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include +#include +#include + +#pragma unmanaged + + +DrResettableMemoryReader:: + DrResettableMemoryReader(DrMemoryBuffer* pMemoryBuffer) : + DrMemoryBufferReader(pMemoryBuffer) +{ +} + +void DrResettableMemoryReader::ResetToBufferOffset(Size_t offset) +{ + ResetMemoryReader(); + SetBufferOffset(offset); +} + +bool RChannelItem::IsTerminationItem(RChannelItemType type) +{ + return (type == RChannelItem_Restart || + type == RChannelItem_Abort || + type == RChannelItem_ParseError || + type == RChannelItem_MarshalError || + type == RChannelItem_EndOfStream); +} + +RChannelItem::RChannelItem(RChannelItemType type) +{ + m_type = type; + m_dataSequenceNumber = s_invalidSequenceNumber; + m_deliverySequenceNumber = s_invalidSequenceNumber; +} + +RChannelItem::~RChannelItem() +{ +} + +void RChannelItem::Clone(RChannelItemRef* pClonedItem) +{ + DrLogA( "Clone method not implemented"); + *pClonedItem = NULL; +} + +UInt64 RChannelItem::GetNumberOfSubItems() const +{ + /* the base class includes marker items and by default has no + subitems */ + return 0; +} + +void RChannelItem::TruncateSubItems(UInt64 numberOfSubItems) +{ + /* the base class includes marker items and by default has no + subitems */ + LogAssert(false); +} + +UInt64 RChannelItem::GetItemSize() const +{ + /* the base class includes marker items and by default marshals to + zero size */ + return 0; +} + +RChannelItemType RChannelItem::GetType() +{ + return m_type; +} + +UInt64 RChannelItem::GetDataSequenceNumber() +{ + return m_dataSequenceNumber; +} + +void RChannelItem::SetDataSequenceNumber(UInt64 dataSequenceNumber) +{ + m_dataSequenceNumber = dataSequenceNumber; +} + +UInt64 RChannelItem::GetDeliverySequenceNumber() +{ + return m_deliverySequenceNumber; +} + +void RChannelItem::SetDeliverySequenceNumber(UInt64 deliverySequenceNumber) +{ + m_deliverySequenceNumber = deliverySequenceNumber; +} + +DrError RChannelItem::DeSerialize(DrResettableMemoryReader* reader, + Size_t availableSize) +{ + DrLogA("Default DeSerialize method cannot be called on RChannelItem"); + return DrError_Fail; +} + +DrError RChannelItem:: + DeSerializePartial(DrResettableMemoryReader* reader, + Size_t availableSize) +{ + /* by default, any partial buffer signifies an error */ + return DryadError_ItemParseError; +} + +DrError RChannelItem::Serialize(ChannelMemoryBufferWriter* writer) +{ + DrLogA("Default Serialize method cannot be called on RChannelItem"); + return DrError_Fail; +} + +DryadMetaData* RChannelItem::GetMetaData() +{ + return m_metaData.Ptr(); +} + +void RChannelItem::ReplaceMetaData(DryadMetaData* metaData) +{ + m_metaData.Set(metaData); +} + +DrError RChannelItem::GetErrorFromItem() +{ + if (m_metaData.Ptr() != NULL) + { + DryadMTagDrError* tag = + m_metaData.Ptr()->LookUpDrErrorTag(Prop_Dryad_ErrorCode); + if (tag != NULL) + { + return tag->GetDrError(); + } + } + + DrError err; + switch (m_type) + { + case RChannelItem_Data: + err = DrError_OK; + break; + + case RChannelItem_BufferHole: + err = DryadError_BufferHole; + break; + + case RChannelItem_ItemHole: + err = DryadError_ItemHole; + break; + + case RChannelItem_EndOfStream: + err = DrError_EndOfStream; + break; + + case RChannelItem_Restart: + err = DryadError_ChannelRestart; + break; + + case RChannelItem_Abort: + err = DryadError_ChannelAbort; + break; + + case RChannelItem_ParseError: + err = DryadError_ItemParseError; + break; + + case RChannelItem_MarshalError: + err = DryadError_ItemMarshalError; + break; + + default: + LogAssert(false); + err = DrError_InvalidParameter; + } + + return err; +} + +RChannelMarkerItem::RChannelMarkerItem(RChannelItemType type) : + RChannelItem(type) +{ +} + +RChannelMarkerItem::~RChannelMarkerItem() +{ +} + +RChannelMarkerItem* RChannelMarkerItem::Create(RChannelItemType type, + bool withMetaData) +{ + RChannelMarkerItem* item = new RChannelMarkerItem(type); + if (withMetaData) + { + DryadMetaDataRef emptyMetaData; + DryadMetaData::Create(&emptyMetaData); + item->ReplaceMetaData(emptyMetaData); + } + return item; +} + +void RChannelMarkerItem::Clone(RChannelItemRef* pClonedItem) +{ + RChannelMarkerItem* clone = new RChannelMarkerItem(GetType()); + clone->ReplaceMetaData(GetMetaData()); + pClonedItem->Attach(clone); +} + +// +// Create a custom error item with specified type and error code +// +RChannelItem* RChannelMarkerItem::CreateErrorItem(RChannelItemType itemType, + DrError errorCode) +{ + RChannelItem* item = Create(itemType, true); + DryadMetaData* m = item->GetMetaData(); + DryadMTagRef tag; + tag.Attach(DryadMTagDrError::Create(Prop_Dryad_ErrorCode, errorCode)); + m->Append(tag, false); + return item; +} + +RChannelItem* RChannelMarkerItem:: + CreateErrorItemWithDescription(RChannelItemType itemType, + DrError errorCode, + const char* errorDescription) +{ + RChannelItem* item = Create(itemType, true); + DryadMetaData* m = item->GetMetaData(); + DryadMTagRef tag; + + tag.Attach(DryadMTagDrError::Create(Prop_Dryad_ErrorCode, errorCode)); + m->Append(tag, false); + + tag.Attach(DryadMTagString::Create(Prop_Dryad_ErrorString, + errorDescription)); + m->Append(tag, false); + + return item; +} + + +RChannelDataItem::RChannelDataItem() : RChannelItem(RChannelItem_Data) +{ +} + +RChannelDataItem::~RChannelDataItem() +{ +} + +UInt64 RChannelDataItem::GetNumberOfSubItems() const +{ + /* by default an item is indivisible, i.e. has one subitem */ + return 1; +} + +UInt64 RChannelDataItem::GetItemSize() const +{ + /* the base class doesn't know the size, so just return 1 */ + return 1; +} + +RChannelItemArray::RChannelItemArray() +{ + m_numberOfItems = 0; + m_baseItemArray = NULL; + m_itemArray = NULL; +} + +RChannelItemArray::~RChannelItemArray() +{ + delete [] m_baseItemArray; +} + +void RChannelItemArray::SetNumberOfItems(UInt32 numberOfItems) +{ + m_numberOfItems = numberOfItems; + delete [] m_baseItemArray; + m_baseItemArray = new RChannelItemRef [m_numberOfItems]; + m_itemArray = m_baseItemArray; +} + +void RChannelItemArray::ExtendNumberOfItems(UInt32 numberOfItems) +{ + if (m_numberOfItems < numberOfItems) + { + RChannelItemRef* newArray = new RChannelItemRef[numberOfItems]; + LogAssert(newArray != NULL); + + UInt32 i; + for (i=0; i +#include + +#pragma unmanaged + + +RChannelItemMarshalerBase::RChannelItemMarshalerBase() +{ + m_maxMarshalBatchSize = 0; +} + +RChannelItemMarshalerBase::~RChannelItemMarshalerBase() +{ +} + +void RChannelItemMarshalerBase::Reset() +{ +} + +void RChannelItemMarshalerBase:: + SetMaxMarshalBatchSize(UInt32 maxMarshalBatchSize) +{ + m_maxMarshalBatchSize = maxMarshalBatchSize; +} + +UInt32 RChannelItemMarshalerBase::GetMaxMarshalBatchSize() +{ + return m_maxMarshalBatchSize; +} + +void RChannelItemMarshalerBase::SetMarshalerIndex(UInt32 index) +{ + m_index = index; +} + +UInt32 RChannelItemMarshalerBase::GetMarshalerIndex() +{ + return m_index; +} + +void RChannelItemMarshalerBase::SetMarshalerContext(RChannelContext* context) +{ + m_context = context; +} + +RChannelContext* RChannelItemMarshalerBase::GetMarshalerContext() +{ + return m_context; +} + + +RChannelItemMarshaler::~RChannelItemMarshaler() +{ +} + +RChannelStdItemMarshalerBase::~RChannelStdItemMarshalerBase() +{ +} + +DrError RChannelStdItemMarshalerBase:: + MarshalItem(ChannelMemoryBufferWriter* writer, + RChannelItem* item, + bool flush, + RChannelItemRef* pFailureItem) +{ + if (item->GetType() != RChannelItem_Data) + { + return MarshalMarker(writer, item, flush, pFailureItem); + } + + DrError err = item->Serialize(writer); + + if (err != DrError_OK && err != DrError_IncompleteOperation) + { + pFailureItem->Attach(RChannelMarkerItem:: + CreateErrorItem(RChannelItem_MarshalError, + err)); + return DryadError_ChannelAbort; + } + + return err; +} + +DrError RChannelStdItemMarshalerBase:: + MarshalMarker(ChannelMemoryBufferWriter* writer, + RChannelItem* item, + bool flush, + RChannelItemRef* pFailureItem) +{ + return DrError_OK; +} + +RChannelStdItemMarshaler::~RChannelStdItemMarshaler() +{ +} diff --git a/DryadVertex/VertexHost/system/channel/src/channelparser.cpp b/DryadVertex/VertexHost/system/channel/src/channelparser.cpp new file mode 100644 index 0000000..cfa93a0 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelparser.cpp @@ -0,0 +1,879 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include +#include +#include +#include + +#pragma unmanaged + + +void RChannelBufferRecord::SetData(DrMemoryBuffer* data) +{ + m_buffer = data; +} + +DrMemoryBuffer* RChannelBufferRecord::GetData() const +{ + return m_buffer; +} + + +RChannelBufferItem::RChannelBufferItem(DrMemoryBuffer* buffer) +{ + SetNumberOfRecords(1); + GetRecordArray()[0].SetData(buffer); +} + +RChannelBufferItem* RChannelBufferItem::Create(DrMemoryBuffer* buffer) +{ + return new RChannelBufferItem(buffer); +} + +DrMemoryBuffer* RChannelBufferItem::GetData() const +{ + return GetRecordArray()[0].GetData(); +} + +UInt64 RChannelBufferItem::GetItemSize() const +{ + return GetData()->GetAvailableSize(); +} + + +RChannelItemParserBase::RChannelItemParserBase() +{ + m_maxParseBatchSize = 0; + m_index = 0; +} + +RChannelItemParserBase::~RChannelItemParserBase() +{ +} + +void RChannelItemParserBase::SetMaxParseBatchSize(UInt32 maxParseBatchSize) +{ + m_maxParseBatchSize = maxParseBatchSize; +} + +UInt32 RChannelItemParserBase::GetMaxParseBatchSize() +{ + return m_maxParseBatchSize; +} + +void RChannelItemParserBase::SetParserIndex(UInt32 index) +{ + m_index = index; +} + +UInt32 RChannelItemParserBase::GetParserIndex() +{ + return m_index; +} + +void RChannelItemParserBase::SetParserContext(RChannelContext* context) +{ + m_context = context; +} + +RChannelContext* RChannelItemParserBase::GetParserContext() +{ + return m_context; +} + + +RChannelRawItemParser::~RChannelRawItemParser() +{ +} + + +RChannelItemTransformerBase::~RChannelItemTransformerBase() +{ +} + +void RChannelItemTransformerBase:: + InitializeTransformer(RChannelItemParserBase* parent, + DVErrorReporter* errorReporter) +{ +} + +void RChannelItemTransformerBase:: + TransformItem(RChannelItemRef& inputItem, + SyncItemWriterBase* writer, + DVErrorReporter* errorReporter) +{ + writer->WriteItemSyncConsumingReference(inputItem); +} + +void RChannelItemTransformerBase:: + FlushTransformer(SyncItemWriterBase* writer, + DVErrorReporter* errorReporter) +{ +} + +void RChannelItemTransformerBase:: + ReportTransformerErrorItem(RChannelItemRef& errorItem, + DVErrorReporter* errorReporter) +{ +} + + +RChannelItemTransformer::~RChannelItemTransformer() +{ +} + + +RChannelTransformerParserBase::RChannelTransformerParserBase() +{ + m_transformedAny = false; +} + +RChannelTransformerParserBase::~RChannelTransformerParserBase() +{ + LogAssert(m_itemList.IsEmpty()); +} + +void RChannelTransformerParserBase:: + SetTransformer(RChannelItemTransformerBase* transformer) +{ + m_transformer = transformer; +} + +RChannelItemTransformerBase* RChannelTransformerParserBase::GetTransformer() +{ + return m_transformer; +} + +RChannelItem* RChannelTransformerParserBase:: + RawParseItem(bool restartParser, + RChannelBuffer* inData, + RChannelBufferPrefetchInfo** outPrefetchCookie) +{ + *outPrefetchCookie = NULL; + + if (restartParser) + { + if (m_transformedAny) + { + DrLogI("Flushing transformer for restart"); + + m_transformer->FlushTransformer(this, this); + /* clear any errors */ + ReportError(DrError_OK, (DryadMetaData*) NULL); + + RChannelItemRef item; + while (m_itemList.IsEmpty() == false) + { + /* remove the saved item from the list; it will be + garbage collected when the refcount goes out of + scope */ + item.Attach(m_itemList.CastOut(m_itemList.RemoveHead())); + } + + m_transformedAny = false; + } + else + { + LogAssert(m_itemList.IsEmpty()); + m_transformer->InitializeTransformer(this, this); + } + } + + if (m_itemList.IsEmpty() == false) + { + LogAssert(m_transformedAny); + LogAssert(inData == NULL); + RChannelItem* item = m_itemList.CastOut(m_itemList.RemoveHead()); + return item; + } + + if (inData == NULL) + { + LogAssert(m_transformedAny); + /* we have sent back all the items that got written from the + last buffer, but the base class doesn't know this yet, so + inform it by returning NULL */ + return NULL; + } + + m_transformedAny = true; + + RChannelBufferType bType = inData->GetType(); + if (bType == RChannelBuffer_Hole || + bType == RChannelBuffer_EndOfStream) + { + RChannelBufferMarker* mBuffer = + dynamic_cast(inData); + LogAssert(mBuffer != NULL); + RChannelItemRef markerItem = mBuffer->GetItem(); + + if (bType == RChannelBuffer_Hole) + { + m_transformer->ReportTransformerErrorItem(markerItem, this); + } + + m_transformer->FlushTransformer(this, this); + + if (!NoError()) + { + markerItem. + Attach(RChannelMarkerItem::Create(RChannelItem_ParseError, + false)); + markerItem->ReplaceMetaData(GetErrorMetaData()); + } + + m_itemList.InsertAsTail(m_itemList.CastIn(markerItem.Detach())); + } + else + { + LogAssert(bType == RChannelBuffer_Data); + RChannelBufferData* dBuffer = + dynamic_cast(inData); + LogAssert(dBuffer != NULL); + + DryadLockedMemoryBuffer* block = dBuffer->GetData(); + LogAssert(block->GetAvailableSize() > 0); + LogAssert(block->IsGrowable() == false); + + RChannelItemRef dataItem; + dataItem.Attach(RChannelBufferItem::Create(block)); + + m_transformer->TransformItem(dataItem, this, this); + + if (!NoError()) + { + RChannelItem* errorItem; + + errorItem = + RChannelMarkerItem::Create(RChannelItem_ParseError, false); + errorItem->ReplaceMetaData(GetErrorMetaData()); + + m_itemList.InsertAsTail(m_itemList.CastIn(errorItem)); + } + } + + if (m_itemList.IsEmpty()) + { + return NULL; + } + else + { + RChannelItem* item = m_itemList.CastOut(m_itemList.RemoveHead()); + return item; + } +} + +void RChannelTransformerParserBase:: + WriteItemSyncConsumingReference(RChannelItemRef& item) +{ + m_itemList.InsertAsTail(m_itemList.CastIn(item.Detach())); +} + +DrError RChannelTransformerParserBase::GetWriterStatus() +{ + return DrError_OK; +} + + +RChannelItemParserNoRefImpl::RChannelItemParserNoRefImpl() +{ + m_needsReset = true; +} + +RChannelItemParserNoRefImpl::~RChannelItemParserNoRefImpl() +{ + ResetParserInternal(); +} + +void RChannelItemParserNoRefImpl::ResetParser() +{ +} + +RChannelItem* RChannelItemParserNoRefImpl:: + ParsePartialItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + RChannelBufferMarker* + markerBuffer) +{ + return NULL; +} + +void RChannelItemParserNoRefImpl::ResetParserInternal() +{ + m_savedItem = NULL; + + DrBListEntry* listEntry = m_bufferList.GetHead(); + while (listEntry != NULL) + { + RChannelBufferData* buffer = m_bufferList.CastOut(listEntry); + listEntry = m_bufferList.GetNext(listEntry); + m_bufferList.Remove(m_bufferList.CastIn(buffer)); + buffer->DecRef(); + } + + m_bufferListStartOffset = 0; + + m_needsData = true; +} + +void RChannelItemParserNoRefImpl::DiscardBufferPrefix(Size_t discardLength) +{ + while (discardLength > 0) + { + LogAssert(m_bufferList.IsEmpty() == false); + RChannelBufferData* buffer = + m_bufferList.CastOut(m_bufferList.GetHead()); + DryadLockedMemoryBuffer* mBuffer = buffer->GetData(); + + LogAssert(mBuffer->GetAvailableSize() > m_bufferListStartOffset); + Size_t tailLength = + mBuffer->GetAvailableSize() - m_bufferListStartOffset; + + if (discardLength >= tailLength) + { + discardLength -= tailLength; + m_bufferList.Remove(m_bufferList.CastIn(buffer)); + buffer->DecRef(); + m_bufferListStartOffset = 0; + } + else + { + m_bufferListStartOffset += discardLength; + discardLength = 0; + } + } +} + +RChannelItem* RChannelItemParserNoRefImpl:: + DealWithPartialBuffer(RChannelBufferMarker* mBuffer) +{ + RChannelItem* item; + + item = ParsePartialItem(&m_bufferList, m_bufferListStartOffset, + mBuffer); + + if (m_bufferList.IsEmpty()) + { + LogAssert(m_bufferListStartOffset == 0); + } + else + { + DrBListEntry* listEntry = m_bufferList.GetHead(); + while (listEntry != NULL) + { + RChannelBufferData* buffer = m_bufferList.CastOut(listEntry); + listEntry = m_bufferList.GetNext(listEntry); + m_bufferList.Remove(m_bufferList.CastIn(buffer)); + buffer->DecRef(); + } + + m_bufferListStartOffset = 0; + } + + RChannelItemRef markerItem = mBuffer->GetItem(); + LogAssert(markerItem != NULL); + + if (mBuffer->GetType() == RChannelBuffer_Hole) + { + LogAssert(markerItem->GetType() == RChannelItem_BufferHole); + } + else + { + LogAssert(markerItem->GetType() == RChannelItem_EndOfStream); + } + + if (item == NULL) + { + item = markerItem.Detach(); + } + else + { + m_savedItem = markerItem; + } + + return item; +} + +RChannelItem* RChannelItemParserNoRefImpl:: + RawParseItem(bool restartParser, + RChannelBuffer* inData, + RChannelBufferPrefetchInfo** outPrefetchCookie) +{ + *outPrefetchCookie = NULL; + + if (m_needsReset) + { + LogAssert(restartParser == true); + } + + if (restartParser) + { + // todo: remove comment if not logging +// DrLogD( "resetting parser"); + + ResetParser(); + + ResetParserInternal(); + } + + LogAssert(m_needsData == (inData != NULL)); + + RChannelItem* item = NULL; + + if (inData != NULL) + { + LogAssert(m_savedItem == NULL); + + RChannelBufferType bType = inData->GetType(); + if (bType == RChannelBuffer_Hole || + bType == RChannelBuffer_EndOfStream) + { + item = DealWithPartialBuffer((RChannelBufferMarker *) inData); + } + else + { + LogAssert(bType == RChannelBuffer_Data); + RChannelBufferData* dBuffer = (RChannelBufferData *) inData; + dBuffer->IncRef(); + DryadLockedMemoryBuffer* block = dBuffer->GetData(); + LogAssert(block->GetAvailableSize() > 0); + LogAssert(block->IsGrowable() == false); + m_bufferList.InsertAsTail(m_bufferList.CastIn(dBuffer)); + } + } + + if (item == NULL) + { + if (m_savedItem != NULL) + { + item = m_savedItem.Detach(); + } + else + { + Size_t itemLength = 0; + + if (m_bufferList.IsEmpty() == false) + { + item = ParseNextItem(&m_bufferList, m_bufferListStartOffset, + &itemLength); + } + + if (item == NULL) + { + LogAssert(itemLength == 0); + } + else + { + DiscardBufferPrefix(itemLength); + } + } + } + + if (item == NULL) + { + m_needsReset = false; + m_needsData = true; + } + else + { + m_needsReset = RChannelItem::IsTerminationItem(item->GetType()); + m_needsData = false; + } + + return item; +} + +RChannelItemParser::~RChannelItemParser() +{ +} + + +RChannelStdItemParserNoRefImpl:: + RChannelStdItemParserNoRefImpl(DObjFactoryBase* factory) +{ + m_factory = factory; +} + +RChannelStdItemParserNoRefImpl::~RChannelStdItemParserNoRefImpl() +{ +} + +void RChannelStdItemParserNoRefImpl::ResetParser() +{ + m_pendingErrorItem = NULL; +} + +RChannelItem* RChannelStdItemParserNoRefImpl:: + ParseNextItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + Size_t* pOutLength) +{ + if (m_pendingErrorItem != NULL) + { + return m_pendingErrorItem.Detach(); + } + + LogAssert(bufferList->IsEmpty() == false); + + RChannelBufferData* tailBuffer = + bufferList->CastOut(bufferList->GetTail()); + Size_t tailBufferSize = + tailBuffer->GetData()->GetAvailableSize(); + + DrRef buffer; + buffer.Attach(new RChannelReaderBuffer(bufferList, + startOffset, + tailBufferSize)); + + DrResettableMemoryReader reader(buffer); + + RChannelItem* item = (RChannelItem *) m_factory->AllocateObjectUntyped(); + DrError err = item->DeSerialize(&reader, buffer->GetAvailableSize()); + + if (err == DrError_OK) + { + *pOutLength = reader.GetBufferOffset(); + return item; + } + + m_factory->FreeObjectUntyped(item); + + if (err == DrError_EndOfStream) + { + return NULL; + } + else + { + return RChannelMarkerItem::CreateErrorItem(RChannelItem_ParseError, + err); + } +} + +RChannelItem* RChannelStdItemParserNoRefImpl:: + ParsePartialItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + RChannelBufferMarker* + markerBuffer) +{ + if (m_pendingErrorItem != NULL) + { + return m_pendingErrorItem.Detach(); + } + + if (bufferList->IsEmpty()) + { + return NULL; + } + + RChannelBufferData* tailBuffer = + bufferList->CastOut(bufferList->GetTail()); + Size_t tailBufferSize = + tailBuffer->GetData()->GetAvailableSize(); + + DrRef buffer; + buffer.Attach(new RChannelReaderBuffer(bufferList, + startOffset, + tailBufferSize)); + + DrResettableMemoryReader reader(buffer); + + RChannelItem* item = (RChannelItem *) m_factory->AllocateObjectUntyped(); + DrError err = item->DeSerializePartial(&reader, + buffer->GetAvailableSize()); + + if (err == DrError_OK) + { + return item; + } + + m_factory->FreeObjectUntyped(item); + + if (err == DrError_EndOfStream) + { + return NULL; + } + else + { + return RChannelMarkerItem::CreateErrorItem(RChannelItem_ParseError, + err); + } +} + +RChannelStdItemParser::RChannelStdItemParser(DObjFactoryBase* factory) : + RChannelStdItemParserNoRefImpl(factory) +{ +} + +RChannelStdItemParser::~RChannelStdItemParser() +{ +} + + +RChannelLengthDelimitedItemParserNoRefImpl:: + RChannelLengthDelimitedItemParserNoRefImpl() +{ + ResetParser(); +} + +RChannelLengthDelimitedItemParserNoRefImpl:: + ~RChannelLengthDelimitedItemParserNoRefImpl() +{ +} + +void RChannelLengthDelimitedItemParserNoRefImpl::ResetParser() +{ + m_itemLength = 0; + m_accumulatedLength = 0; +} + +void RChannelLengthDelimitedItemParserNoRefImpl:: + AddMetaData(RChannelItem* item, + ChannelDataBufferList* + bufferList, + Size_t startOffset, + Size_t endOffset) +{ + LogAssert(bufferList->IsEmpty() == false); + + RChannelBufferData* buffer = bufferList->CastOut(bufferList->GetHead()); + DryadMetaDataRef startMetaData; + buffer->GetOffsetMetaData(true, startOffset, &startMetaData); + + buffer = bufferList->CastOut(bufferList->GetTail()); + DryadMetaDataRef endMetaData; + buffer->GetOffsetMetaData(false, endOffset, &endMetaData); + + DryadMetaData* m = item->GetMetaData(); + /* this is only called for items which already have metadata for + performance reasons */ + LogAssert(m != NULL); + + DryadMTagMetaDataRef tag; + bool brc; + + tag = m->LookUpMetaDataTag(DryadTag_ItemStart); + if (tag == NULL) + { + tag.Attach(DryadMTagMetaData::Create(DryadTag_ItemStart, + startMetaData, true)); + brc = m->Append(tag, false); + LogAssert(brc == true); + } + else + { + tag->GetMetaData()->AppendMetaDataTags(startMetaData, false); + } + + tag = m->LookUpMetaDataTag(DryadTag_ItemEnd); + if (tag == NULL) + { + tag.Attach(DryadMTagMetaData::Create(DryadTag_ItemEnd, + endMetaData, true)); + brc = m->Append(tag, false); + LogAssert(brc == true); + } + else + { + tag->GetMetaData()->AppendMetaDataTags(endMetaData, false); + } +} + +RChannelItem* RChannelLengthDelimitedItemParserNoRefImpl:: + FetchItem(ChannelDataBufferList* + bufferList, + Size_t startOffset, + Size_t tailBufferSize) +{ + Size_t tailGapLength = m_accumulatedLength - m_itemLength; + + LogAssert(tailBufferSize > tailGapLength); + + Size_t endOffset = tailBufferSize - tailGapLength; + + RChannelReaderBuffer* itemBuffer = + new RChannelReaderBuffer(bufferList, startOffset, endOffset); + LogAssert(itemBuffer->GetAvailableSize() == m_itemLength); + + RChannelItem* item = ParseItemWithLength(itemBuffer, m_itemLength); + LogAssert(item != NULL); + LogAssert(item->GetType() == RChannelItem_Data || + item->GetType() == RChannelItem_ItemHole || + item->GetType() == RChannelItem_ParseError); + + itemBuffer->DecRef(); + + if (item->GetMetaData() != NULL) + { + AddMetaData(item, bufferList, startOffset, endOffset); + } + + ResetParser(); + + return item; +} + +RChannelItem* RChannelLengthDelimitedItemParserNoRefImpl:: + MaybeFetchItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + Size_t* pOutLength) +{ + /* make sure there are at least two items */ + LogAssert(bufferList->GetHead() != bufferList->GetTail()); + + RChannelBufferData* buffer = bufferList->CastOut(bufferList->GetTail()); + Size_t tailBufferSize = buffer->GetData()->GetAvailableSize(); + + LogAssert(m_accumulatedLength < m_itemLength); + m_accumulatedLength += tailBufferSize; + + if (m_accumulatedLength >= m_itemLength) + { + Size_t thisLength = m_itemLength; + RChannelItem* item = FetchItem(bufferList, startOffset, + tailBufferSize); + *pOutLength = thisLength; + return item; + } + else + { + *pOutLength = 0; + return NULL; + } +} + +RChannelItem* RChannelLengthDelimitedItemParserNoRefImpl:: + ParseNextItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + Size_t* pOutLength) +{ + if (m_itemLength > 0) + { + return MaybeFetchItem(bufferList, startOffset, pOutLength); + } + else + { + LogAssert(bufferList->IsEmpty() == false); + + RChannelBufferData* buffer = + bufferList->CastOut(bufferList->GetTail()); + Size_t tailBufferSize = + buffer->GetData()->GetAvailableSize(); + LogAssert(bufferList->GetHead() != bufferList->GetTail() || + startOffset < tailBufferSize); + + RChannelReaderBuffer* lengthCheckBuffer = + new RChannelReaderBuffer(bufferList, startOffset, tailBufferSize); + + m_accumulatedLength = lengthCheckBuffer->GetAvailableSize(); + + RChannelItem* errorItem = NULL; + LengthStatus ls = GetNextItemLength(lengthCheckBuffer, + &m_itemLength, + &errorItem); + + lengthCheckBuffer->DecRef(); + + switch (ls) + { + default: + LogAssert(false); + return NULL; + + case LS_ParseError: + LogAssert(errorItem != NULL); + LogAssert(errorItem->GetType() == RChannelItem_ParseError); + LogAssert(errorItem->GetMetaData() != NULL); + AddMetaData(errorItem, bufferList, startOffset, tailBufferSize); + ResetParser(); + *pOutLength = 0; + return errorItem; + + case LS_NeedsData: + LogAssert(errorItem == NULL); + m_itemLength = 0; + *pOutLength = 0; + return NULL; + + case LS_Ok: + LogAssert(errorItem == NULL); + LogAssert(m_itemLength > 0); + + if (m_accumulatedLength >= m_itemLength) + { + Size_t thisLength = m_itemLength; + RChannelItem* item = FetchItem(bufferList, startOffset, + tailBufferSize); + *pOutLength = thisLength; + return item; + } + else + { + *pOutLength = 0; + return NULL; + } + } + } +} + +RChannelItem* RChannelLengthDelimitedItemParserNoRefImpl:: + ParsePartialItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + RChannelBufferMarker* + markerBuffer) +{ + if (bufferList->IsEmpty()) + { + return NULL; + } + else + { + RChannelBufferData* buffer = + bufferList->CastOut(bufferList->GetTail()); + Size_t tailBufferSize = + buffer->GetData()->GetAvailableSize(); + RChannelItem* item = + RChannelMarkerItem::Create(RChannelItem_ItemHole, true); + AddMetaData(item, bufferList, startOffset, tailBufferSize); + ResetParser(); + return item; + } +} + +RChannelLengthDelimitedItemParser::~RChannelLengthDelimitedItemParser() +{ +} + +DryadParserFactoryBase::~DryadParserFactoryBase() +{ +} + +DryadParserFactory::~DryadParserFactory() +{ +} + +DryadMarshalerFactoryBase::~DryadMarshalerFactoryBase() +{ +} + +DryadMarshalerFactory::~DryadMarshalerFactory() +{ +} diff --git a/DryadVertex/VertexHost/system/channel/src/channelreader.cpp b/DryadVertex/VertexHost/system/channel/src/channelreader.cpp new file mode 100644 index 0000000..64d164c --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelreader.cpp @@ -0,0 +1,1485 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include "channelreader.h" +#include "channelparser.h" +#include "channelbufferqueue.h" +#include "channelhelpers.h" +#include "workqueue.h" +#include "dryaderrordef.h" + +#pragma unmanaged + + +/* put this here for want of a better source file */ +RChannelContext::~RChannelContext() +{ +} + + +RChannelReader::~RChannelReader() +{ +} + +bool RChannelReader::FetchNextItem(RChannelItemRef* pOutItem, + DrTimeInterval timeOut) +{ + RChannelItemArrayRef itemArray; + bool delivered = FetchNextItemArray(1, &itemArray, timeOut); + LogAssert(itemArray != NULL); + if (delivered) + { + if (itemArray->GetNumberOfItems() == 1) + { + RChannelItemRef* a = itemArray->GetItemArray(); + pOutItem->TransferFrom(a[0]); + } + else + { + LogAssert(itemArray->GetNumberOfItems() == 0); + *pOutItem = NULL; + } + } + else + { + LogAssert(itemArray->GetNumberOfItems() == 0); + *pOutItem = NULL; + } + + return delivered; +} + +DrError RChannelReader::ReadItemSync(RChannelItemRef* pOutItem) +{ + FetchNextItem(pOutItem, DrTimeInterval_Infinite); + + if ((*pOutItem) == NULL) + { + RChannelItemRef writerTermination; + RChannelItemRef readerTermination; + GetTerminationItems(&writerTermination, + &readerTermination); + + if (writerTermination == NULL) + { + LogAssert(readerTermination != NULL); + *pOutItem = readerTermination; + } + else + { + *pOutItem = writerTermination; + } + } + + return DrError_OK; +} + +DrError RChannelReader::GetTerminationStatus(DryadMetaDataRef* pErrorData) +{ + DrError status; + + RChannelItemRef writerTermination; + RChannelItemRef readerTermination; + GetTerminationItems(&writerTermination, + &readerTermination); + + LogAssert(readerTermination.Ptr() != NULL); + if (writerTermination.Ptr() != NULL && + writerTermination->GetType() != RChannelItem_EndOfStream) + { + status = writerTermination->GetErrorFromItem(); + *pErrorData = writerTermination->GetMetaData(); + } + else if (readerTermination->GetType() == RChannelItem_EndOfStream) + { + status = DrError_EndOfStream; + *pErrorData = readerTermination->GetMetaData(); + } + else + { + status = DryadError_ProcessingInterrupted; + *pErrorData = readerTermination->GetMetaData(); + } + + return status; +} + +RChannelReaderSupplier::~RChannelReaderSupplier() +{ +} + +RChannelItemArrayReaderHandler::RChannelItemArrayReaderHandler() +{ + m_maximumArraySize = 1; +} + +RChannelItemArrayReaderHandler::~RChannelItemArrayReaderHandler() +{ +} + +void RChannelItemArrayReaderHandler::SetMaximumArraySize(UInt32 maximumArraySize) +{ + m_maximumArraySize = maximumArraySize; +} + +UInt32 RChannelItemArrayReaderHandler::GetMaximumArraySize() +{ + return m_maximumArraySize; +} + +bool RChannelItemArrayReaderHandlerQueued::ImmediateDispatch() +{ + return false; +} + +bool RChannelItemArrayReaderHandlerImmediate::ImmediateDispatch() +{ + return true; +} + +RChannelItemReaderHandler::~RChannelItemReaderHandler() +{ +} + +void RChannelItemReaderHandler:: + ProcessItemArray(RChannelItemArray* deliveredArray) +{ + if (deliveredArray->GetNumberOfItems() == 1) + { + RChannelItemRef* a = deliveredArray->GetItemArray(); + ProcessItem(a[0]); + } + else + { + LogAssert(deliveredArray->GetNumberOfItems() == 0); + ProcessItem(NULL); + } +} + +bool RChannelItemReaderHandlerQueued::ImmediateDispatch() +{ + return false; +} + +bool RChannelItemReaderHandlerImmediate::ImmediateDispatch() +{ + return true; +} + + +RChannelReaderImpl::RChannelReaderImpl() +{ + m_workQueue = NULL; + m_interruptHandler = NULL; + m_supplier = NULL; + m_lazyStart = false; + + m_state = RS_Stopped; + m_drainEvent = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(m_drainEvent != NULL); + m_interruptEvent = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(m_interruptEvent != NULL); + m_startedSupplierEvent = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(m_startedSupplierEvent != NULL); + m_startedSupplier = false; + m_numberOfSubItemsRead = 0; + m_dataSizeRead = 0; +} + +RChannelReaderImpl::~RChannelReaderImpl() +{ + LogAssert(m_state == RS_Closed); + LogAssert(m_cookieMap.empty()); + LogAssert(m_eventMap.empty()); + LogAssert(m_unitList.IsEmpty()); + LogAssert(m_handlerList.IsEmpty()); + LogAssert(m_interruptHandler == NULL); + LogAssert(m_readerTerminationItem == NULL); + LogAssert(m_writerTerminationItem == NULL); + + BOOL bRet = ::CloseHandle(m_drainEvent); + LogAssert(bRet != 0); + bRet = ::CloseHandle(m_interruptEvent); + LogAssert(bRet != 0); + bRet = ::CloseHandle(m_startedSupplierEvent); + LogAssert(bRet != 0); +} + +void RChannelReaderImpl::Initialize(RChannelReaderSupplier* supplier, + WorkQueue* workQueue, bool lazyStart) +{ + m_supplier = supplier; + m_workQueue = workQueue; + m_lazyStart = lazyStart; +} + +void RChannelReaderImpl::SetURI(const char* uri) +{ + m_uri = uri; +} + +const char* RChannelReaderImpl::GetURI() +{ + return m_uri; +} + +void RChannelReaderImpl::Start(RChannelBufferPrefetchInfo* prefetchCookie) +{ + bool startSupplier = false; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_supplier != NULL); + LogAssert(m_state == RS_Stopped); + LogAssert(m_cookieMap.empty()); + LogAssert(m_eventMap.empty()); + LogAssert(m_unitList.IsEmpty()); + LogAssert(m_handlerList.IsEmpty()); + LogAssert(m_interruptHandler == NULL); + LogAssert(m_startedSupplier == false); + m_sendLatch.Start(); + m_unitLatch.Start(); + + m_state = RS_Running; + m_readerTerminationItem = NULL; + m_writerTerminationItem = NULL; + m_numberOfSubItemsRead = 0; + m_dataSizeRead = 0; + m_prefetchCookie = prefetchCookie; + + if (m_lazyStart == false) + { + /* when we exit the lock, we will start the + supplier. After the supplier has been started we'll set + m_startedSupplierEvent to prevent a race condition in + Interrupt */ + startSupplier = true; + BOOL bRet = ::ResetEvent(m_startedSupplierEvent); + LogAssert(bRet != 0); + m_startedSupplier = true; + } + } + + if (startSupplier) + { + //todo: remove comments if not logging +// DrLogI( "Starting Supplier"); + m_supplier->StartSupplier(m_prefetchCookie); +// DrLogI( "Started Supplier"); + BOOL bRet = ::SetEvent(m_startedSupplierEvent); + LogAssert(bRet != 0); + } +} + +/* called with RChannelReaderImpl::m_baseDR held */ +void RChannelReaderImpl:: + FillEmptyHandlers(ChannelProcessRequestList* handlerDispatch) +{ + while (m_handlerList.IsEmpty() == false) + { + RChannelProcessRequest* request = + m_handlerList.CastOut(m_handlerList.GetHead()); + + LogAssert(request->GetItemArray() == NULL); + RChannelItemArrayRef emptyArray; + emptyArray.Attach(new RChannelItemArray()); + request->SetItemArray(emptyArray); + + handlerDispatch->TransitionToTail(handlerDispatch-> + CastIn(request)); + } +} + +/* called with RChannelReaderImpl::m_baseDR held */ +void RChannelReaderImpl:: + TransferWaitingItems(const char* caller, + ChannelProcessRequestList* requestList, + ChannelUnitList* returnUnitList) +{ + LogAssert(m_handlerList.IsEmpty() == false); + LogAssert(m_unitList.IsEmpty() == false); + + do + { + RChannelProcessRequest* request = + m_handlerList.CastOut(m_handlerList.GetHead()); + + LogAssert(m_writerTerminationItem == NULL); + + /* the request needs to be taken off the handler list now, + since if we hit a termination item below, the remains of + the handler list will be dumped on to the end of the + request list */ + requestList->TransitionToTail(requestList->CastIn(request)); + + UInt32 maxArraySize = request->GetHandler()->GetMaximumArraySize(); + LogAssert(request->GetItemArray() == NULL); + + RChannelUnit* unit = m_unitList.CastOut(m_unitList.GetHead()); + LogAssert(unit->GetType() == RChannelUnit_Item); + RChannelItemUnit* itemUnit = (RChannelItemUnit *) unit; + + RChannelItemArray* srcArray = itemUnit->GetItemArray(); + UInt32 nItems = srcArray->GetNumberOfItems(); + + if (nItems <= maxArraySize) + { + //todo: remove comments if not logging +// DrLogD( "Transferring bulk item array", +// "Size %u max acceptable %u", nItems, maxArraySize); + + m_numberOfSubItemsRead += itemUnit->GetNumberOfSubItems(); + m_dataSizeRead += itemUnit->GetDataSize(); + + request->SetItemArray(srcArray); + + m_writerTerminationItem = itemUnit->GetTerminationItem(); + if (m_writerTerminationItem != NULL) + { + FillEmptyHandlers(requestList); + } + + itemUnit->DiscardItems(); + + returnUnitList->TransitionToTail(returnUnitList->CastIn(unit)); + + /* send off any blocked buffer boundary units */ + while (!m_unitList.IsEmpty() && + (unit = m_unitList.CastOut(m_unitList.GetHead()))-> + GetType() != RChannelUnit_Item) + { + LogAssert(unit->GetType() == RChannelUnit_BufferBoundary); + returnUnitList->TransitionToTail(returnUnitList-> + CastIn(unit)); + // todo: remove comments if not logging +// DrLogD( +// "adding buffer boundary dispatch request", +// "caller: %s", caller); + } + } + else + { +// DrLogD( "Transferring partial item array", +// "Size %u max acceptable %u", nItems, maxArraySize); + + RChannelItemArrayRef dstArray; + dstArray.Attach(new RChannelItemArray()); + dstArray->SetNumberOfItems(maxArraySize); + RChannelItemRef* dstItems = dstArray->GetItemArray(); + RChannelItemRef* srcItems = srcArray->GetItemArray(); + + UInt64 subItemsRead = 0; + UInt64 dataSizeRead = 0; + UInt32 i; + for (i=0; iGetNumberOfSubItems(); + dataSizeRead += dstItems[i]->GetItemSize(); + } + + LogAssert(subItemsRead <= itemUnit->GetNumberOfSubItems()); + LogAssert(dataSizeRead <= itemUnit->GetDataSize()); + + m_numberOfSubItemsRead += subItemsRead; + m_dataSizeRead += dataSizeRead; + itemUnit->SetSizes(itemUnit->GetNumberOfSubItems() - subItemsRead, + itemUnit->GetDataSize() - dataSizeRead); + + request->SetItemArray(dstArray); + srcArray->DiscardPrefix(maxArraySize); + } + } while (m_handlerList.IsEmpty() == false && + m_unitList.IsEmpty() == false); +} + +/* called with RChannelReaderImpl::m_baseDR held */ +void RChannelReaderImpl::AddUnitToQueue(const char* caller, + RChannelUnit* unit, + ChannelProcessRequestList* requestList, + ChannelUnitList* returnUnitList) +{ + LogAssert(m_state == RS_Running || m_state == RS_InterruptingSupplier); + + if (unit->GetType() == RChannelUnit_Item) + { + /* make sure we haven't already sent out a termination item for + processing */ + LogAssert(m_writerTerminationItem == NULL && + m_readerTerminationItem == NULL); + + if (m_handlerList.IsEmpty()) + { +// DrLogD( +// "queueing item list", +// "caller: %s", caller); + + m_unitList.InsertAsTail(m_unitList.CastIn(unit)); + unit = NULL; + } + else + { +// DrLogD( +// "adding process work request", +// "caller: %s", caller); + + LogAssert(m_unitList.IsEmpty()); + + /* put this in the unit list since that's where + TransferWaitingItems expects to find it */ + m_unitList.InsertAsTail(m_unitList.CastIn(unit)); + + TransferWaitingItems(caller, + requestList, + returnUnitList); + } + } + else + { + LogAssert(unit->GetType() == RChannelUnit_BufferBoundary); + if (m_unitList.IsEmpty()) + { + returnUnitList->InsertAsTail(returnUnitList->CastIn(unit)); +// DrLogD( +// "adding buffer boundary dispatch request", +// "caller: %s", caller); + } + else + { + m_unitList.InsertAsTail(m_unitList.CastIn(unit)); +// DrLogD( +// "queueing buffer boundary dispatch request", +// "caller: %s", caller); + } + } +} + +/* called with RChannelReaderImpl::m_baseDR held */ +void RChannelReaderImpl:: + AddHandlerToQueue(const char* caller, + RChannelProcessRequest* request, + ChannelProcessRequestList* requestList, + ChannelUnitList* returnUnitList) +{ + LogAssert(m_state == RS_Running); + + if (m_unitList.IsEmpty()) + { +// DrLogD( +// "queueing handler", +// "caller: %s", caller); + m_handlerList.InsertAsTail(m_handlerList.CastIn(request)); + } + else + { +// DrLogD( +// "adding process work request", +// "caller: %s", caller); + + LogAssert(m_handlerList.IsEmpty()); + + /* put this in the handler list since that's where + TransferWaitingItems expects to find it */ + m_handlerList.InsertAsTail(m_handlerList.CastIn(request)); + + TransferWaitingItems(caller, + requestList, + returnUnitList); + } +} + +void RChannelReaderImpl::ReturnUnits(ChannelUnitList* unitList) +{ + while (unitList->IsEmpty() == false) + { + DrBListEntry* listEntry = unitList->RemoveHead(); + while (listEntry != NULL) + { + RChannelUnit* unit = unitList->CastOut(listEntry); +// DrLogD( "returning unit"); + unit->ReturnToSupplier(); + listEntry = unitList->RemoveHead(); + } + + { + AutoCriticalSection acs(&m_baseDR); + + m_unitLatch.TransferList(unitList); + } + } +} + +void RChannelReaderImpl:: + DispatchRequests(const char* caller, + ChannelProcessRequestList* requestList, + ChannelUnitList* unitList) +{ + ReturnUnits(unitList); + + while (requestList->IsEmpty() == false) + { + DrBListEntry* listEntry = requestList->RemoveHead(); + while (listEntry != NULL) + { + RChannelProcessRequest* request = requestList->CastOut(listEntry); + + if (request->GetHandler()->ImmediateDispatch()) + { + request->Process(); + delete request; + } + else + { +// DrLogD( caller, +// "adding work request"); + bool bRet = m_workQueue->EnQueue(request); + LogAssert(bRet == true); + } + + listEntry = requestList->RemoveHead(); + } + + { + AutoCriticalSection acs(&m_baseDR); + + m_sendLatch.TransferList(requestList); + } + } +} + +static void FOO(ChannelProcessRequestList* requestList) +{ + RChannelProcessRequest* request = + requestList->CastOut(requestList->GetHead()); + while (request != NULL) + { + LogAssert(request->GetItemArray() != NULL); + request = requestList->GetNextTyped(request); + } +} + +void RChannelReaderImpl::AddUnitList(ChannelUnitList* unitList) +{ + ChannelProcessRequestList requestList; + ChannelUnitList returnUnitList; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == RS_Running || + m_state == RS_InterruptingSupplier); + + DrBListEntry* listEntry = unitList->RemoveHead(); + while (listEntry != NULL) + { + RChannelUnit* unit = unitList->CastOut(listEntry); + AddUnitToQueue("RChannelReaderImpl::AddUnitList", + unit, &requestList, &returnUnitList); + listEntry = unitList->RemoveHead(); + } + + FOO(&requestList); + m_sendLatch.AcceptList(&requestList); + m_unitLatch.AcceptList(&returnUnitList); + } + + DispatchRequests("RChannelReaderImpl::AddUnitList", + &requestList, &returnUnitList); +} + +void RChannelReaderImpl::SupplyHandler(RChannelItemArrayReaderHandler* handler, + void* cancelCookie) +{ + ChannelProcessRequestList requestList; + ChannelUnitList returnUnitList; + bool queuedHandler = false; + bool startSupplier = false; + + { + AutoCriticalSection acs(&m_baseDR); + + if (m_state == RS_Running) + { + if (m_writerTerminationItem == NULL && + m_readerTerminationItem == NULL) + { + if (m_startedSupplier == false) + { + /* when we exit the lock, we will start the + supplier since this is the first read we have + received. After the supplier has been started + we'll set m_startedSupplierEvent to prevent a + race condition in Interrupt */ + startSupplier = true; + BOOL bRet = ::ResetEvent(m_startedSupplierEvent); + LogAssert(bRet != 0); + m_startedSupplier = true; + } + + /* we haven't yet sent a termination item for + processing, so there's going to be at least one + more item coming from the parser */ + RChannelProcessRequest* request = + new RChannelProcessRequest(this, handler, cancelCookie); + m_cookieMap.insert(std::make_pair(cancelCookie, request)); + AddHandlerToQueue("RChannelReaderImpl::SupplyHandler", + request, &requestList, &returnUnitList); + queuedHandler = true; + + FOO(&requestList); + m_sendLatch.AcceptList(&requestList); + m_unitLatch.AcceptList(&returnUnitList); + } + } + else + { + LogAssert(m_state != RS_Closed); + } + } + + if (startSupplier) + { +// DrLogI( "Starting Supplier"); + m_supplier->StartSupplier(m_prefetchCookie); +// DrLogI( "Started Supplier"); + BOOL bRet = ::SetEvent(m_startedSupplierEvent); + LogAssert(bRet != 0); + } + + if (queuedHandler) + { + DispatchRequests("RChannelReaderImpl::SupplyHandler", + &requestList, &returnUnitList); + } + else + { + LogAssert(requestList.IsEmpty()); + LogAssert(returnUnitList.IsEmpty()); + /* we've already sent out a termination item or started to + drain so just return the handler immediately as there will + never be any more items to send */ + RChannelItemArrayRef emptyArray; + emptyArray.Attach(new RChannelItemArray()); + handler->ProcessItemArray(emptyArray); + } +} + +void RChannelReaderImpl:: + ThreadSafeSetItemArray(RChannelItemArrayRef* dstItemArray, + RChannelItemArray* srcItemArray) +{ + { + AutoCriticalSection acs(&m_baseDR); + + dstItemArray->Set(srcItemArray); + } +} + +bool RChannelReaderImpl::FetchNextItemArray(UInt32 maxArraySize, + RChannelItemArrayRef* pItemArray, + DrTimeInterval csTimeOut) +{ + ChannelProcessRequestList requestList; + ChannelUnitList returnUnitList; + DryadHandleListEntry* event = NULL; + bool timedOut = false; + bool mustBlock = false; + bool startSupplier = false; + + { + AutoCriticalSection acs(&m_baseDR); + + if (m_state != RS_Running) + { + LogAssert(m_state != RS_Closed); + /* return an empty array */ + pItemArray->Attach(new RChannelItemArray()); + return true; + } + + if (!m_unitList.IsEmpty()) + { + LogAssert(m_startedSupplier == true); + LogAssert(m_handlerList.IsEmpty()); + + RChannelReaderSyncWaiter dummyHandler(NULL, + INVALID_HANDLE_VALUE, + NULL); + dummyHandler.SetMaximumArraySize(maxArraySize); + RChannelProcessRequest dummyRequest(NULL, &dummyHandler, NULL); + m_handlerList.InsertAsTail(m_handlerList.CastIn(&dummyRequest)); + + TransferWaitingItems("FetchNextItemArray", + &requestList, &returnUnitList); + + RChannelProcessRequest* dummyReturn = + requestList.CastOut(requestList.RemoveHead()); + LogAssert(dummyReturn == &dummyRequest); + + pItemArray->Set(dummyRequest.GetItemArray()); + LogAssert((*pItemArray)->GetNumberOfItems() > 0); + } + else if (m_writerTerminationItem != NULL || + m_readerTerminationItem != NULL) + { + /* we have already sent a termination item for processing + so there aren't going to be any more arriving on the + queue and we can return an empty list immediately */ + pItemArray->Attach(new RChannelItemArray()); + } + else if (csTimeOut > DrTimeInterval_Zero) + { + if (m_startedSupplier == false) + { + /* when we exit the lock, we will start the supplier + since this is the first read we have + received. After the supplier has been started we'll + set m_startedSupplierEvent to prevent a race + condition in Interrupt */ + startSupplier = true; + BOOL bRet = ::ResetEvent(m_startedSupplierEvent); + LogAssert(bRet != 0); + m_startedSupplier = true; + } + + /* we are going to block */ + mustBlock = true; + event = m_eventCache.GetEvent(true); + } + else + { + /* there's no item available and a zero timeout --- return + an empty list immediately */ + pItemArray->Attach(new RChannelItemArray()); + LogAssert(csTimeOut == DrTimeInterval_Zero); + timedOut = true; + } + + m_unitLatch.AcceptList(&returnUnitList); + FOO(&requestList); + m_sendLatch.AcceptList(&requestList); + } + + if (startSupplier) + { +// DrLogI( "Starting Supplier"); + m_supplier->StartSupplier(m_prefetchCookie); +// DrLogI( "Started Supplier"); + BOOL bRet = ::SetEvent(m_startedSupplierEvent); + LogAssert(bRet != 0); + } + + DispatchRequests("FetchNextItemArray", &requestList, &returnUnitList); + + if (mustBlock == false) + { + LogAssert(event == NULL); + } + else + { + LogAssert(event != NULL); + + *pItemArray = NULL; + + RChannelReaderSyncWaiter* waiter = + new RChannelReaderSyncWaiter(this, event->GetHandle(), pItemArray); + this->SupplyHandler(waiter, waiter); + + DWORD timeOut = DrGetTimerMsFromInterval(csTimeOut); + DWORD dRet = ::WaitForSingleObject(event->GetHandle(), timeOut); + + if (dRet == WAIT_TIMEOUT) + { + /* when cancel returns it's guaranteed the handler has + been called */ + this->Cancel(waiter); + } + else + { + LogAssert(dRet == WAIT_OBJECT_0); + } + + LogAssert(*pItemArray != NULL); + + delete waiter; + + { + AutoCriticalSection acs(&m_baseDR); + + /* save this event in case we need one again in the + future */ + m_eventCache.ReturnEvent(event); + + /* even if we actually timed out the wait, the handler may + have supplied an item immediately afterwards or during + the cancel. We only return a timed out value if there's + no item ready but we haven't processed a termination + item, so it's worth waiting for another one. */ + timedOut = ((*pItemArray)->GetNumberOfItems() == 0 && + m_writerTerminationItem == NULL && + m_readerTerminationItem == NULL); + } + } + + return (!timedOut); +} + +/* called with RChannelReaderImpl::m_baseDR held */ +void RChannelReaderImpl::RemoveFromCancelMap(RChannelProcessRequest* request, + void* cancelCookie) +{ + /* get the first occurrence of this cookie in the multimap */ + CookieHandlerMap::iterator hIter = m_cookieMap.find(cancelCookie); + /* then look for the actual matching request */ + while (hIter != m_cookieMap.end() && + hIter->first == cancelCookie && + hIter->second != request) + { + ++hIter; + } + LogAssert(hIter != m_cookieMap.end() && + hIter->first == cancelCookie); + m_cookieMap.erase(hIter); +} + +/* called with RChannelReaderImpl::m_baseDR held */ +void RChannelReaderImpl::MaybeTriggerCancelEvent(void* cancelCookie) +{ + /* there is an event in the event map if somebody is blocked on + cancelling this cookie */ + CookieEventMap::iterator eIter = m_eventMap.find(cancelCookie); + if (eIter != m_eventMap.end()) + { + /* check to see if there are any remaining requests with this + cookie which haven't been processed yet */ + CookieHandlerMap::iterator hIter = m_cookieMap.find(cancelCookie); + if (hIter == m_cookieMap.end()) + { + /* there's nothing left so signal the event and remove the + cookie from the map */ + BOOL bRet = ::SetEvent(eIter->second); + LogAssert(bRet != 0); + m_eventMap.erase(eIter); + } + } +} + +void RChannelReaderImpl:: + ProcessItemArrayRequest(RChannelProcessRequest* request) +{ + RChannelItemArray* itemArray = request->GetItemArray(); + RChannelItemArrayReaderHandler* handler = request->GetHandler(); + LogAssert(handler != NULL); + LogAssert(itemArray != NULL); + + void* cancelCookie = request->GetCookie(); + + handler->ProcessItemArray(itemArray); + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state != RS_Stopped); + + /* now that the request has been processed remove it from our + accounting */ + RemoveFromCancelMap(request, cancelCookie); + /* if it was the last request with this cookie and somebody + was blocked cancelling the cookie, wake them up */ + MaybeTriggerCancelEvent(cancelCookie); + + /* if it was the last request in the system and we are + draining, wake up the drain thread */ + if (m_state == RS_Stopping && m_cookieMap.empty()) + { + LogAssert(m_eventMap.empty()); + BOOL bRet = ::SetEvent(m_drainEvent); + LogAssert(bRet != 0); + } + } +} + +void RChannelReaderImpl::AlertApplication(RChannelItem* item) +{ + RChannelInterruptHandler* interruptHandler = NULL; + + { + AutoCriticalSection acs(&m_baseDR); + + if (m_interruptHandler != NULL) + { + interruptHandler = m_interruptHandler; + m_interruptHandler = NULL; + } + } + + if (interruptHandler != NULL) + { + interruptHandler->ProcessInterrupt(item); + } +} + +bool RChannelReaderImpl::IsRunning() +{ + bool retval; + + { + AutoCriticalSection acs(&m_baseDR); + + retval = (m_state == RS_Running); + } + + return retval; +} + +void RChannelReaderImpl::Cancel(void* cancelCookie) +{ + BOOL bRet; + bool mustClean = false; + ChannelProcessRequestList handlerDispatch; + DrBListEntry* listEntry; + CookieHandlerMap::iterator cIter; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == RS_Running || m_state == RS_InterruptingSupplier); + + /* first find any handlers with this cookie which haven't yet + been dispatched with an item, and put them on the + handlerDispatch list */ + listEntry = m_handlerList.GetHead(); + while (listEntry != NULL) + { + RChannelProcessRequest* request = m_handlerList.CastOut(listEntry); + listEntry = m_handlerList.GetNext(listEntry); + + if (request->GetCookie() == cancelCookie) + { + LogAssert(request->GetItemArray() == NULL); + handlerDispatch.TransitionToTail(handlerDispatch. + CastIn(request)); + RemoveFromCancelMap(request, cancelCookie); + } + } + + /* now find any handlers with this cookie which are still + around (i.e. have already been put on the work queue but + not yet processed) and try to cancel them, which will + spring them from the work queue early when we clean it + below */ + cIter = m_cookieMap.find(cancelCookie); + if (cIter != m_cookieMap.end()) + { + while (cIter != m_cookieMap.end() && + cIter->first == cancelCookie) + { + (cIter->second)->Cancel(); + ++cIter; + } + mustClean = true; + } + } + + /* these handlers had been submitted but not assigned an item, so + we can send them back immediately */ + listEntry = handlerDispatch.GetHead(); + while (listEntry != NULL) + { + RChannelProcessRequest* request = handlerDispatch.CastOut(listEntry); + listEntry = handlerDispatch.GetNext(listEntry); + handlerDispatch.Remove(handlerDispatch.CastIn(request)); + RChannelItemArrayRef emptyArray; + emptyArray.Attach(new RChannelItemArray()); + request->GetHandler()->ProcessItemArray(emptyArray); + delete request; + } + + if (mustClean) + { + /* at least one handler had already been sent for processing: + get the queue to trigger it if it's still hanging around + there */ + m_workQueue->Clean(); + } + + DryadHandleListEntry* event = NULL; + + { + AutoCriticalSection acs(&m_baseDR); + + /* see if there are still any handlers around after the work + queue cleaning which haven't triggered yet. If so, we'll + have to add an event to the event map and wait for it */ + cIter = m_cookieMap.find(cancelCookie); + if (cIter != m_cookieMap.end()) + { + event = m_eventCache.GetEvent(true); + + std::pair retval; + retval = m_eventMap.insert(std::make_pair(cancelCookie, + event->GetHandle())); + /* it's not legal to overlap two calls to cancel the same + cookie */ + LogAssert(retval.second == true); + } + } + + if (event != NULL) + { + /* we decided we had to wait for an event */ + bRet = ::WaitForSingleObject(event->GetHandle(), INFINITE); + LogAssert(bRet == WAIT_OBJECT_0); + + { + AutoCriticalSection acs(&m_baseDR); + + /* sanity check that there really aren't any handlers with + this cookie still hanging around */ + cIter = m_cookieMap.find(cancelCookie); + LogAssert(cIter == m_cookieMap.end()); + + /* sanity check that it got removed from the event map at + the same time */ + CookieEventMap::iterator eIter = m_eventMap.find(cancelCookie); + LogAssert(eIter == m_eventMap.end()); + + /* save the event in case we want to use it again later */ + m_eventCache.ReturnEvent(event); + } + } +} + +// +// Interrupt an input channel for provided reason +// +void RChannelReaderImpl::Interrupt(RChannelItem* interruptItemBase) +{ + RChannelItemRef interruptItem = interruptItemBase; + bool doInterrupt; + bool startedSupplier = false; + BOOL bRet; + + // + // Enter a critical section and update state from "Running" to "Interrupting" + // + { + AutoCriticalSection acs(&m_baseDR); + + if (m_state == RS_Running) + { + doInterrupt = true; + bRet = ::ResetEvent(m_interruptEvent); + LogAssert(bRet != 0); + m_state = RS_InterruptingSupplier; + startedSupplier = m_startedSupplier; + } + else + { + doInterrupt = false; + } + } + + // + // If already interrupting, wait for interrupt event to be handled + // + if (doInterrupt == false) + { + bRet = ::WaitForSingleObject(m_interruptEvent, INFINITE); + LogAssert(bRet == WAIT_OBJECT_0); + return; + } + + ChannelProcessRequestList handlerDispatch; + ChannelUnitList unitDispatch; + DrBListEntry* listEntry; + RChannelInterruptHandler* interruptHandler = NULL; + + if (startedSupplier) + { + // + // startedSupplier just means that we (potentially in another + // thread) have called or are about to call StartSupplier, + // outside the lock. We'll wait here until it really gets + // called before calling InterruptSupplier + // + bRet = ::WaitForSingleObject(m_startedSupplierEvent, INFINITE); + LogAssert(bRet == WAIT_OBJECT_0); + + // + // when this returns the buffer reader will not be generating + // any new buffers and the parser will not be generating any + // new items. + // + m_supplier->InterruptSupplier(); + } + + { + AutoCriticalSection acs(&m_baseDR); + + if (interruptItem == NULL) + { + if (m_writerTerminationItem == NULL) + { + interruptItem.Attach(RChannelMarkerItem:: + Create(RChannelItem_Abort, false)); + } + else + { + RChannelItemType interruptType = + m_writerTerminationItem->GetType(); + interruptItem.Attach(RChannelMarkerItem:: + Create(interruptType, false)); + interruptItem->ReplaceMetaData(m_writerTerminationItem-> + GetMetaData()); + } + } + else + { + RChannelItemType interruptType = interruptItem->GetType(); + LogAssert(RChannelItem::IsTerminationItem(interruptType)); + } + + LogAssert(m_state == RS_InterruptingSupplier); + /* sanity check that nobody accidentally started the supplier + while we were getting here */ + LogAssert(startedSupplier == m_startedSupplier); + + m_state = RS_Stopping; + + if (m_unitList.IsEmpty()) + { + /* gather up any handlers which haven't been given an item + yet */ + FillEmptyHandlers(&handlerDispatch); + } + else + { + LogAssert(m_handlerList.IsEmpty()); + /* gather up any units which have been put on the queue + but don't have a handler waiting */ + bool gotTermination = (m_writerTerminationItem != NULL); + while (m_unitList.IsEmpty() == false) + { + RChannelUnit* unit = m_unitList.CastOut(m_unitList.GetHead()); + if (unit->GetType() == RChannelUnit_Item) + { + /* sanity check that the item sequence is + correct */ + RChannelItemUnit* itemUnit = (RChannelItemUnit *) unit; + RChannelItemArray* itemArray = itemUnit->GetItemArray(); + + UInt32 nItems = itemArray->GetNumberOfItems(); + LogAssert(nItems > 0); + RChannelItemRef* a = itemArray->GetItemArray(); + + UInt32 i; + for (i=0; iGetType(); + LogAssert(gotTermination == false); + if (RChannelItem::IsTerminationItem(itemType)) + { + gotTermination = true; + } + } + + itemUnit->DiscardItems(); + } + unitDispatch.TransitionToTail(unitDispatch.CastIn(unit)); + } + } + + /* prepare to trigger any interrupt handler the application + sent in earlier */ + if (m_interruptHandler != NULL) + { + interruptHandler = m_interruptHandler; + m_interruptHandler = NULL; + } + + m_unitLatch.AcceptList(&unitDispatch); + } + + /* these are handlers which had been submitted but not assigned an + item */ + listEntry = handlerDispatch.GetHead(); + while (listEntry != NULL) + { + RChannelProcessRequest* request = handlerDispatch.CastOut(listEntry); + listEntry = handlerDispatch.GetNext(listEntry); + handlerDispatch.Remove(handlerDispatch.CastIn(request)); + /* call the process event (which calls back into + RChannelReaderImpl::ProcessUnit) instead of just calling the + handler in case there is a thread blocked on a cancellation + which needs to be woken up. */ + request->Process(); + delete request; + } + + ReturnUnits(&unitDispatch); + + if (startedSupplier) + { + /* assuming we ever started it, wait for the buffer reader to + process all its returned buffers and tell us whether we can + restart or not */ + m_supplier->DrainSupplier(interruptItem); + } + + if (interruptHandler != NULL) + { + interruptHandler->ProcessInterrupt(interruptItem); + } + + { + AutoCriticalSection acs(&m_baseDR); + + /* sanity check that nobody accidentally started the supplier + while we were getting here */ + LogAssert(startedSupplier == m_startedSupplier); + + LogAssert(m_unitList.IsEmpty()); + LogAssert(m_handlerList.IsEmpty()); + LogAssert(m_interruptHandler == NULL); + LogAssert(m_readerTerminationItem == NULL); + m_readerTerminationItem = interruptItem; + m_startedSupplier = false; + } + + bRet = ::SetEvent(m_interruptEvent); + LogAssert(bRet != 0); +} + +void RChannelReaderImpl::Drain() +{ + bool waitForCookie = false; + bool waitForLatch = false; + BOOL bRet; + + Interrupt(NULL); + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == RS_Stopping); + LogAssert(m_startedSupplier == false); + LogAssert(m_unitList.IsEmpty()); + LogAssert(m_handlerList.IsEmpty()); + + if (!m_cookieMap.empty()) + { + /* there are handlers which have been submitted and sent + for processing, but haven't triggered yet. Mark them + for cancellation so they will get evicted from the work + queue in the clean below. */ + CookieHandlerMap::iterator cookieIter; + for (cookieIter = m_cookieMap.begin(); + cookieIter != m_cookieMap.end(); + ++cookieIter) + { + (cookieIter->second)->Cancel(); + } + waitForCookie = true; + bRet = ::ResetEvent(m_drainEvent); + LogAssert(bRet != 0); + } + + waitForLatch = m_sendLatch.Interrupt(); + } + + if (waitForLatch) + { + m_sendLatch.Wait(); + } + + { + AutoCriticalSection acs(&m_baseDR); + + waitForLatch = m_unitLatch.Interrupt(); + } + + if (waitForLatch) + { + m_unitLatch.Wait(); + } + + if (waitForCookie) + { + m_workQueue->Clean(); + + /* wait until all handlers which had been sent to the queue + have returned */ + bRet = ::WaitForSingleObject(m_drainEvent, INFINITE); + LogAssert(bRet == WAIT_OBJECT_0); + } + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_cookieMap.empty()); + LogAssert(m_eventMap.empty()); + LogAssert(m_unitList.IsEmpty()); + LogAssert(m_handlerList.IsEmpty()); + LogAssert(m_interruptHandler == NULL); + LogAssert(m_startedSupplier == false); + m_sendLatch.Stop(); + m_unitLatch.Stop(); + + m_state = RS_Stopped; + } +} + +void RChannelReaderImpl:: + GetTerminationItems(RChannelItemRef* pWriterDrainItem, + RChannelItemRef* pReaderDrainItem) +{ + { + AutoCriticalSection acs(&m_baseDR); + + *pWriterDrainItem = m_writerTerminationItem; + *pReaderDrainItem = m_readerTerminationItem; + } +} + +UInt64 RChannelReaderImpl::GetDataSizeRead() +{ + return m_dataSizeRead; +} + +void RChannelReaderImpl::Close() +{ + m_supplier->CloseSupplier(); + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_cookieMap.empty()); + LogAssert(m_eventMap.empty()); + LogAssert(m_unitList.IsEmpty()); + LogAssert(m_handlerList.IsEmpty()); + LogAssert(m_interruptHandler == NULL); + LogAssert(m_state == RS_Stopped); + m_state = RS_Closed; + m_readerTerminationItem = NULL; + m_writerTerminationItem = NULL; + } +} + +RChannelSerializedReader:: + RChannelSerializedReader(RChannelBufferReader* bufferReader, + RChannelItemParserBase* parser, + UInt32 maxParseBatchSize, + UInt32 maxOutstandingUnits, + bool lazyStart, + WorkQueue* workQueue) +{ + m_bufferQueue = new RChannelBufferQueue(this, + bufferReader, + parser, + maxParseBatchSize, + maxOutstandingUnits, + workQueue); + RChannelReaderImpl::Initialize(m_bufferQueue, workQueue, lazyStart); + + m_expectedLength = (UInt64) -1; +} + +RChannelSerializedReader::~RChannelSerializedReader() +{ + delete m_bufferQueue; +} + +bool RChannelSerializedReader::GetTotalLength(UInt64* pLen) +{ + return this->m_bufferQueue->GetTotalLength(pLen); +} + +bool RChannelSerializedReader::GetExpectedLength(UInt64* pLen) +{ + if (m_expectedLength == (UInt64) -1) + { + *pLen = 0; + return false; + } + else + { + *pLen = m_expectedLength; + return true; + } +} + +void RChannelSerializedReader::SetExpectedLength(UInt64 expectedLength) +{ + m_expectedLength = expectedLength; +} + +RChannelBufferReaderHandler::~RChannelBufferReaderHandler() +{ +} + +RChannelBufferReader::~RChannelBufferReader() +{ +} + +void RChannelBufferReader::FillInStatus(DryadChannelDescription* status) +{ +} + +void RChannelNullBufferReader:: + Start(RChannelBufferPrefetchInfo* prefetchCookie, + RChannelBufferReaderHandler* handler) +{ + RChannelItem* item = + RChannelMarkerItem::Create(RChannelItem_EndOfStream, false); + + RChannelBuffer* buffer = + RChannelBufferMarkerDefault::Create(RChannelBuffer_EndOfStream, + item, + this); + + handler->ProcessBuffer(buffer); +} + +void RChannelNullBufferReader::Interrupt() +{ +} + +void RChannelNullBufferReader::Drain(RChannelItem* drainItem) +{ +} + +void RChannelNullBufferReader::Close() +{ +} + +void RChannelNullBufferReader::ReturnBuffer(RChannelBuffer* buffer) +{ + buffer->DecRef(); +} + +bool RChannelNullBufferReader::GetTotalLength(UInt64* pLen) +{ + *pLen = 0; + return true; +} diff --git a/DryadVertex/VertexHost/system/channel/src/channelreader.h b/DryadVertex/VertexHost/system/channel/src/channelreader.h new file mode 100644 index 0000000..c4a11ef --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelreader.h @@ -0,0 +1,385 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once +#pragma warning(disable:4512) // KLUDGE -- build for now, fix later. + +#include +#include +#include +#include +#include +#include +#pragma warning(disable:4995) +#include + +class WorkQueue; +class WorkRequest; + +class RChannelUnit; + +class RChannelParseRequest; +class RChannelProcessRequest; +class RChannelBufferBoundaryUnit; +class RChannelItemUnit; +class RChannelBuffer; +class RChannelBufferReader; +class RChannelBufferQueue; +class RChannelBufferPrefetchInfo; + +typedef DryadBList ChannelUnitList; + +typedef DryadBListDerived + ChannelProcessRequestList; + +class RChannelReaderSupplier +{ +public: + virtual ~RChannelReaderSupplier(); + virtual void StartSupplier(RChannelBufferPrefetchInfo* prefetchCookie) = 0; + virtual void InterruptSupplier() = 0; + virtual void DrainSupplier(RChannelItem* drainItem) = 0; + virtual void CloseSupplier() = 0; +}; + +class RChannelReaderImpl : public RChannelReader +{ +public: + virtual ~RChannelReaderImpl(); + + /* a client must call Start to cause the channel to start + generating items for the first time, or after a Drain has + completed before the channel is restarted. + + prefetchCookie is passed to the byte-oriented buffer layer and + should be NULL for now. + */ + void Start(RChannelBufferPrefetchInfo* prefetchCookie); + + /* SupplyHandler passes a handler which will be "returned" via a + matching call to handler->ProcessItem. Multiple handlers may be + outstanding at a given time. The channel blocks when there are + no handlers available, so this is the primary flow control + mechanism to allow the reader to exert back-pressure on a + channel. + + If the channel has not been started or has delivered a + termination item to the application since the last call to + Start, handler->ProcessItem will be called with value NULL on + the calling thread before SupplyHandler returns. + + If there is an item waiting to be delivered and handler is an + RChannelItemReaderHandlerImmediate then handler will be called + back with the item on the calling thread before SupplyHandler + returns. + + If handler is an RChannelItemReaderHandlerQueued it will never + be called back on the calling thread. If there is an item + waiting to be delivered, a processing request will be queued + for the item with handler, otherwise handler will be queued + waiting for the next item to be ready. + + If thread B submits an RChannelItemReaderHandlerImmediate with + a call to SupplyHandler that is overlapped with thread A's call + to SupplyHandler, B's immediate handler may be called on A's + calling thread before A's call to SupplyHandler returns. + + After a handler has been submitted to SupplyHandler with a + given value of cancelCookie, that handler is guaranteed to be + returned before any matching call to Cancel with the same value + of cancelCookie returns. If there is no item available at the + time the handler is cancelled it will be called with + NULL. cancelCookie may take any value, including NULL, however + it is safest to use NULL or the address of an object owned by + the caller. The cancellation mechanism is used internally by + the synchronous FetchNextItem call which uses an allocated heap + address as its cancelCookie. Using the heap address of an + object which has not been freed before the call to Cancel will + avoid any danger of a cookie collision. + + For any item A which has already been returned via an async + handler or call to FetchNextItem when SupplyHandler is called + it is guaranteed that A's sequence number is less than or equal + to the sequence number of the item eventually returned on + handler. Beyond this constraint, if multiple handlers are + outstanding at once, or while calls to FetchNextItem are in + progress, the order of delivered items is undefined. + */ + void SupplyHandler(RChannelItemArrayReaderHandler* handler, + void* cancelCookie); + + /* any handler passed to SupplyHandler with value cancelCookie is + guaranteed to be returned before a call to Cancel completes. */ + void Cancel(void* cancelCookie); + + /* FetchNextItem blocks waiting until an item is available, a + termination item is delivered on another thread, or the timeOut + interval has elapsed (timeOut can be DrTimeInterval_Infinite in + which case FetchNextItem will block indefinitely). + + If the timeout expires FetchNextItem returns false and *outItem + is NULL. Otherwise FetchNextItem returns true and the returned + item is stored in *outItem. *outItem may be NULL even if + FetchNextItem returns true: this will happen if Start has not + been called or a termination item has already been delivered on + the channel since the last call to Start. + */ + bool FetchNextItemArray(UInt32 maxListSize, + RChannelItemArrayRef* itemArray, + DrTimeInterval timeOut); + + /* Instruct the channel to return all outstanding handlers via + their handler->ProcessItem callbacks in preparation for either + closing or restarting the channel. Once the Drain method + returns all outstanding handler callbacks will have completed + and all waiting calls to FetchNextItem will have been unblocked + (though of course they may not have returned to the calling + thread). For obvious reasons, Drain may not be called from a + handler's ProcessItem callback. + + The drainItem may be of type RChannelItem_Abort, + RChannelItem_Restart, RChannelItem_EndOfStream or + RChannelItem_ParseError. If the channel is a pipe which has not + broken, the drainItem will be delivered to the remote end as + part of the shutdown procedure. + */ + void Interrupt(RChannelItem* interruptItem); + void Drain(); + + /* return the drain item passed to the most recent call to Drain, + or NULL if Drain has not been called since the last call to + Start. The return value is undefined if this call overlaps + with a call to Start or Drain, and it is illegal to call + GetTerminationItems after Close has been called. */ + void GetTerminationItems(RChannelItemRef* pWriterDrainItem, + RChannelItemRef* pReaderDrainItem); + + UInt64 GetDataSizeRead(); + + /* Close may only be called if Start has never been called, or if + Drain has completed since the last call to Start. Close must be + called before the RChannelReader is destroyed. After Close has + been called no other methods may be called on RChannelReader. + */ + void Close(); + + /* Get/Set the URI of the channel. */ + const char* GetURI(); + void SetURI(const char* uri); + +protected: + RChannelReaderImpl(); + void Initialize(RChannelReaderSupplier* supplier, + WorkQueue* workQueue, + bool lazyStart); + +private: + typedef std::multimap + CookieHandlerMap; + typedef std::map + CookieEventMap; + + enum ReaderState { + RS_Stopped, + RS_Running, + RS_InterruptingSupplier, + RS_Stopping, + RS_Closed + }; + + bool IsRunning(); + + void FillEmptyHandlers(ChannelProcessRequestList* handlerDispatch); + void TransferWaitingItems(const char* caller, + ChannelProcessRequestList* requestList, + ChannelUnitList* returnUnitList); + void ReturnUnits(ChannelUnitList* unitList); + void DispatchRequests(const char* caller, + ChannelProcessRequestList* requestList, + ChannelUnitList* returnUnitList); + void AddUnitToQueue(const char* caller, + RChannelUnit* unit, + ChannelProcessRequestList* requestList, + ChannelUnitList* returnUnitList); + void AddUnitList(ChannelUnitList* unitList); + void AddHandlerToQueue(const char* caller, + RChannelProcessRequest* request, + ChannelProcessRequestList* requestList, + ChannelUnitList* returnUnitList); + void RemoveFromCancelMap(RChannelProcessRequest* request, + void* cancelCookie); + void MaybeTriggerCancelEvent(void* cancelCookie); + void ProcessItemArrayRequest(RChannelProcessRequest* handler); + void AlertApplication(RChannelItem* item); + void ThreadSafeSetItemArray(RChannelItemArrayRef* dstItemArray, + RChannelItemArray* srcItemArray); + + WorkQueue* m_workQueue; + + CookieHandlerMap m_cookieMap; + CookieEventMap m_eventMap; + ChannelUnitList m_unitList; + ChannelProcessRequestList m_handlerList; + RChannelInterruptHandler* m_interruptHandler; + + ReaderState m_state; + DryadOrderedSendLatch m_sendLatch; + DryadOrderedSendLatch m_unitLatch; + HANDLE m_drainEvent; + HANDLE m_interruptEvent; + HANDLE m_startedSupplierEvent; + bool m_lazyStart; + bool m_startedSupplier; + DryadEventCache m_eventCache; + RChannelItemRef m_readerTerminationItem; + RChannelItemRef m_writerTerminationItem; + UInt64 m_numberOfSubItemsRead; + UInt64 m_dataSizeRead; + + RChannelBufferPrefetchInfo* m_prefetchCookie; + RChannelReaderSupplier* m_supplier; + + DrStr128 m_uri; + + CRITSEC m_baseDR; + + friend class RChannelParseRequest; + friend class RChannelProcessRequest; + friend class RChannelBufferBoundaryRequest; + friend class RChannelReaderSyncWaiter; + friend class RChannelItemUnit; + friend class RChannelBufferQueue; + friend class RChannelFifoWriterBase; +}; + +class RChannelSerializedReader : public RChannelReaderImpl +{ +public: + RChannelSerializedReader(RChannelBufferReader* bufferReader, + RChannelItemParserBase* parser, + UInt32 maxParseBatchSize, + UInt32 maxOutstandingUnits, + bool lazyStart, + WorkQueue* workQueue); + ~RChannelSerializedReader(); + + bool GetTotalLength(UInt64* pLen); + bool GetExpectedLength(UInt64* pLen); + void SetExpectedLength(UInt64 expectedLength); + +private: + UInt64 m_expectedLength; + RChannelBufferQueue* m_bufferQueue; +}; + +/* interface used to signal that a buffer has arrived and is ready for + reading */ +class RChannelBufferReaderHandler +{ +public: + virtual ~RChannelBufferReaderHandler(); + + /* When an i/o completes on a buffer reader, the buffer is + delivered via this callback to the consumer. + buffer->ProcessingComplete should be called when the consumer + has finished using the data in the buffer. Buffers are + delivered in order, and to enforce this there will be at most + one call to ProcessBuffer in progress at a time on any given + channel. + + If buffer->GetType() is a termination type + (RChannelBuffer::IsTerminationBuffer returns true) then no more + calls to ProcessBuffer will be made before + RChannelBufferReader::Drain has completed. + + The completion callback mechanism is used to implement flow + control, as the buffer-oriented i/o will only allow the + consumer to hold a bounded number of outstanding buffers before + blocking further reads. + */ + virtual void ProcessBuffer(RChannelBuffer* buffer) = 0; +}; + +/* base class to wrap byte-oriented read implementations */ +class RChannelBufferReader +{ +public: + virtual ~RChannelBufferReader(); + + /* Instruct the i/o reader to start fetching buffers from the + start of the Channel and delivering them via handler. + + prefetchCookie is an optional hint which may be used to + influence initial read buffer sizes or prefetching + behaviour. It is dependent on the implementation of the + underlying buffer-oriented i/o class and should be NULL for + now. + */ + virtual void Start(RChannelBufferPrefetchInfo* prefetchCookie, + RChannelBufferReaderHandler* handler) = 0; + + /* Instruct the i/o reader to prepare to Drain the Channel of + buffers. After a call to Interrupt returns, the BufferReader + will not make any more calls to handler unless the Channel is + Drained and restarted. + + A client which wants to abort or restart the Channel should + first call Interrupt (which guarantees no new buffers will be + delivered), then ensure that any outstanding buffers have been + returned via their completion handlers, then call Drain. + */ + virtual void Interrupt() = 0; + + virtual void FillInStatus(DryadChannelDescription* status); + + /* Complete the synchronisation in the case of restarting or + closing the stream. Drain will not return until all outstanding + buffer completion handlers have been called. + + drainItem will have type RChannelItem_Abort, + RChannelItem_Restart, RChannelItem_EndOfStream or + RChannelItem_ParseError and if the channel is a pipe it should + be communicated to the process at the remote end. + + After Drain returns, Start and Close are the only legal method + calls. + */ + virtual void Drain(RChannelItem* drainItem) = 0; + + /* shut down the channel. After Close returns no further calls can + be made to this interface. + */ + virtual void Close() = 0; + + virtual bool GetTotalLength(UInt64* pLen) = 0; +}; + +class RChannelNullBufferReader : + public RChannelBufferReader, public RChannelBufferDefaultHandler +{ +public: + void Start(RChannelBufferPrefetchInfo* prefetchCookie, + RChannelBufferReaderHandler* handler); + void Interrupt(); + void Drain(RChannelItem* drainItem); + void Close(); + void ReturnBuffer(RChannelBuffer* buffer); + bool GetTotalLength(UInt64* pLen); +}; diff --git a/DryadVertex/VertexHost/system/channel/src/channelwriter.cpp b/DryadVertex/VertexHost/system/channel/src/channelwriter.cpp new file mode 100644 index 0000000..36ebc7d --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelwriter.cpp @@ -0,0 +1,1524 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include +#include +#include +#include +#include "channelhelpers.h" +#include "channelmemorybuffers.h" + +#pragma unmanaged + + +void SyncItemWriterBase:: + WriteItemSyncConsumingFreeReference(RChannelItem* item) +{ + RChannelItemRef refItem; + refItem.Attach(item); + WriteItemSyncConsumingReference(refItem); +} + +void SyncItemWriterBase::WriteItemSync(RChannelItem* item) +{ + RChannelItemRef refItem = item; + WriteItemSyncConsumingReference(refItem); +} + +RChannelWriter::~RChannelWriter() +{ +} + +void RChannelWriter::WriteItem(RChannelItem* item, + bool flushAfter, + RChannelItemArrayWriterHandler* handler) +{ + RChannelItemArrayRef singleArray; + singleArray.Attach(new RChannelItemArray()); + singleArray->SetNumberOfItems(1); + singleArray->GetItemArray()[0] = item; + WriteItemArray(singleArray, flushAfter, handler); +} + +RChannelItemType RChannelWriter:: + WriteItemSync(RChannelItem* item, + bool flush, + RChannelItemRef* pMarshalFailureItem) +{ + RChannelItemArrayRef singleArray; + singleArray.Attach(new RChannelItemArray()); + singleArray->SetNumberOfItems(1); + singleArray->GetItemArray()[0] = item; + + RChannelItemArrayRef failureArray; + + RChannelItemType retVal = + WriteItemArraySync(singleArray, flush, &failureArray); + + if (pMarshalFailureItem != NULL) + { + if (failureArray == NULL) + { + *pMarshalFailureItem = NULL; + } + else + { + LogAssert(failureArray->GetNumberOfItems() == 1); + LogAssert(failureArray->GetItemArray()[0] != NULL); + *pMarshalFailureItem = failureArray->GetItemArray()[0]; + } + } + + return retVal; +} + +void RChannelWriter::WriteItemSyncConsumingReference(RChannelItemRef& item) +{ + WriteItemSync(item, false, NULL); + item = NULL; +} + +DrError RChannelWriter::GetTerminationStatus(DryadMetaDataRef* pErrorData) +{ + DrError status; + + RChannelItemRef writerTermination; + RChannelItemRef readerTermination; + GetTerminationItems(&writerTermination, &readerTermination); + + LogAssert(writerTermination != NULL); + + *pErrorData = writerTermination->GetMetaData(); + + if (writerTermination->GetType() == RChannelItem_EndOfStream) + { + if (readerTermination != NULL && + readerTermination->GetType() != RChannelItem_EndOfStream) + { + status = readerTermination->GetErrorFromItem(); + *pErrorData = readerTermination->GetMetaData(); + } + else + { + status = DrError_EndOfStream; + } + } + else + { + status = writerTermination->GetErrorFromItem(); + } + + return status; +} + +DrError RChannelWriter::GetWriterStatus() +{ + DrError status = DrError_OK; + + RChannelItemRef writerTermination; + RChannelItemRef readerTermination; + GetTerminationItems(&writerTermination, &readerTermination); + + if (readerTermination != NULL) + { + if (readerTermination->GetType() == RChannelItem_EndOfStream) + { + status = DrError_EndOfStream; + } + else + { + status = readerTermination->GetErrorFromItem(); + } + } + + return status; +} + +RChannelItemArrayWriterHandler::~RChannelItemArrayWriterHandler() +{ +} + +RChannelItemWriterHandler::~RChannelItemWriterHandler() +{ +} + +void RChannelItemWriterHandler:: + ProcessWriteArrayCompleted(RChannelItemType returnCode, + RChannelItemArray* failureArray) +{ + if (failureArray != NULL) + { + LogAssert(failureArray->GetNumberOfItems() == 1); + ProcessWriteCompleted(returnCode, failureArray->GetItemArray()[0]); + } + else + { + ProcessWriteCompleted(returnCode, NULL); + } +} + +void RChannelSerializedWriter::DummyItemHandler:: + ProcessWriteArrayCompleted(RChannelItemType returnCode, + RChannelItemArray* failureArray) +{ + LogAssert(failureArray == NULL); + delete this; +} + +RChannelSerializedWriter::WriteRequest:: + WriteRequest(RChannelItemArray* itemArray, + bool flushAfter, + RChannelItemArrayWriterHandler* handler) +{ + LogAssert(itemArray != NULL); + m_itemArray = itemArray; + m_flushAfter = flushAfter; + m_handler = handler; + m_currentItem = 0; + m_aborted = false; +} + +RChannelSerializedWriter::WriteRequest::~WriteRequest() +{ + LogAssert(m_itemArray == NULL); + LogAssert(m_handler == NULL); + LogAssert(m_failureArray == NULL); +} + +void RChannelSerializedWriter::WriteRequest:: + SetHandler(RChannelItemArrayWriterHandler* handler) +{ + LogAssert(m_handler == NULL); + m_handler = handler; +} + +RChannelItem* RChannelSerializedWriter::WriteRequest::GetNextItem() +{ + LogAssert(m_currentItem < m_itemArray->GetNumberOfItems() && + m_aborted == false); + return m_itemArray->GetItemArray()[m_currentItem]; +} + +void RChannelSerializedWriter::WriteRequest::SetSuccessItem() +{ + LogAssert(m_aborted == false); + if (m_failureArray != NULL) + { + LogAssert(m_currentItem < m_failureArray->GetNumberOfItems()); + LogAssert(m_failureArray->GetItemArray()[m_currentItem] == NULL); + } + ++m_currentItem; +} + +void RChannelSerializedWriter::WriteRequest:: + SetFailureItem(RChannelItem* marshalFailureItem, bool abort) +{ + LogAssert(m_aborted == false); + m_aborted = abort; + + if (m_failureArray == NULL) + { + m_failureArray.Attach(new RChannelItemArray()); + m_failureArray->SetNumberOfItems(m_itemArray->GetNumberOfItems()); + } + + LogAssert(m_currentItem < m_failureArray->GetNumberOfItems()); + LogAssert(m_failureArray->GetItemArray()[m_currentItem] == NULL); + m_failureArray->GetItemArray()[m_currentItem] = marshalFailureItem; + + ++m_currentItem; +} + +bool RChannelSerializedWriter::WriteRequest::ShouldFlush() +{ + return m_flushAfter; +} + +bool RChannelSerializedWriter::WriteRequest::LastItem() +{ + LogAssert(m_aborted == false); + return (m_currentItem == m_itemArray->GetNumberOfItems()-1); +} + +bool RChannelSerializedWriter::WriteRequest::Completed() +{ + return (m_currentItem == m_itemArray->GetNumberOfItems() || + m_aborted == true); +} + +void RChannelSerializedWriter::WriteRequest:: + ProcessMarshalCompleted(RChannelItemType returnCode) +{ + LogAssert(m_handler != NULL); + + if (m_aborted == false) + { + LogAssert(m_currentItem == m_itemArray->GetNumberOfItems()); + } + + m_handler->ProcessWriteArrayCompleted(returnCode, m_failureArray); + + m_itemArray = NULL; + m_failureArray = NULL; + m_handler = NULL; +} + + +RChannelSerializedWriter::SyncHandler::SyncHandler() +{ + m_event = NULL; + m_statusCode = RChannelItem_EndOfStream; + m_usingEvent = 0; +} + +RChannelSerializedWriter::SyncHandler::~SyncHandler() +{ + LogAssert(m_event == NULL); +} + +void RChannelSerializedWriter::SyncHandler:: + UseEvent(DryadHandleListEntry* event) +{ + LONG postIncrement = ::InterlockedIncrement(&m_usingEvent); + LogAssert(postIncrement == 1); + LogAssert(m_event == NULL); + m_event = event; +} + +void RChannelSerializedWriter::SyncHandler:: + ProcessWriteArrayCompleted(RChannelItemType statusCode, + RChannelItemArray* failureArray) +{ + m_statusCode = statusCode; + m_failureArray = failureArray; + + LONG postIncrement = ::InterlockedIncrement(&m_usingEvent); + if (postIncrement == 2) + { + LogAssert(m_event != NULL); + BOOL bRet = ::SetEvent(m_event->GetHandle()); + LogAssert(bRet != 0); + } + else + { + LogAssert(postIncrement == 1); + LogAssert(m_event == NULL); + } +} + +bool RChannelSerializedWriter::SyncHandler::UsingEvent() +{ + return (m_event != NULL); +} + +RChannelItemType RChannelSerializedWriter::SyncHandler::GetStatusCode() +{ + return m_statusCode; +} + +void RChannelSerializedWriter::SyncHandler:: + GetFailureItemArray(RChannelItemArrayRef* pFailureArray) +{ + *pFailureArray = m_failureArray; +} + +void RChannelSerializedWriter::SyncHandler::Wait() +{ + LogAssert(m_event != NULL); + DWORD dRet = ::WaitForSingleObject(m_event->GetHandle(), INFINITE); + LogAssert(dRet == WAIT_OBJECT_0); +} + +DryadHandleListEntry* RChannelSerializedWriter::SyncHandler::GetEvent() +{ + LogAssert(m_event != NULL); + DryadHandleListEntry* retVal = m_event; + m_event = NULL; + return retVal; +} + + +RChannelSerializedWriter:: + RChannelSerializedWriter(RChannelBufferWriter* writer, + RChannelItemMarshalerBase* marshaler, + bool breakBufferOnRecordBoundaries, + UInt32 maxMarshalBatchSize, + WorkQueue* workQueue) +{ + m_breakBufferOnRecordBoundaries = breakBufferOnRecordBoundaries; + m_writer = writer; + m_marshaler = marshaler; + m_workQueue = workQueue; + + m_maxMarshalBatchSize = m_marshaler->GetMaxMarshalBatchSize(); + if (m_maxMarshalBatchSize == 0) + { + m_maxMarshalBatchSize = maxMarshalBatchSize; + } + + m_state = CW_Stopped; + m_outstandingBuffers = 0; + m_outstandingHandlers = 0; + m_marshaledTerminationItem = false; + m_channelTermination = RChannelItem_Data; + m_cachedWriter = NULL; + + m_handlerReturnEvent = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(m_handlerReturnEvent != NULL); + m_marshaledLastItemEvent = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(m_marshaledLastItemEvent != NULL); +} + +RChannelSerializedWriter::~RChannelSerializedWriter() +{ + { + AutoCriticalSection acs(&m_baseDR); + + DrLogD("Uninitializing RChannelSerializedWriter, uri = %s", m_uri != NULL ? m_uri : ""); + + LogAssert(m_state == CW_Closed); + LogAssert(m_writerTerminationItem == NULL); + LogAssert(m_readerTerminationItem == NULL); + LogAssert(m_cachedWriter == NULL); + + BOOL bRet = ::CloseHandle(m_handlerReturnEvent); + LogAssert(bRet != 0); + m_handlerReturnEvent = INVALID_HANDLE_VALUE; + + bRet = ::CloseHandle(m_marshaledLastItemEvent); + LogAssert(bRet != 0); + m_marshaledLastItemEvent = INVALID_HANDLE_VALUE; + } +} + + +void RChannelSerializedWriter::SetURI(const char* uri) +{ + m_uri = uri; +} + +const char* RChannelSerializedWriter::GetURI() +{ + return m_uri; +} + +UInt64 RChannelSerializedWriter::GetInitialSizeHint() +{ + return m_writer->GetInitialSizeHint(); +} + +void RChannelSerializedWriter::SetInitialSizeHint(UInt64 hint) +{ + m_writer->SetInitialSizeHint(hint); +} + +void RChannelSerializedWriter::Start() +{ + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == CW_Stopped); + LogAssert(m_pendingList.IsEmpty()); + LogAssert(m_blockedHandlerList.IsEmpty()); + LogAssert(m_bufferList.IsEmpty()); + LogAssert(m_outstandingBuffers == 0); + LogAssert(m_outstandingHandlers == 0); + LogAssert(m_marshaledTerminationItem == false); + LogAssert(m_cachedWriter == NULL); + + m_writerTerminationItem = NULL; + m_readerTerminationItem = NULL; + m_state = CW_Empty; + m_channelTermination = RChannelItem_Data; + m_returnLatch.Start(); + } + + m_marshaler->Reset(); + m_writer->Start(); +} + +void RChannelSerializedWriter:: + GetTerminationItems(RChannelItemRef* pWriterDrainItem, + RChannelItemRef* pReaderDrainItem) +{ + { + AutoCriticalSection acs(&m_baseDR); + + *pWriterDrainItem = m_writerTerminationItem; + *pReaderDrainItem = m_readerTerminationItem; + } +} + +void RChannelSerializedWriter::Close() +{ + m_writer->Close(); + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == CW_Stopped); + LogAssert(m_pendingList.IsEmpty()); + LogAssert(m_blockedHandlerList.IsEmpty()); + LogAssert(m_bufferList.IsEmpty()); + LogAssert(m_outstandingBuffers == 0); + LogAssert(m_outstandingHandlers == 0); + LogAssert(m_marshaledTerminationItem == false); + LogAssert(m_cachedWriter == NULL); + + m_writerTerminationItem = NULL; + m_readerTerminationItem = NULL; + m_state = CW_Closed; + } +} + +void RChannelSerializedWriter::MakeCachedWriter() +{ + if (m_cachedWriter == NULL) + { + DrRef buffer; + buffer.Attach(new RChannelWriterBuffer(m_writer, &m_bufferList)); + m_cachedWriter = new ChannelMemoryBufferWriter(buffer, &m_bufferList); + } +} + +Size_t RChannelSerializedWriter::DisposeOfCachedWriter() +{ + Size_t preMarshalAvailableSize = 0; + if (m_cachedWriter != NULL) + { + preMarshalAvailableSize = m_cachedWriter->GetLastRecordBoundary(); + DrError err = m_cachedWriter->CloseMemoryWriter(); + LogAssert(err == DrError_OK); + delete m_cachedWriter; + m_cachedWriter = NULL; + } + return preMarshalAvailableSize; +} + +/* called with m_baseDR held */ +void RChannelSerializedWriter::AcceptReturningHandlers(UInt32 handlerCount) +{ + LogAssert(m_outstandingHandlers >= handlerCount); + m_outstandingHandlers -= handlerCount; + + if (m_state == CW_Stopping && m_outstandingHandlers == 0) + { + BOOL bRet = ::SetEvent(m_handlerReturnEvent); + LogAssert(bRet != 0); + } +} + +void RChannelSerializedWriter::ReturnHandlers(WriteRequestList* completedList, + RChannelItemType returnCode) +{ + while (completedList->IsEmpty() == false) + { + UInt32 returnedHandlerCount = 0; + + WriteRequest* returnRequest = + completedList->CastOut(completedList->RemoveHead()); + + do + { + ++returnedHandlerCount; + returnRequest->ProcessMarshalCompleted(returnCode); + delete returnRequest; + returnRequest = + completedList->CastOut(completedList->RemoveHead()); + } while (returnRequest != NULL); + + { + AutoCriticalSection acs(&m_baseDR); + + m_returnLatch.TransferList(completedList); + + AcceptReturningHandlers(returnedHandlerCount); + } + } +} + +/* called with m_baseDR held */ +bool RChannelSerializedWriter:: + CheckForTerminationItem(RChannelItemArray* itemArray) +{ + if (m_writerTerminationItem != NULL || m_state == CW_Stopped) + { + return true; + } + else + { + UInt32 nItems = itemArray->GetNumberOfItems(); + LogAssert(nItems > 0); + RChannelItemRef* items = itemArray->GetItemArray(); + + UInt32 i; + for (i=0; iGetType(); + if (RChannelItem::IsTerminationItem(itemType)) + { + m_writerTerminationItem = item; + itemArray->TruncateToSize(i+1); + nItems = i+1; + } + } + + return false; + } +} + +/* called **without** m_baseDR held but this is OK since we are in the + marshaling state which means no other thread is touching + m_bufferList */ +void RChannelSerializedWriter:: + RestorePreMarshalBuffers(Size_t preMarshalAvailableSize) +{ + LogAssert(m_state == CW_Marshaling); + LogAssert(m_cachedWriter == NULL); + + if (m_bufferList.IsEmpty()) + { + LogAssert(preMarshalAvailableSize == 0); + return; + } + + DrBListEntry* listEntry = m_bufferList.GetTail(); + while (listEntry != NULL && + (m_bufferList.GetPrev(listEntry) != NULL || + preMarshalAvailableSize == 0)) + { + DryadFixedMemoryBuffer* buffer = m_bufferList.CastOut(listEntry); + listEntry = m_bufferList.GetPrev(listEntry); + m_bufferList.Remove(m_bufferList.CastIn(buffer)); + m_writer->ReturnUnusedBuffer(buffer); + } + + if (preMarshalAvailableSize > 0) + { + LogAssert(m_bufferList.CountLinks() == 1); + DryadFixedMemoryBuffer* buffer = + m_bufferList.CastOut(m_bufferList.GetHead()); + LogAssert(buffer->GetAllocatedSize() > preMarshalAvailableSize); + LogAssert(buffer->GetAvailableSize() >= preMarshalAvailableSize); + buffer->SetAvailableSize(preMarshalAvailableSize); + } +} + +bool RChannelSerializedWriter:: + SendCompletedBuffers(bool shouldFlush, RChannelItemType terminationType) +{ + Size_t preMarshalAvailableSize = DisposeOfCachedWriter(); + + if (m_breakBufferOnRecordBoundaries) + { + ShuffleBuffersOnRecordBoundaries(preMarshalAvailableSize); + } + + UInt32 bufferCount = m_bufferList.CountLinks(); + DryadFixedBufferList sendList; + sendList.TransitionToTail(&m_bufferList); + +// DrLogE( "SendCompletedBuffers", +// "count %u", bufferCount); + + if (shouldFlush == false && terminationType == RChannelItem_Data) + { + /* we're only here because we filled a buffer, and we may not + want to send the partial one at the end of the list */ + DryadFixedMemoryBuffer* lastBuffer = + sendList.CastOut(sendList.GetTail()); + if (lastBuffer->GetAvailableSize() != lastBuffer->GetAllocatedSize()) + { + m_bufferList.TransitionToTail(m_bufferList.CastIn(lastBuffer)); + --bufferCount; + } + } + + { + AutoCriticalSection acs(&m_baseDR); + + if (terminationType != RChannelItem_Data) + { + LogAssert(m_writerTerminationItem != NULL); + /* the additional "buffer" accounts for the call to + writer->WriteTermination below */ + ++bufferCount; + } + + m_outstandingBuffers += bufferCount; + } + + bool shouldBlock = false; + + DryadFixedMemoryBuffer* nextBuffer = + sendList.CastOut(sendList.RemoveHead()); + while (nextBuffer != NULL) + { + shouldBlock = + m_writer->WriteBuffer(nextBuffer, shouldFlush, this); + nextBuffer = sendList.CastOut(sendList.RemoveHead()); + } + + if (terminationType != RChannelItem_Data) + { + shouldBlock = true; + m_writer->WriteTermination(terminationType, this); + } + + return shouldBlock; +} + +RChannelItemType RChannelSerializedWriter:: + PerformSingleMarshal(WriteRequest* writeRequest) +{ + LogAssert(m_cachedWriter != NULL); + + RChannelItem* item = writeRequest->GetNextItem(); + RChannelItemType itemType = item->GetType(); + if (RChannelItem::IsTerminationItem(itemType) == false) + { + itemType = RChannelItem_Data; + } + + bool shouldFlush = (writeRequest->ShouldFlush() && + writeRequest->LastItem()); + + RChannelItemRef marshalFailure; + DrError marshalStatus = + m_marshaler->MarshalItem(m_cachedWriter, item, shouldFlush, + &marshalFailure); + DrError errTmp = m_cachedWriter->FlushMemoryWriter(); + LogAssert(errTmp == DrError_OK); + + if (marshalStatus == DrError_IncompleteOperation) + { + /* do nothing, we will call MarshalItem again on the same item + next time around */ + LogAssert(itemType == RChannelItem_Data); + } + else if (marshalStatus == DrError_OK) + { + writeRequest->SetSuccessItem(); + } + else + { + if (marshalFailure == NULL) + { + marshalFailure.Attach(RChannelMarkerItem:: + Create(RChannelItem_MarshalError, false)); + } + + Size_t preMarshalAvailableSize = DisposeOfCachedWriter(); + + RestorePreMarshalBuffers(preMarshalAvailableSize); + + MakeCachedWriter(); + + RChannelItemRef secondFailure; + m_marshaler->MarshalItem(m_cachedWriter, marshalFailure, shouldFlush, + &secondFailure); + errTmp = m_cachedWriter->FlushMemoryWriter(); + LogAssert(errTmp == DrError_OK); + + bool aborted = (marshalStatus == DryadError_ChannelAbort || + marshalStatus == DryadError_ChannelRestart); + + writeRequest->SetFailureItem(marshalFailure, aborted); + + if (aborted) + { + itemType = (marshalStatus == DryadError_ChannelAbort) ? + (RChannelItem_Abort) : (RChannelItem_Restart); + RChannelItemRef terminationItem; + terminationItem. + Attach(RChannelMarkerItem:: + CreateErrorItemWithDescription(itemType, + marshalStatus, + "Marshal failure caused " + "termination")); + m_marshaler->MarshalItem(m_cachedWriter, + terminationItem, shouldFlush, + &secondFailure); + errTmp = m_cachedWriter->FlushMemoryWriter(); + LogAssert(errTmp == DrError_OK); + } + } + + return itemType; +} + +void RChannelSerializedWriter::CollapseToSingleBuffer() +{ + Size_t totalSize = 0; + DryadFixedMemoryBuffer* buffer = + m_bufferList.CastOut(m_bufferList.GetHead()); + while (buffer != NULL) + { + totalSize += buffer->GetAvailableSize(); + buffer = m_bufferList.GetNextTyped(buffer); + } + + LogAssert(totalSize > 0); + + DryadFixedMemoryBuffer* combinedBuffer = + m_writer->GetCustomWriteBuffer(totalSize); + + totalSize = 0; + buffer = m_bufferList.CastOut(m_bufferList.RemoveHead()); + while (buffer != NULL) + { + Size_t contiguous; + + void* dstPtr = + combinedBuffer->GetWriteAddress(totalSize, + buffer->GetAvailableSize(), + &contiguous); + LogAssert(contiguous >= buffer->GetAvailableSize()); + + const void* srcPtr = + buffer->GetReadAddress(0, &contiguous); + LogAssert(contiguous >= buffer->GetAvailableSize()); + + ::memcpy(dstPtr, srcPtr, buffer->GetAvailableSize()); + + totalSize += buffer->GetAvailableSize(); + + m_writer->ReturnUnusedBuffer(buffer); + + buffer = m_bufferList.CastOut(m_bufferList.RemoveHead()); + } + + LogAssert(totalSize == combinedBuffer->GetAllocatedSize()); + combinedBuffer->SetAvailableSize(totalSize); + + m_bufferList.InsertAsTail(m_bufferList.CastIn(combinedBuffer)); +} + +void RChannelSerializedWriter:: + ShuffleBuffersOnRecordBoundaries(Size_t preMarshalAvailableSize) +{ + UInt32 bufferCount = m_bufferList.CountLinks(); + + if (bufferCount < 2) + { + /* there is nothing to write or the record boundary coincides + with the buffer boundary, so we don't need to move anything + around */ + DrLogD( "ShuffleBuffers taking no action"); + return; + } + + DryadFixedMemoryBuffer* headBuffer = + m_bufferList.CastOut(m_bufferList.GetHead()); + LogAssert(headBuffer != NULL); + + LogAssert(headBuffer->GetAvailableSize() == + headBuffer->GetAllocatedSize()); + LogAssert(preMarshalAvailableSize < headBuffer->GetAvailableSize()); + Size_t overhang = headBuffer->GetAvailableSize() - preMarshalAvailableSize; + + DryadFixedMemoryBuffer* nextBuffer = + m_bufferList.GetNextTyped(headBuffer); + LogAssert(nextBuffer != NULL); + + Size_t remainingSpace = + nextBuffer->GetAllocatedSize() - nextBuffer->GetAvailableSize(); + + if (preMarshalAvailableSize == 0 || bufferCount > 2 || + remainingSpace < overhang) + { + /* there isn't space to just shift the overhang into the next + buffer (the common case) so collapse everything written up + to now into a single buffer to be written as a unit */ + CollapseToSingleBuffer(); + } + else + { + /* shift up the overhang into the next buffer, copying the + data that is already there out of the way first */ + + Size_t copySize = nextBuffer->GetAvailableSize(); + Size_t availableSize; + const void* srcPtr = nextBuffer->GetReadAddress(0, &availableSize); + LogAssert(availableSize >= copySize); + void* dstPtr = nextBuffer->GetWriteAddress(overhang, + copySize, + &availableSize); + LogAssert(availableSize >= copySize); + + DrLogD( "ShuffleBuffers shifting data. Overhang size %Iu/%Iu copy size %Iu/%Iu", + overhang, headBuffer->GetAvailableSize(), + copySize, nextBuffer->GetAllocatedSize()); + + ::memmove(dstPtr, srcPtr, copySize); + + dstPtr = nextBuffer->GetWriteAddress(0, overhang, &availableSize); + LogAssert(availableSize >= overhang); + srcPtr = headBuffer->GetReadAddress(preMarshalAvailableSize, + &availableSize); + LogAssert(availableSize >= overhang); + ::memcpy(dstPtr, srcPtr, overhang); + + headBuffer->SetAvailableSize(preMarshalAvailableSize); + nextBuffer->SetAvailableSize(copySize + overhang); + } +} + +bool RChannelSerializedWriter:: + PerformMarshal(WriteRequestList* pendingRequestList, + WriteRequestList* completedRequestList) +{ + LogAssert(pendingRequestList->IsEmpty() == false); + + MakeCachedWriter(); + + UInt32 marshaledItemCount = 0; + + bool filledBuffer = m_cachedWriter->MarkRecordBoundary(); + LogAssert(filledBuffer == false); + bool shouldFlush = false; + RChannelItemType terminationType = RChannelItem_Data; + + do + { + WriteRequest* writeRequest = + pendingRequestList->CastOut(pendingRequestList->GetHead()); + + LogAssert(writeRequest->Completed() == false); + + do + { + ++marshaledItemCount; + + terminationType = PerformSingleMarshal(writeRequest); + + filledBuffer = m_cachedWriter->MarkRecordBoundary(); + + if (terminationType != RChannelItem_Data) + { + LogAssert(writeRequest->Completed()); + } + } while (filledBuffer == false && + writeRequest->Completed() == false && + marshaledItemCount < m_maxMarshalBatchSize); + + if (writeRequest->Completed()) + { + completedRequestList-> + TransitionToTail(completedRequestList->CastIn(writeRequest)); + shouldFlush = writeRequest->ShouldFlush(); + } + } while (filledBuffer == false && + shouldFlush == false && + terminationType == RChannelItem_Data && + pendingRequestList->IsEmpty() == false && + marshaledItemCount < m_maxMarshalBatchSize); + + bool shouldBlock = false; + + if (filledBuffer || shouldFlush || terminationType != RChannelItem_Data) + { + shouldBlock = + SendCompletedBuffers(shouldFlush, terminationType) || + shouldFlush || terminationType != RChannelItem_Data; + } + + return shouldBlock; +} + +void RChannelSerializedWriter:: + PostMarshal(WriteRequestList* pendingRequestList, + WriteRequestList* completedRequestList, + bool shouldBlock, + SyncHandler* syncHandler) +{ + bool marshaledLast = false; + WorkRequest* nextRequest = NULL; + RChannelItemType returnCode; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == CW_Marshaling); + + /* stick any newly arrived pending requests on the end of our + list */ + pendingRequestList->TransitionToTail(&m_pendingList); + /* then move the whole thing back where it belongs */ + m_pendingList.TransitionToTail(pendingRequestList); + + /* block when requested unless the outstanding buffers got + returned on another thread between the call to + SendCompletedBuffers and here */ + if (shouldBlock && m_outstandingBuffers > 0) + { + m_blockedHandlerList.TransitionToTail(completedRequestList); + m_state = CW_Blocking; + } + else if (m_pendingList.IsEmpty()) + { + if (m_writerTerminationItem != NULL) + { + LogAssert(m_outstandingHandlers > 0); + m_state = CW_Stopping; + } + else + { + m_state = CW_Empty; + } + } + else + { + m_state = CW_InWorkQueue; + nextRequest = new RChannelMarshalRequest(this); + } + + returnCode = m_channelTermination; + + if (m_state != CW_InWorkQueue && + m_writerTerminationItem != NULL && + m_pendingList.IsEmpty()) + { + LogAssert(m_marshaledTerminationItem == false); + m_marshaledTerminationItem = true; + marshaledLast = true; + } + + m_returnLatch.AcceptList(completedRequestList); + + if (completedRequestList->IsEmpty() && syncHandler != NULL) + { + syncHandler->UseEvent(m_eventCache.GetEvent(true)); + } + } + + if (marshaledLast) + { + DrLogD("Writer marshalled last item."); + BOOL bRet = ::SetEvent(m_marshaledLastItemEvent); + LogAssert(bRet != 0, "SetEvent() failed with error code = %d", GetLastError()); + } + + ReturnHandlers(completedRequestList, returnCode); + + if (nextRequest != NULL) + { + m_workQueue->EnQueue(nextRequest); + } +} + +void RChannelSerializedWriter::MarshalItems() +{ + WriteRequestList pendingRequestList; + WriteRequestList completedRequestList; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == CW_InWorkQueue); + + m_state = CW_Marshaling; + + pendingRequestList.TransitionToTail(&m_pendingList); + + LogAssert(pendingRequestList.IsEmpty() == false); + } + + bool shouldBlock = PerformMarshal(&pendingRequestList, + &completedRequestList); + + PostMarshal(&pendingRequestList, &completedRequestList, + shouldBlock, NULL); + + LogAssert(pendingRequestList.IsEmpty()); + LogAssert(completedRequestList.IsEmpty()); +} + +void RChannelSerializedWriter::ProcessWriteCompleted(RChannelItemType status) +{ + WriteRequestList unblockList; + WorkRequest* nextRequest = NULL; + RChannelItemType returnCode; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_outstandingBuffers > 0); + --m_outstandingBuffers; + + if (m_outstandingBuffers <= 2 && + m_state == CW_Blocking && + m_writerTerminationItem == NULL) + { + if (m_pendingList.IsEmpty()) + { + m_state = CW_Empty; + } + else + { + m_state = CW_InWorkQueue; + nextRequest = new RChannelMarshalRequest(this); + } + + unblockList.TransitionToHead(&m_blockedHandlerList); + } + + if (m_outstandingBuffers == 0) + { + if (m_state == CW_Blocking) + { + if (m_pendingList.IsEmpty()) + { + if (m_writerTerminationItem != NULL) + { + m_state = CW_Stopping; + /* if we are unblocking there must be at least + one handler queued to be released after the + block */ + LogAssert(m_blockedHandlerList.IsEmpty() == false); + LogAssert(m_outstandingHandlers > 0); + } + else + { + m_state = CW_Empty; + } + } + else + { + m_state = CW_InWorkQueue; + nextRequest = new RChannelMarshalRequest(this); + } + + unblockList.TransitionToHead(&m_blockedHandlerList); + } + else + { + LogAssert(m_state == CW_InWorkQueue || + m_state == CW_Marshaling || + m_state == CW_Empty); + } + } + + if (RChannelItem::IsTerminationItem(status)) + { + if (status != RChannelItem_EndOfStream) + { + m_channelTermination = status; + } + } + + returnCode = m_channelTermination; + + m_returnLatch.AcceptList(&unblockList); + } + + ReturnHandlers(&unblockList, returnCode); + + if (nextRequest != NULL) + { + m_workQueue->EnQueue(nextRequest); + } +} + +void RChannelSerializedWriter:: + WriteItemArray(RChannelItemArrayRef& itemArrayIn, + bool flushAfter, + RChannelItemArrayWriterHandler* handler) +{ + RChannelItemArrayRef itemArray; + /* ensure that the caller no longer holds a reference to this array */ + itemArray.TransferFrom(itemArrayIn); + + LogAssert(itemArray->GetNumberOfItems() > 0); + LogAssert(handler != NULL); + + WorkRequest* workRequest = NULL; + + bool alreadyTerminated; + + { + AutoCriticalSection acs(&m_baseDR); + + alreadyTerminated = CheckForTerminationItem(itemArray); + + if (alreadyTerminated) + { + LogAssert(m_state == CW_InWorkQueue || + m_state == CW_Marshaling || + m_state == CW_Blocking || + m_state == CW_Stopping || + m_state == CW_Stopped); + } + else + { + if (m_state == CW_Empty) + { + LogAssert(m_pendingList.IsEmpty()); + m_state = CW_InWorkQueue; + workRequest = new RChannelMarshalRequest(this); + } + else + { + LogAssert(m_state == CW_InWorkQueue || + m_state == CW_Marshaling || + m_state == CW_Blocking); + } + + WriteRequest* writeRequest = + new WriteRequest(itemArray, flushAfter, handler); + m_pendingList.InsertAsTail(m_pendingList.CastIn(writeRequest)); + + ++m_outstandingHandlers; + } + } + + if (alreadyTerminated) + { + LogAssert(workRequest == NULL); + handler->ProcessWriteArrayCompleted(RChannelItem_EndOfStream, NULL); + } + else + { + if (workRequest != NULL) + { + m_workQueue->EnQueue(workRequest); + } + } +} + +RChannelItemType RChannelSerializedWriter:: + WriteItemArraySync(RChannelItemArrayRef& itemArrayIn, + bool flush, + RChannelItemArrayRef* pFailureArray) +{ + RChannelItemArrayRef itemArray; + /* ensure that the caller no longer holds a reference to this array */ + itemArray.TransferFrom(itemArrayIn); + + LogAssert(itemArray->GetNumberOfItems() > 0); + + bool alreadyTerminated; + bool shouldMarshal = false; + + SyncHandler* handler = NULL; + WriteRequestList pendingRequestList; + + { + AutoCriticalSection acs(&m_baseDR); + + alreadyTerminated = CheckForTerminationItem(itemArray); + + if (alreadyTerminated) + { + LogAssert(m_state == CW_InWorkQueue || + m_state == CW_Marshaling || + m_state == CW_Blocking || + m_state == CW_Stopping || + m_state == CW_Stopped); + } + else + { + handler = new SyncHandler(); + WriteRequest* writeRequest = + new WriteRequest(itemArray, flush, handler); + + if (m_state == CW_Empty) + { + m_state = CW_Marshaling; + + pendingRequestList.InsertAsTail(pendingRequestList. + CastIn(writeRequest)); + shouldMarshal = true; + } + else + { + LogAssert(m_state == CW_InWorkQueue || + m_state == CW_Marshaling || + m_state == CW_Blocking); + + m_pendingList.InsertAsTail(m_pendingList.CastIn(writeRequest)); + writeRequest = NULL; + handler->UseEvent(m_eventCache.GetEvent(true)); + } + + ++m_outstandingHandlers; + } + } + + if (alreadyTerminated) + { + LogAssert(handler == NULL); + LogAssert(pendingRequestList.IsEmpty()); + if (pFailureArray != NULL) + { + *pFailureArray = NULL; + } + return RChannelItem_EndOfStream; + } + + if (shouldMarshal) + { + LogAssert(pendingRequestList.IsEmpty() == false); + WriteRequestList completedRequestList; + bool shouldBlock = false; + do + { + /* keep marshaling until we have done everything in this + request */ + shouldBlock = PerformMarshal(&pendingRequestList, + &completedRequestList) || + shouldBlock; + } while (pendingRequestList.IsEmpty() == false); + + PostMarshal(&pendingRequestList, &completedRequestList, + shouldBlock, handler); + + LogAssert(pendingRequestList.IsEmpty()); + LogAssert(completedRequestList.IsEmpty()); + } + + if (handler->UsingEvent()) + { +// DrLogE( "Waiting for handler"); + handler->Wait(); +// DrLogE( "Waited for handler"); + + { + AutoCriticalSection acs(&m_baseDR); + + m_eventCache.ReturnEvent(handler->GetEvent()); + } + } + + RChannelItemType returnCode = handler->GetStatusCode(); + if (pFailureArray != NULL) + { + handler->GetFailureItemArray(pFailureArray); + } + delete handler; + + LogAssert(returnCode != RChannelItem_EndOfStream); + + return returnCode; +} + +void RChannelSerializedWriter::Drain(DrTimeInterval csTimeOut, + RChannelItemRef* pRemoteStatus) +{ + bool mustWaitForHandler = false; + bool mustWaitForMarshal = false; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_writerTerminationItem != NULL); + + if (m_outstandingHandlers > 0 || m_state != CW_Stopping) + { + BOOL bRet = ::ResetEvent(m_handlerReturnEvent); + LogAssert(bRet != 0); + mustWaitForHandler = true; + } + + if (m_marshaledTerminationItem == false) + { + BOOL bRet = ::ResetEvent(m_marshaledLastItemEvent); + LogAssert(bRet != 0); + mustWaitForMarshal = true; + } + } + + if (mustWaitForMarshal) + { + DWORD dRet = ::WaitForSingleObject(m_marshaledLastItemEvent, INFINITE); + LogAssert(dRet == WAIT_OBJECT_0); + } + + RChannelItemRef returnItem; + m_writer->Drain(&returnItem); + LogAssert(returnItem != NULL); + + if (mustWaitForHandler) + { + DWORD dRet = ::WaitForSingleObject(m_handlerReturnEvent, INFINITE); + LogAssert(dRet == WAIT_OBJECT_0); + } + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == CW_Stopping); + + LogAssert(m_pendingList.IsEmpty()); + LogAssert(m_blockedHandlerList.IsEmpty()); + LogAssert(m_bufferList.IsEmpty()); + LogAssert(m_outstandingBuffers == 0); + LogAssert(m_outstandingHandlers == 0); + LogAssert(m_writerTerminationItem != NULL); + LogAssert(m_marshaledTerminationItem == true); + LogAssert(m_readerTerminationItem == NULL); + LogAssert(m_cachedWriter == NULL); + + m_readerTerminationItem = returnItem; + + m_returnLatch.Stop(); + m_marshaledTerminationItem = false; + m_state = CW_Stopped; + } + + if (pRemoteStatus != NULL) + { + *pRemoteStatus = returnItem; + } +} + +RChannelBufferWriterHandler::~RChannelBufferWriterHandler() +{ +} + +RChannelBufferWriter::~RChannelBufferWriter() +{ +} + +void RChannelBufferWriter::FillInStatus(DryadChannelDescription* status) +{ +} + + +RChannelNullWriter::RChannelNullWriter(const char* uri) +{ + m_uri = uri; + m_started = false; +} + +const char* RChannelNullWriter::GetURI() +{ + return m_uri; +} + +UInt64 RChannelNullWriter::GetInitialSizeHint() +{ + return 0; +} + +void RChannelNullWriter::SetInitialSizeHint(UInt64 /*hint*/) +{ +} + +void RChannelNullWriter::Start() +{ + { + AutoCriticalSection acs(&m_baseDR); + + m_writeTerminationItem = NULL; + m_started = true; + } +} + +void RChannelNullWriter:: + WriteItemArray(RChannelItemArrayRef& itemArray, + bool flushAfter, + RChannelItemArrayWriterHandler* handler) +{ + RChannelItemType status = WriteItemArraySync(itemArray, flushAfter, NULL); + handler->ProcessWriteArrayCompleted(status, NULL); +} + +RChannelItemType RChannelNullWriter:: + WriteItemArraySync(RChannelItemArrayRef& itemArray, + bool flush, + RChannelItemArrayRef* pFailureArray) +{ + RChannelItemType status; + + { + AutoCriticalSection acs(&m_baseDR); + + if (m_started) + { + UInt32 numberOfItems = itemArray->GetNumberOfItems(); + RChannelItemRef* array = itemArray->GetItemArray(); + UInt32 i; + for (i=0; iGetType())) + { + m_writeTerminationItem = item; + } + } + + if (m_writeTerminationItem == NULL) + { + status = RChannelItem_Data; + } + else + { + status = RChannelItem_EndOfStream; + } + } + else + { + status = RChannelItem_EndOfStream; + } + } + + itemArray = NULL; + + return status; +} + +void RChannelNullWriter::Drain(DrTimeInterval timeOut, + RChannelItemRef* pRemoteStatus) +{ + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_started); + LogAssert(m_writeTerminationItem != NULL); + + *pRemoteStatus = m_writeTerminationItem; + m_started = false; + } +} + +void RChannelNullWriter::GetTerminationItems(RChannelItemRef* pWriterDrainItem, + RChannelItemRef* pReaderDrainItem) +{ + { + AutoCriticalSection acs(&m_baseDR); + + *pWriterDrainItem = m_writeTerminationItem; + *pReaderDrainItem = m_writeTerminationItem; + } +} + +void RChannelNullWriter::Close() +{ + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_started == false); + } +} diff --git a/DryadVertex/VertexHost/system/channel/src/channelwriter.h b/DryadVertex/VertexHost/system/channel/src/channelwriter.h new file mode 100644 index 0000000..571042c --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/channelwriter.h @@ -0,0 +1,318 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#pragma once +#pragma warning(disable:4512) + +#include +#include +#include +#include +#include + +class RChannelMarshalRequest; +class RChannelBufferWriter; +class WorkQueue; + +class RChannelBufferWriterHandler +{ +public: + virtual ~RChannelBufferWriterHandler(); + + virtual void ProcessWriteCompleted(RChannelItemType status) = 0; +}; + +/* + The RChannelWriter is the primary mechanism for writing + application-specific structured items to an underlying byte-oriented + channel. + + The application first calls Start, then sends a series of items, + primarily of type RChannelItem_Data, though markers may be + interspersed, followed by a termination item of type + RChannelItem_Restart, RChannelItem_Abort or + RChannelItem_EndOfStream. In the case of a pipe which has not + broken, the termination item will be sent to the remote end so that + the consuming process learns the reason for the pipe closure. + + After sending a termination item no further items will be sent on + the channel until a call to Drain has completed. After this point + the channel may call Start again in the case of a restart, and start + sending items again. + + Each item is marshaled into the byte-stream using an + application-specific marshaler object which must co-operate with the + application so that bare RChannelItem objects can be cast by it into + objects containing meaningful data. + */ +class RChannelSerializedWriter : + public RChannelWriter, + public RChannelBufferWriterHandler +{ +public: + RChannelSerializedWriter(RChannelBufferWriter* writer, + RChannelItemMarshalerBase* marshaler, + bool breakBufferOnRecordBoundaries, + UInt32 maxMarshalBatchSize, + WorkQueue* workQueue); + ~RChannelSerializedWriter(); + + /* a client must call Start before writing items to the channel + the first time, or after a Drain has completed before writing + to a restarted channel. + */ + void Start(); + + void WriteItemArray(RChannelItemArrayRef& itemArray, + bool flushAfter, + RChannelItemArrayWriterHandler* handler); + + RChannelItemType WriteItemArraySync(RChannelItemArrayRef& itemArray, + bool flush, + RChannelItemArrayRef* pFailureArray); + + virtual void Drain(DrTimeInterval timeOut, + RChannelItemRef* pRemoteStatus); + + void GetTerminationItems(RChannelItemRef* pWriterDrainItem, + RChannelItemRef* pReaderDrainItem); + + /* Close may not be called unless Start has never been called or + Drain has completed since the last call to Start. After Close + is called the channel may not be restarted. */ + void Close(); + + /* this is the implementation of the RChannelBufferWriterHandler + interface and must not be called by clients. */ + void ProcessWriteCompleted(RChannelItemType status); + + /* Get/Set the URI of the channel. */ + const char* GetURI(); + void SetURI(const char* uri); + + /* Get/set a hint about the total length the channel is expected + to be. Some channel implementations can use this to improve + write performance and decrease disk fragmentation. A value of 0 + (the default) means that the size is unknown. */ + UInt64 GetInitialSizeHint(); + void SetInitialSizeHint(UInt64 hint); + +private: + class DummyItemHandler : public RChannelItemArrayWriterHandler + { + public: + void ProcessWriteArrayCompleted(RChannelItemType returnCode, + RChannelItemArray* failureArray); + }; + + class WriteRequest + { + public: + WriteRequest(RChannelItemArray* itemArray, + bool flushAfter, + RChannelItemArrayWriterHandler* handler); + ~WriteRequest(); + + void SetHandler(RChannelItemArrayWriterHandler* handler); + bool ShouldFlush(); + + RChannelItem* GetNextItem(); + void SetSuccessItem(); + void SetFailureItem(RChannelItem* marshalFailureItem, bool abort); + bool LastItem(); + bool Completed(); + + void ProcessMarshalCompleted(RChannelItemType returnType); + + private: + RChannelItemArrayRef m_itemArray; + bool m_flushAfter; + RChannelItemArrayWriterHandler* m_handler; + UInt32 m_currentItem; + RChannelItemArrayRef m_failureArray; + bool m_aborted; + DrBListEntry m_listPtr; + friend class DryadBList; + }; + + typedef DryadBList WriteRequestList; + + class SyncHandler : public RChannelItemArrayWriterHandler + { + public: + SyncHandler(); + ~SyncHandler(); + + void ProcessWriteArrayCompleted(RChannelItemType returnType, + RChannelItemArray* failureArray); + + RChannelItemType GetStatusCode(); + void GetFailureItemArray(RChannelItemArrayRef* pFailureArray); + + void UseEvent(DryadHandleListEntry* event); + DryadHandleListEntry* GetEvent(); + + bool UsingEvent(); + void Wait(); + + private: + LONG m_usingEvent; + DryadHandleListEntry* m_event; + RChannelItemType m_statusCode; + RChannelItemArrayRef m_failureArray; + }; + + enum CWState { + CW_Closed, + CW_Empty, + CW_InWorkQueue, + CW_Marshaling, + CW_Blocking, + CW_Stopping, + CW_Stopped + }; + + void MakeCachedWriter(); + Size_t DisposeOfCachedWriter(); + void ReturnHandlers(WriteRequestList* completedList, + RChannelItemType returnCode); + bool CheckForTerminationItem(RChannelItemArray* itemArray); + void RestorePreMarshalBuffers(Size_t preMarshalAvailableSize); + void CollapseToSingleBuffer(); + void ShuffleBuffersOnRecordBoundaries(Size_t preMarshalAvailableSize); + RChannelItemType PerformSingleMarshal(WriteRequest* writeRequest); + bool PerformMarshal(WriteRequestList* pendingRequestList, + WriteRequestList* completedRequestList); + void PostMarshal(WriteRequestList* pendingRequestList, + WriteRequestList* completedRequestList, + bool shouldBlock, + SyncHandler* syncHandler); + bool SendCompletedBuffers(bool shouldFlush, + RChannelItemType terminationType); + void MarshalItems(); + void AcceptReturningHandlers(UInt32 handlerCount); + + bool m_breakBufferOnRecordBoundaries; + RChannelBufferWriter* m_writer; + RChannelItemMarshalerRef m_marshaler; + WorkQueue* m_workQueue; + + UInt32 m_maxMarshalBatchSize; + + CWState m_state; + /* the pending list is the writes which have been submitted but + not yet queued for the marshaler. */ + WriteRequestList m_pendingList; + /* the blocked handler list is the writes which have been + marshaled and sent to the buffer writer and are waiting for the + channel to unblock */ + WriteRequestList m_blockedHandlerList; + /* outstandingBuffers is the count of buffers which have been sent + to the buffer writer and not yet completed */ + UInt32 m_outstandingBuffers; + /* outstandingBuffers is the count of handlers which have been + received from the client and not yet returned. */ + UInt32 m_outstandingHandlers; + HANDLE m_handlerReturnEvent; + bool m_marshaledTerminationItem; + HANDLE m_marshaledLastItemEvent; + + DryadFixedBufferList m_bufferList; + ChannelMemoryBufferWriter* m_cachedWriter; + DryadOrderedSendLatch m_returnLatch; + + DryadEventCache m_eventCache; + + RChannelItemType m_channelTermination; + RChannelItemRef m_writerTerminationItem; + RChannelItemRef m_readerTerminationItem; + + DrStr128 m_uri; + + CRITSEC m_baseDR; + + friend class RChannelMarshalRequest; +}; + +class RChannelBufferWriter +{ +public: + virtual ~RChannelBufferWriter(); + + virtual void Start() = 0; + + virtual DryadFixedMemoryBuffer* GetNextWriteBuffer() = 0; + virtual DryadFixedMemoryBuffer* GetCustomWriteBuffer(Size_t bufferSize) = 0; + + virtual bool WriteBuffer(DryadFixedMemoryBuffer* buffer, + bool flushAfter, + RChannelBufferWriterHandler* handler) = 0; + + virtual void ReturnUnusedBuffer(DryadFixedMemoryBuffer* buffer) = 0; + + virtual void WriteTermination(RChannelItemType reasonCode, + RChannelBufferWriterHandler* handler) = 0; + + virtual void FillInStatus(DryadChannelDescription* status); + + virtual void Drain(RChannelItemRef* pReturnItem) = 0; + + /* shut down the channel. After Close returns no further calls can + be made to this interface. + */ + virtual void Close() = 0; + + /* Get/set a hint about the total length the channel is expected + to be. Some channel implementations can use this to improve + write performance and decrease disk fragmentation. A value of 0 + (the default) means that the size is unknown. */ + virtual UInt64 GetInitialSizeHint() = 0; + virtual void SetInitialSizeHint(UInt64 hint) = 0; +}; + +class RChannelNullWriter : public RChannelWriter +{ +public: + RChannelNullWriter(const char* uri); + void Start(); + void WriteItemArray(RChannelItemArrayRef& itemArray, + bool flushAfter, + RChannelItemArrayWriterHandler* handler); + RChannelItemType WriteItemArraySync(RChannelItemArrayRef& itemArray, + bool flush, + RChannelItemArrayRef* pFailureArray); + void Drain(DrTimeInterval timeOut, + RChannelItemRef* pRemoteStatus); + void GetTerminationItems(RChannelItemRef* pWriterDrainItem, + RChannelItemRef* pReaderDrainItem); + void Close(); + + UInt64 GetInitialSizeHint(); + void SetInitialSizeHint(UInt64 hint); + const char* GetURI(); + +private: + bool m_started; + RChannelItemRef m_writeTerminationItem; + DrStr128 m_uri; + CRITSEC m_baseDR; +}; diff --git a/DryadVertex/VertexHost/system/channel/src/concreterchannel.cpp b/DryadVertex/VertexHost/system/channel/src/concreterchannel.cpp new file mode 100644 index 0000000..673a01f --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/concreterchannel.cpp @@ -0,0 +1,1645 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef TIDYFS +#include +#endif + +#pragma once +#pragma warning(disable:4512) +#pragma warning(disable:4511) +#include +#include + +#pragma unmanaged + + +static const char* s_filePrefix = "file://"; +static const char* s_fifoPrefix = "fifo://"; +static const char* s_nullPrefix = "null://"; +static const char* s_tidyfsPrefix = "tidyfs://"; +static const char* s_dscStreamPrefix = "hpcdsc://"; +static const char* s_dscPartitionPrefix = "hpcdscpt://"; +static const char* s_azureBlobPrefix = "http://"; + +// +// Use 6 to match up with retry count used between GM and VS +// +static const int s_dscRetryMax = 6; + +// +// Check if vertex host running on Azure node +// +static bool IsAzure() +{ + WCHAR buf[MAX_PATH]; + ZeroMemory(buf, sizeof(buf)); + + int nSchedulerTypeLen = ::GetEnvironmentVariable(L"CCP_SCHEDULERTYPE", buf, _countof(buf)); + if (_wcsicmp(buf,L"AZURE") == 0) + { + return true; + } + + ZeroMemory(buf, sizeof(buf)); + ::GetEnvironmentVariable(L"DEBUG_AZURE", buf, _countof(buf)); + if (nSchedulerTypeLen != 0 && _wcsicmp(buf, L"0") != 0) + { + return true; + } + return false; +} + +// +// Check if channel URI is reading from an on-premise NTFS file +// +bool ConcreteRChannel::IsNTFSFile(const char* uri) +{ + size_t prefixLen = ::strlen(s_filePrefix); + size_t uriLen = ::strlen(uri); + + if (uriLen < prefixLen + 2) + { + return false; + } + + if (_strnicmp(uri, s_filePrefix, ::strlen(s_filePrefix)) == 0) + { + if (!IsAzure() || uri[prefixLen] != '\\' || uri[prefixLen+1] != '\\') + { + return true; + } + } + return false; + +} + + +// +// Check if channel URI is a DSC partition by comparing the prefix to hpcdscpt:// +// +bool ConcreteRChannel::IsDscPartition(const char* uri) +{ + return (_strnicmp(uri, s_dscPartitionPrefix, ::strlen(s_dscPartitionPrefix)) == 0); +} + +// +// Check if channel URI is a HDFS file by comparing the prefix to hpchdfs:// +// +bool ConcreteRChannel::IsHdfsFile(const char* uri) +{ + return (_strnicmp(uri, + RChannelBufferHdfsWriter::s_hdfsFilePrefix, + ::strlen(RChannelBufferHdfsWriter:: + s_hdfsFilePrefix)) == 0); +} + +// +// Check if channel URI is a HDFS partition by comparing the prefix to hpchdfspt:// +// +bool ConcreteRChannel::IsHdfsPartition(const char* uri) +{ + return (_strnicmp(uri, + RChannelBufferHdfsReader::s_hdfsPartitionPrefix, + ::strlen(RChannelBufferHdfsReader:: + s_hdfsPartitionPrefix)) == 0); +} + +// +// Check if the channel URI is an Azure blob by comparing the prefix to http:// +// +bool ConcreteRChannel::IsAzureBlob(const char* uri) +{ + return (_strnicmp(uri, s_azureBlobPrefix, ::strlen(s_azureBlobPrefix)) == 0); +} + +// +// Check if the channel URI is a UNC path in Azure +// +bool ConcreteRChannel::IsUncPath(const char* uri) +{ + size_t prefixLen = ::strlen(s_filePrefix); + size_t uriLen = ::strlen(uri); + + if (uriLen < prefixLen + 2) + { + return false; + } + + if (_strnicmp(uri, s_filePrefix, ::strlen(s_filePrefix)) == 0) + { + if (IsAzure() && uri[prefixLen] == '\\' && uri[prefixLen+1] == '\\') + { + return true; + } + } + return false; +} +// +// Check if the channel URI is a DSC stream by comparing prefix to hpcdsc:// +// +bool ConcreteRChannel::IsDscStream(const char* uri) +{ + return (_strnicmp(uri, s_dscStreamPrefix, ::strlen(s_dscStreamPrefix)) == 0); +} + +// +// Check if the channel URI is a fifo by comparing prefix to fifo:// +// +bool ConcreteRChannel::IsFifo(const char* uri) +{ + return (_strnicmp(uri, s_fifoPrefix, ::strlen(s_fifoPrefix)) == 0); +} + +// +// Check if the channel URI is a null channel by comparing prefix to null:// +// +bool ConcreteRChannel::IsNull(const char* uri) +{ + return (_strnicmp(uri, s_nullPrefix, ::strlen(s_nullPrefix)) == 0); +} + +// +// Check if channel URI is a named pipe by comparing prefix to \pipe\ +// +bool ConcreteRChannel::IsNamedPipe(const char* uri) +{ + // uri is of form <\\CompName\pipe\> + const char* pipe = ::strstr(uri, "\\pipe\\"); + if(pipe) + { + while(pipe != uri) + { + --pipe; + if(*pipe == '\\') + { + return (--pipe == uri); + } + } + } + + return false; +} + +// +// Check if channel URI is a tidyfs stream by comparing prefix to tidyfs:// +// +bool ConcreteRChannel::IsTidyFSStream(const char* uri) +{ + return (_strnicmp(uri, s_tidyfsPrefix, ::strlen(s_tidyfsPrefix)) == 0); +} + +// +// Nothing to dispose +// todo: ensure this is correct +// +RChannelThrottledStream::~RChannelThrottledStream() +{ +} + +// +// Create a new dispatch with provided stream as source +// +RChannelOpenThrottler::Dispatch::Dispatch(RChannelThrottledStream* stream) +{ + m_stream = stream; +} + +void RChannelOpenThrottler::Dispatch::Process() +{ + m_stream->OpenAfterThrottle(); +} + +bool RChannelOpenThrottler::Dispatch::ShouldAbort() +{ + return false; +} + +// +// Constructor - Initialize properties (workqueue and max input) +// +RChannelOpenThrottler::RChannelOpenThrottler(UInt32 maxOpenFiles, + WorkQueue* workQueue) +{ + LogAssert(maxOpenFiles > 0); + m_maxOpenFiles = maxOpenFiles; + m_workQueue = workQueue; + m_openFileCount = 0; +} + +RChannelOpenThrottler::~RChannelOpenThrottler() +{ + LogAssert(m_openFileCount == 0); + LogAssert(m_blockedFileList.empty()); +} + +bool RChannelOpenThrottler::QueueOpen(RChannelThrottledStream* stream) +{ + bool openImmediately = false; + + { + AutoCriticalSection acs(&m_baseDR); + + if (m_openFileCount < m_maxOpenFiles) + { + ++m_openFileCount; + openImmediately = true; + } + else + { + LogAssert(m_openFileCount == m_maxOpenFiles); + m_blockedFileList.push_back(stream); + DrLogI( "Queueing open. %u files open, %Iu files blocked", + m_openFileCount, m_blockedFileList.size()); + } + } + + return openImmediately; +} + +// +// Notify throttler that a file is complete and a new one can be opened +// +void RChannelOpenThrottler::NotifyFileCompleted() +{ + Dispatch* dispatch = NULL; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_openFileCount > 0); + if (m_openFileCount == m_maxOpenFiles && + m_blockedFileList.empty() == false) + { + // + // If current throttle is full and there's more files ready to be opened, + // Open first blocked file + // + dispatch = new Dispatch(m_blockedFileList.front()); + m_blockedFileList.pop_front(); + DrLogI( "Unblocking open. %u files open, %Iu files still blocked", + m_openFileCount, m_blockedFileList.size()); + } + else + { + --m_openFileCount; + } + } + + if (dispatch != NULL) + { + // + // If a new file can be opened, add it to the work queue + // + bool b = m_workQueue->EnQueue(dispatch); + LogAssert(b); + } +} + +// +// Create a file reader. Every type of file path eventually ends up here. +// +static RChannelBufferReader* + CreateNativeFileReader(UInt32 numberOfReaders, + RChannelOpenThrottler* openThrottler, + WorkQueue* workQueue, + const char* fileName, + DryadMetaData* metaData, + DVErrorReporter* errorReporter, + LPDWORD localInputChannels) +{ + UInt32 blockSize = 4*1024; + UInt32 numberOfBlocksPerBuffer = 64 / numberOfReaders; + if (numberOfBlocksPerBuffer < 16) + { + numberOfBlocksPerBuffer = 16; + } + + RChannelBufferReaderNativeFile* fileReader = + new RChannelBufferReaderNativeFile(numberOfBlocksPerBuffer*blockSize, + blockSize, 4, + g_dryadNativePort, + workQueue, openThrottler); + if (fileReader == NULL) + { + return NULL; + } + + // + // Open the specified file + // + if (!fileReader->OpenA(fileName)) + { + delete fileReader; + + DrError errorCode = DrGetLastError(); + errorReporter->ReportError(errorCode, + "Can't open native file '%s' to read", + fileName); + + return NULL; + } + + // + // If counting local input channels, proceed with testing + // this is set to null in the case of Azure and in-memory fifo sources + // + if(localInputChannels != NULL) + { + DrStr myFilePath(fileName); + if(!myFilePath.StartsWith("\\\\")) + { + // + // If path doesnt start with "\\", then it must be local + // + (*localInputChannels)++; + } + else + { + // + // Otherwise, check host name to see if it's local + // + + // replace below to enable DNS hostname support + //DWORD hostLength = DNS_MAX_LABEL_BUFFER_LENGTH; + //char hostname[DNS_MAX_LABEL_BUFFER_LENGTH]; + //if (!GetComputerNameExA(ComputerNameDnsHostname, hostname, &hostLength)) + + DWORD hostLength = MAX_COMPUTERNAME_LENGTH + 1; + char hostname[MAX_COMPUTERNAME_LENGTH + 1]; + if (!GetComputerNameA(hostname, &hostLength)) + { + DrLogE( "Error calling GetComputerName. File: %s ErrorCode: %u", + + fileName, GetLastError()); + return NULL; + } + + DrStr myHostName(hostname); + + // + // Find first index of \ after \\ prefix + // this should find the end of machine name + // should always find an instance of the \ character in valid path (bug if not) + // + size_t firstIndexOfSlash = myFilePath.IndexOfChar('\\', 2); + LogAssert(firstIndexOfSlash != DrStr_InvalidIndex); + + if(firstIndexOfSlash - 2 == myHostName.GetLength()) + { + if(myFilePath.SubstrIsEqualNoCase(2, myHostName, myHostName.GetLength())) + { + // + // If length of host name is same and strings match, then file is local + // + (*localInputChannels)++; + } + } + } + } + + return fileReader; +} + +// +// Create a HDFS block reader. HDFS read path eventually ends up here. +// +static RChannelBufferReader* CreateHdfsBlockReader(const char* uri) +{ + return new RChannelBufferHdfsReaderLineRecord(uri); +} + + +static RChannelBufferWriter* CreateHdfsFileWriter(const char* uri) +{ + return new RChannelBufferHdfsWriter(uri); +} + + +// hide azure related +#if 0 +static RChannelBufferReader* + CreateUncFileReader(UInt32 numberOfReaders, + RChannelOpenThrottler* openThrottler, + WorkQueue* workQueue, + const char* streamName, + DryadMetaData* metaData, + DVErrorReporter* errorReporter, + LPDWORD localInputChannels) +{ + // Ensure that the stream name passed in is really a Azure Blob URL + LogAssert(ConcreteRChannel::IsUncPath(streamName)); + + // + // Copy the blob locally to a temp file and call + // CreateNativeFileReader on the temp file. + // + char path[MAX_PATH]; + if (GetTempFileNameA(".", "UNC", 0, path) == 0) + { + errorReporter->ReportError(E_FAIL, "Error calling GetTempFileName for '%s': %u", + streamName, GetLastError()); + return NULL; + } + + HRESULT err = DscGetNetworkFile(streamName + ::strlen(s_filePrefix), path); + LogAssert(err == S_OK); + if (err != S_OK) + { + errorReporter->ReportError(err, "Error calling GetTempFileName for '%s': %u", + streamName, GetLastError()); + return NULL; + } + + return CreateNativeFileReader(numberOfReaders, openThrottler, workQueue, path, metaData, errorReporter, localInputChannels); +} + +#endif + +#ifdef TIDYFS +static RChannelBufferReader* +CreateTidyFSStreamReader(UInt32 numberOfReaders, + RChannelOpenThrottler* openThrottler, + WorkQueue* workQueue, + const char* streamName, + DryadMetaData* metaData, + DVErrorReporter* errorReporter) +{ + MDClient *client = new MDClient(); + LogAssert(client != NULL); + DrError result = client->Initialize("rsl.ini"); + if (result != DrError_OK) + { + errorReporter->ReportError(result, + "Error initializing TidyFS client: %u: %s", + result, GetDrErrorDescription(result)); + + return NULL; + } + const char *hostname = Configuration::GetRawMachineName(); + + char path[2048]; + result = client->GetReadPath(path, 2048, streamName, hostname); + DrLogI( "CreateTidyFSStreamReader", "Stream: %s, Host: %s", streamName, hostname); + if (result != DrError_OK) + { + errorReporter->ReportError(result, + "Error calling GetReadPath on '%s': %u: %s", + streamName, result, GetDrErrorDescription(result)); + delete client; + return NULL; + } + DrLogI( "CreateTidyFSStreamReader", "Path: %s", path); + delete client; + return CreateNativeFileReader(numberOfReaders, openThrottler, workQueue, path, metaData, errorReporter); +} +#endif + +/* JC +static RChannelBufferReader* + CreateXComputeFileReader(UInt32 numberOfReaders, + RChannelOpenThrottler* openThrottler, + WorkQueue* workQueue, + const char* fileName, + DryadMetaData* metaData, + DVErrorReporter* errorReporter) +{ + UInt32 blockSize = 4*1024; + UInt32 numberOfBlocksPerBuffer = 64 / numberOfReaders; + if (numberOfBlocksPerBuffer < 16) + { + numberOfBlocksPerBuffer = 16; + } + + RChannelBufferReaderNativeXComputeFile* fileReader = + new RChannelBufferReaderNativeXComputeFile(numberOfBlocksPerBuffer*blockSize*64, + blockSize, 2, + g_dryadNativePort, + workQueue, openThrottler); + if (fileReader == NULL) + { + return NULL; + } + + if (!fileReader->OpenA(fileName)) + { + delete fileReader; + + DrError errorCode = DrGetLastError(); + errorReporter->ReportError(errorCode, + "Can't open xcompute file '%s' to read", + fileName); + + return NULL; + } + + return fileReader; +} + +static RChannelBufferReader* + CreateDryadStreamReader(UInt32 numberOfReaders, + RChannelOpenThrottler* openThrottler, + WorkQueue* workQueue, + const char* streamName, + DryadMetaData* metaData, + DVErrorReporter* errorReporter) +{ + RChannelBufferReaderDryadStream* streamReader = + new RChannelBufferReaderDryadStream(2*1024*1024, 1, + g_dryadNativePort, + workQueue, openThrottler); + if (streamReader == NULL) + { + return NULL; + } + + DrError cse = streamReader->OpenA(streamName); + if (cse != DrError_OK) + { + delete streamReader; + + errorReporter->ReportError(cse, + "Can't open cosmos stream '%s' to read" + " --- %s", + streamName, DRERRORSTRING(cse)); + + return NULL; + } + + return streamReader; +} + +static RChannelBufferReader* + CreateDryadPipeReader(UInt32 numberOfReaders, + const char* pipeName, + DryadMetaData* metaData, + DVErrorReporter* errorReporter) +{ + RChannelBufferReaderDryadPipe* pipeReader = + new RChannelBufferReaderDryadPipe(256*1024, 1, + g_dryadNativePort); + if (pipeReader == NULL) + { + return NULL; + } + + DrError cse = pipeReader->OpenA(pipeName); + if (cse != DrError_OK) + { + delete pipeReader; + + errorReporter->ReportError(cse, + "Can't open cosmos pipe '%s' to read" + " --- %s", + pipeName, DRERRORSTRING(cse)); + + return NULL; + } + + return pipeReader; +} +*/ + +// +// todo: figure out what and why +// +static RChannelBufferReader* + CreateNullReader(const char* uri, + DryadMetaData* metaData, + DVErrorReporter* errorReporter) +{ + return new RChannelNullBufferReader(); +} + +static RChannelBufferWriter* + CreateNativeFileWriter(UInt32 numberOfWriters, + RChannelOpenThrottler* openThrottler, + const char* fileName, + DryadMetaData* metaData, + bool* pBreakOnBufferBoundaries, + DVErrorReporter* errorReporter) +{ + UInt32 blockSize = 4*1024; + UInt32 numberOfBlocksPerBuffer = 8*64 / numberOfWriters; + if (numberOfBlocksPerBuffer < 16) + { + numberOfBlocksPerBuffer = 16; + } + + RChannelBufferWriterNativeFile* fileWriter = + new RChannelBufferWriterNativeFile(numberOfBlocksPerBuffer*blockSize, + blockSize, 2, 6, + g_dryadNativePort, openThrottler); + if (fileWriter == NULL) + { + return NULL; + } + + DrError cse = fileWriter->SetMetaData(metaData); + if (cse != DrError_OK) + { + delete fileWriter; + + const char* text = metaData->GetText(); + + errorReporter->ReportError(cse, + "Can't read native file metadata %s for '%s' to write" + " --- %s", + text, fileName, DRERRORSTRING(cse)); + + delete [] text; + + return NULL; + } + + if (!fileWriter->OpenA(fileName)) + { + delete fileWriter; + + DrError errorCode = DrGetLastError(); + errorReporter->ReportError(errorCode, + "Can't open native file '%s' to write", + fileName); + + return NULL; + } + + return fileWriter; +} + +/* JC +static RChannelBufferWriter* + CreateDryadStreamWriter(UInt32 numberOfWriters, + RChannelOpenThrottler* openThrottler, + const char* streamName, + DryadMetaData* metaData, + bool* pBreakOnBufferBoundaries, + DVErrorReporter* errorReporter) +{ + RChannelBufferWriterDryadStream* streamWriter = + new RChannelBufferWriterDryadStream(2*1024*1024, 0, 1, + g_dryadNativePort, openThrottler); + if (streamWriter == NULL) + { + return NULL; + } + + DrError cse = streamWriter->SetMetaData(metaData); + if (cse != DrError_OK) + { + delete streamWriter; + + const char* text = metaData->GetText(); + + errorReporter->ReportError(cse, + "Can't read cosmos stream metadata %s for '%s' to write" + " --- %s", + text, streamName, DRERRORSTRING(cse)); + + delete [] text; + + return NULL; + } + + cse = streamWriter->OpenA(streamName); + if (cse != DrError_OK) + { + delete streamWriter; + + errorReporter->ReportError(cse, + "Can't open cosmos stream '%s' to write" + " --- %s", + streamName, DRERRORSTRING(cse)); + + return NULL; + } + + *pBreakOnBufferBoundaries = true; + + return streamWriter; +} +*/ + + +#ifdef TIDYFS +static RChannelBufferWriter* +CreateTidyFSStreamWriter(UInt32 numberOfWriters, + RChannelOpenThrottler* openThrottler, + const char* streamName, + DryadMetaData* metaData, + bool* pBreakOnBufferBoundaries, + DVErrorReporter* errorReporter) +{ + UInt32 blockSize = 4*1024; + UInt32 numberOfBlocksPerBuffer = 64 / numberOfWriters; + if (numberOfBlocksPerBuffer < 16) + { + numberOfBlocksPerBuffer = 16; + } + + RChannelBufferWriterNativeTidyFSStream *tidyFSWriter = + new RChannelBufferWriterNativeTidyFSStream(numberOfBlocksPerBuffer*blockSize, + blockSize, 2, 6, + g_dryadNativePort, openThrottler); + if (tidyFSWriter == NULL) + { + return NULL; + } + + DrError cse = tidyFSWriter->OpenA(streamName, metaData); + if (cse != DrError_OK) + { + delete tidyFSWriter; + + errorReporter->ReportError(cse, + "Can't open TidyFS stream '%s' to write" + " --- %s", + streamName, DRERRORSTRING(cse)); + + return NULL; + } + + return tidyFSWriter; +} +#endif + +/* JC +static RChannelBufferWriter* + CreateDryadPipeWriter(UInt32 numberOfWriters, + const char* pipeName, + DryadMetaData* metaData, + bool* pBreakOnBufferBoundaries, + DVErrorReporter* errorReporter) +{ + RChannelBufferWriterDryadPipe* pipeWriter = + new RChannelBufferWriterDryadPipe(256*1024, 0, 1, + g_dryadNativePort); + if (pipeWriter == NULL) + { + return NULL; + } + + DrError cse = pipeWriter->OpenA(pipeName); + if (cse != DrError_OK) + { + delete pipeWriter; + + errorReporter->ReportError(cse, + "Can't open cosmos pipe '%s' to write" + " --- %s", + pipeName, DRERRORSTRING(cse)); + + return NULL; + } + + return pipeWriter; +} +*/ + +// +// Creates a new fifo holder +// +RChannelFifoHolder::RChannelFifoHolder(const char* channelURI, + UInt32 fifoLength, + bool isReader, + WorkQueue* workQueue) +{ + if (workQueue == NULL) + { + // + // create a work queue if needed + // + m_workQueue = new WorkQueue(4, 2); + m_workQueue->Start(); + workQueue = m_workQueue; + } + else + { + // + // If there is a global work queue, don't create a local one + // + m_workQueue = NULL; + } + + m_fifo = new RChannelFifo(channelURI, fifoLength, workQueue); + + m_discardedReader = false; + m_discardedWriter = false; + m_madeReader = isReader; + m_madeWriter = !isReader; +} + +RChannelFifoHolder::~RChannelFifoHolder() +{ + if (m_workQueue != NULL) + { + m_workQueue->Stop(); + } + delete m_fifo; + delete m_workQueue; +} + +RChannelFifo* RChannelFifoHolder::GetFifo() +{ + return m_fifo; +} + +bool RChannelFifoHolder::MakeReader() +{ + { + AutoCriticalSection acs(&m_atomic); + + if (m_madeReader == false) + { + m_madeReader = true; + return true; + } + else + { + return false; + } + } +} + +bool RChannelFifoHolder::MakeWriter() +{ + { + AutoCriticalSection acs(&m_atomic); + + if (m_madeWriter == false) + { + m_madeWriter = true; + return true; + } + else + { + return false; + } + } +} + +bool RChannelFifoHolder::DiscardReader() +{ + { + AutoCriticalSection acs(&m_atomic); + + LogAssert(m_discardedReader == false); + m_discardedReader = true; + if (m_discardedWriter) + { + m_fifo->GetReader()->Close(); + m_fifo->GetWriter()->Close(); + return true; + } + else + { + return false; + } + } +} + +bool RChannelFifoHolder::DiscardWriter() +{ + { + AutoCriticalSection acs(&m_atomic); + + LogAssert(m_discardedWriter == false); + m_discardedWriter = true; + if (m_discardedReader) + { + m_fifo->GetReader()->Close(); + m_fifo->GetWriter()->Close(); + return true; + } + else + { + return false; + } + } +} + +typedef std::map< std::string, RChannelFifoHolder*, std::less > + FifoMap; + +static FifoMap s_fifoMap; +static CRITSEC s_fifoAtomic; +static UInt32 s_fifoUniqueId = 0; + +static UInt32 GetFifoLength(const char* channelURI) +{ + DrStr256 uri( channelURI ); + LogAssert(uri.StartsWithNoCase("fifo://", strlen("fifo://"))); + // BUGBUG: sammck: this code is inefficient, copies entire string when only the first word is used. + DrStr256 lengthField( channelURI + strlen("fifo://") ); + size_t endOfField = lengthField.IndexOfChar('/'); + if (endOfField == DrStr_InvalidIndex) + { + DrLogW( "fifo URI malformed. URI %s", channelURI); + return 0; + } + + lengthField.UpdateLength(endOfField); + + UInt32 length; + DrError err = DrStringToUInt32(lengthField, &length); + if (err == DrError_OK) + { + return length; + } + else + { + DrLogW( "fifo URI malformed. URI %s lengthField %s getlength returned %s", + channelURI, lengthField.GetString(), DRERRORSTRING(err)); + return 0; + } +} + + +RChannelReaderHolder::~RChannelReaderHolder() +{ +} + +RChannelWriterHolder::~RChannelWriterHolder() +{ +} + +// +// Create a new fifo reader holder +// +RChannelFifoReaderHolder:: + RChannelFifoReaderHolder(const char* channelURI, + WorkQueue* workQueue, + DVErrorReporter* errorReporter) +{ + m_reader = NULL; + + UInt32 fifoLength = GetFifoLength(channelURI); + + if (fifoLength == 0) + { + // + // Report an error if any inputs are empty + // + errorReporter->ReportError(DryadError_InvalidChannelURI, + "fifo %s has invalid length", + channelURI); + return; + } + + { + AutoCriticalSection acs(&s_fifoAtomic); + + FifoMap::iterator existing = s_fifoMap.find(channelURI); + if (existing == s_fifoMap.end()) + { + // + // If fifo doesn't already exist, create a new reader + // + DrLogI( "Creating new fifo reader. Name %s length %u", channelURI, fifoLength); + RChannelFifoHolder* holder = + new RChannelFifoHolder(channelURI, fifoLength, true, + workQueue); + s_fifoMap.insert(std::make_pair(channelURI, holder)); + m_reader = holder->GetFifo()->GetReader(); + } + else + { + DrLogI( + "Looked up existing fifo for reader. Name %s length %u", channelURI, fifoLength); + if (existing->second->MakeReader()) + { + RChannelFifo* fifo = existing->second->GetFifo(); + m_reader = fifo->GetReader(); + } + else + { + errorReporter->ReportError(DryadError_InvalidChannelURI, + "Duplicate fifo name %s passed " + "to reader create", + channelURI); + } + } + } +} + +RChannelFifoReaderHolder::~RChannelFifoReaderHolder() +{ + Close(); +} + +RChannelReader* RChannelFifoReaderHolder::GetReader() +{ + return m_reader; +} + +void RChannelFifoReaderHolder::FillInStatus(DryadInputChannelDescription* s) +{ + LogAssert(m_reader != NULL); + + RChannelFifoWriterBase* writer = + m_reader->GetParent()->GetWriter(); + LogAssert(writer != NULL); + + s->SetChannelProcessedLength(m_reader->GetDataSizeRead()); + s->SetChannelTotalLength(writer->GetDataSizeWritten()); +} + +void RChannelFifoReaderHolder::Close() +{ + if (m_reader != NULL) + { + RChannelFifo* fifo = m_reader->GetParent(); + const char* name = fifo->GetName(); + + { + AutoCriticalSection acs(&s_fifoAtomic); + + FifoMap::iterator holder = s_fifoMap.find(name); + LogAssert(holder != s_fifoMap.end()); + + if (holder->second->DiscardReader()) + { + DrLogI( "Discarding fifo (reader). Name %s", name); + + delete holder->second; + s_fifoMap.erase(holder); + } + } + + m_reader = NULL; + } +} + + +RChannelFifoWriterHolder:: + RChannelFifoWriterHolder(const char* channelURI, + DVErrorReporter* errorReporter) +{ + m_writer = NULL; + + UInt32 fifoLength = GetFifoLength(channelURI); + + if (fifoLength == 0) + { + errorReporter->ReportError(DryadError_InvalidChannelURI, + "fifo %s has invalid length", + channelURI); + return; + } + + { + AutoCriticalSection acs(&s_fifoAtomic); + + FifoMap::iterator existing = s_fifoMap.find(channelURI); + if (existing == s_fifoMap.end()) + { + DrLogI( "Creating new fifo writer. Name %s length %u", channelURI, fifoLength); + RChannelFifoHolder* holder = + new RChannelFifoHolder(channelURI, fifoLength, false, NULL); + s_fifoMap.insert(std::make_pair(channelURI, holder)); + m_writer = holder->GetFifo()->GetWriter(); + } + else + { + DrLogI( + "Looked up existing fifo for writer. Name %s length %u", channelURI, fifoLength); + if (existing->second->MakeWriter()) + { + RChannelFifo* fifo = existing->second->GetFifo(); + m_writer = fifo->GetWriter(); + } + else + { + errorReporter->ReportError(DryadError_InvalidChannelURI, + "Duplicate fifo name %s passed " + "to writer create", + channelURI); + } + } + } +} + +RChannelFifoWriterHolder::~RChannelFifoWriterHolder() +{ + Close(); +} + +RChannelWriter* RChannelFifoWriterHolder::GetWriter() +{ + return m_writer; +} + +void RChannelFifoWriterHolder::FillInStatus(DryadOutputChannelDescription* s) +{ + LogAssert(m_writer != NULL); + s->SetChannelProcessedLength(m_writer->GetDataSizeWritten()); +} + +void RChannelFifoWriterHolder::Close() +{ + if (m_writer != NULL) + { + RChannelFifo* fifo = m_writer->GetParent(); + const char* name = fifo->GetName(); + + { + AutoCriticalSection acs(&s_fifoAtomic); + + FifoMap::iterator holder = s_fifoMap.find(name); + LogAssert(holder != s_fifoMap.end()); + + if (holder->second->DiscardWriter()) + { + DrLogI( "Discarding fifo (writer). Name %s", name); + + delete holder->second; + s_fifoMap.erase(holder); + } + } + + m_writer = NULL; + } +} + +RChannelNullWriterHolder::RChannelNullWriterHolder(const char* uri) +{ + m_writer = new RChannelNullWriter(uri); +} + +RChannelNullWriterHolder::~RChannelNullWriterHolder() +{ + Close(); +} + +RChannelWriter* RChannelNullWriterHolder::GetWriter() +{ + return m_writer; +} + +void RChannelNullWriterHolder::FillInStatus(DryadOutputChannelDescription* s) +{ + LogAssert(m_writer != NULL); + s->SetChannelProcessedLength(0); +} + +void RChannelNullWriterHolder::Close() +{ + delete m_writer; + m_writer = NULL; +} + + +// +// Create a buffer reader for the incoming channel +// +RChannelBufferedReaderHolder:: + RChannelBufferedReaderHolder(const char* channelURI, + RChannelOpenThrottler* openThrottler, + DryadMetaData* metaData, + RChannelItemParserBase* parser, + UInt32 numberOfReaders, + UInt32 maxParseBatchSize, + UInt32 maxParseUnitsInFlight, + WorkQueue* workQueue, + DVErrorReporter* errorReporter, + LPDWORD localInputChannels) +{ + m_parser = parser; + m_reader = NULL; + m_bufferReader = NULL; + + // + // Create buffer reader + // + bool lazyStart = + CreateBufferReader(numberOfReaders, openThrottler, workQueue, + channelURI, metaData, errorReporter, localInputChannels); + + // + // If error, just return. Caller will see same error and act on it + // + if (errorReporter->GetErrorCode() != DrError_OK) + { + return; + } + + LogAssert(m_bufferReader != NULL); + + // + // Create a serial reader wrapper around the buffer reader + // + RChannelSerializedReader* r = + new RChannelSerializedReader(m_bufferReader, + parser, + maxParseBatchSize, + maxParseUnitsInFlight, + lazyStart, + workQueue); + r->SetURI(channelURI); + + m_reader = r; +} + +// +// Call close on destruction, which handles cleanup +// +RChannelBufferedReaderHolder::~RChannelBufferedReaderHolder() +{ + Close(); +} + +// +// Create a buffer reader, taking the type of input into account +// +bool RChannelBufferedReaderHolder:: + CreateBufferReader(UInt32 numberOfReaders, + RChannelOpenThrottler* openThrottler, + WorkQueue* workQueue, + const char* channelURI, + DryadMetaData* metaData, + DVErrorReporter* errorReporter, + LPDWORD localInputChannels) +{ + bool lazyStart = false; + + if (ConcreteRChannel::IsNTFSFile(channelURI)) + { + // + // If URI is on-premise NTFS file, create the file reader right away + // + m_bufferReader = + CreateNativeFileReader(numberOfReaders, openThrottler, + workQueue, + channelURI + ::strlen(s_filePrefix), + metaData, errorReporter, localInputChannels); + lazyStart = true; + } + else if (ConcreteRChannel::IsHdfsPartition(channelURI)) + { + m_bufferReader = CreateHdfsBlockReader(channelURI); + lazyStart = false; + } + else if (ConcreteRChannel::IsFifo(channelURI)) + { + // + // If URI is a fifo, report it as an error + // + errorReporter->ReportError(DryadError_InvalidChannelURI, + "RChannelBufferReaderFactory passed " + "fifo URI %s in error", + channelURI); + } + else if (ConcreteRChannel::IsNull(channelURI)) + { + // + // If URI is a null reader, create a specialized reader for it + // todo: figure out when this is used + // + m_bufferReader = + CreateNullReader(channelURI, + metaData, errorReporter); + } + // hide azure related +#if 0 + else if (ConcreteRChannel::IsUncPath(channelURI)) + { + // + // If URI is a UNC path in azure, copy the file locally and then create reader + // + m_bufferReader = CreateUncFileReader(numberOfReaders, openThrottler, + workQueue, channelURI, metaData, errorReporter, NULL); + } +#endif + else + { + // + // If any other channel URI, give up and report error + // + errorReporter->ReportError(DryadError_InvalidChannelURI, + "Can't open channel '%s' to read --- " + "unknown prefix (must be %s, %s, %s, %s, %s or %s)", + channelURI, + s_filePrefix, s_tidyfsPrefix, s_fifoPrefix, + s_dscPartitionPrefix, + RChannelBufferHdfsReader:: + s_hdfsPartitionPrefix, + s_nullPrefix); + } + + // + // If creating the reader failed, report the error + // + if (m_bufferReader == NULL && errorReporter->NoError()) + { + errorReporter->ReportError(DryadError_ChannelOpenError, + "Can't open channel '%s' to read", + channelURI); + } + + return lazyStart; +} + +// +// Return the reader +// +RChannelReader* RChannelBufferedReaderHolder::GetReader() +{ + return m_reader; +} + +void RChannelBufferedReaderHolder:: + FillInStatus(DryadInputChannelDescription* s) +{ + LogAssert(m_bufferReader != NULL); + m_bufferReader->FillInStatus(s); +} + +// +// Clean up the reader and the buffered reader wrapper +// +void RChannelBufferedReaderHolder::Close() +{ + if (m_reader == NULL) + { + LogAssert(m_bufferReader == NULL); + } + else + { + m_reader->Close(); + delete m_reader; + m_reader = NULL; + LogAssert(m_bufferReader != NULL); + delete m_bufferReader; + m_bufferReader = NULL; + } +} + +RChannelBufferedWriterHolder:: + RChannelBufferedWriterHolder(const char* channelURI, + RChannelOpenThrottler* openThrottler, + DryadMetaData* metaData, + RChannelItemMarshalerBase* marshaler, + UInt32 numberOfWriters, + UInt32 maxMarshalBatchSize, + WorkQueue* workQueue, + DVErrorReporter* errorReporter) +{ + m_marshaler = marshaler; + m_writer = NULL; + m_bufferWriter = NULL; + + bool breakOnBufferBoundaries; + CreateBufferWriter(numberOfWriters, openThrottler, + channelURI, metaData, &breakOnBufferBoundaries, + errorReporter); + + if (errorReporter->GetErrorCode() != DrError_OK) + { + return; + } + + LogAssert(m_bufferWriter != NULL); + + RChannelSerializedWriter* w = + new RChannelSerializedWriter(m_bufferWriter, + marshaler, + breakOnBufferBoundaries, + maxMarshalBatchSize, + workQueue); + w->SetURI(channelURI); + + m_writer = w; +} + +RChannelBufferedWriterHolder::~RChannelBufferedWriterHolder() +{ + Close(); +} + +void RChannelBufferedWriterHolder:: + CreateBufferWriter(UInt32 numberOfWriters, + RChannelOpenThrottler* openThrottler, + const char* channelURI, + DryadMetaData* metaData, + bool* pBreakOnBufferBoundaries, + DVErrorReporter* errorReporter) +{ + *pBreakOnBufferBoundaries = false; + + if (ConcreteRChannel::IsNTFSFile(channelURI)) + { + m_bufferWriter = + CreateNativeFileWriter(numberOfWriters, openThrottler, + channelURI + ::strlen(s_filePrefix), + metaData, pBreakOnBufferBoundaries, + errorReporter); + } + else if (ConcreteRChannel::IsHdfsFile(channelURI)) + { + m_bufferWriter = CreateHdfsFileWriter(channelURI); + } + else if (ConcreteRChannel::IsFifo(channelURI)) + { + errorReporter->ReportError(DryadError_InvalidChannelURI, + "RChannelBufferWriterFactory passed " + "fifo URI %s in error", + channelURI); + } + else + { + errorReporter->ReportError(DryadError_InvalidChannelURI, + "Can't open channel '%s' to write --- " + "unknown prefix (must be %s, %s or %s)", + channelURI, s_filePrefix, + s_tidyfsPrefix, s_fifoPrefix); + } + + if (m_bufferWriter == NULL && errorReporter->NoError()) + { + errorReporter->ReportError(DryadError_ChannelOpenError, + "Can't open channel '%s' to write", + channelURI); + } +} + +RChannelWriter* RChannelBufferedWriterHolder::GetWriter() +{ + return m_writer; +} + +void RChannelBufferedWriterHolder:: + FillInStatus(DryadOutputChannelDescription* s) +{ + LogAssert(m_bufferWriter != NULL); + m_bufferWriter->FillInStatus(s); +} + +void RChannelBufferedWriterHolder::Close() +{ + if (m_writer == NULL) + { + LogAssert(m_bufferWriter == NULL); + } + else + { + m_writer->Close(); + delete m_writer; + m_writer = NULL; + LogAssert(m_bufferWriter != NULL); + delete m_bufferWriter; + m_bufferWriter = NULL; + } +} + +// +// Open a reader on an input channel +// +DrError RChannelFactory::OpenReader(const char* channelURI, + DryadMetaData* metaData, + RChannelItemParserBase* parser, + UInt32 numberOfReaders, + RChannelOpenThrottler* openThrottler, + UInt32 maxParseBatchSize, + UInt32 maxParseUnitsInFlight, + WorkQueue* workQueue, + DVErrorReporter* errorReporter, + RChannelReaderHolderRef* pHolder, + LPDWORD localInputChannels) +{ + LogAssert(errorReporter->NoError()); + LogAssert(numberOfReaders > 0); + + if (ConcreteRChannel::IsFifo(channelURI)) + { + // + // If the input channel is a FIFO, open a FIFO channel reader + // + pHolder->Attach(new RChannelFifoReaderHolder(channelURI, workQueue, + errorReporter)); + } + else + { + // + // If the input channel is a file or DSC buffer, open a generic buffer reader + // + pHolder->Attach(new RChannelBufferedReaderHolder(channelURI, + openThrottler, + metaData, + parser, + numberOfReaders, + maxParseBatchSize, + maxParseUnitsInFlight, + workQueue, + errorReporter, + localInputChannels)); + } + + return errorReporter->GetErrorCode(); +} + +// +// Open writer for provided output channel +// +DrError RChannelFactory::OpenWriter(const char* channelURI, + DryadMetaData* metaData, + RChannelItemMarshalerBase* marshaler, + UInt32 numberOfWriters, + RChannelOpenThrottler* openThrottler, + UInt32 maxMarshalBatchSize, + WorkQueue* workQueue, + DVErrorReporter* errorReporter, + RChannelWriterHolderRef* pHolder) +{ + LogAssert(errorReporter->NoError()); + LogAssert(numberOfWriters > 0); + + if (ConcreteRChannel::IsFifo(channelURI)) + { + // + // If the output channel is a FIFO, open a FIFO channel writer + // + pHolder->Attach(new RChannelFifoWriterHolder(channelURI, + errorReporter)); + } + else if (ConcreteRChannel::IsNull(channelURI)) + { + // + // If the output channel is a null channel, open a null channel writer + // + pHolder->Attach(new RChannelNullWriterHolder(channelURI)); + } + else + { + // + // If the output channel is a file buffer, open a file buffer writer + // + pHolder->Attach(new RChannelBufferedWriterHolder(channelURI, + openThrottler, + metaData, + marshaler, + numberOfWriters, + maxMarshalBatchSize, + workQueue, + errorReporter)); + } + + return errorReporter->GetErrorCode(); +} + +UInt32 RChannelFactory::GetUniqueFifoId() +{ + UInt32 id; + + { + AutoCriticalSection acs(&s_fifoAtomic); + + id = s_fifoUniqueId; + ++s_fifoUniqueId; + } + + return id; +} + +// +// Create a new Throttler +// +RChannelOpenThrottler* RChannelFactory::MakeOpenThrottler(UInt32 maxOpens, + WorkQueue* workQueue) +{ + return new RChannelOpenThrottler(maxOpens, workQueue); +} + +// +// Delete referenced throttler +// +void RChannelFactory::DiscardOpenThrottler(RChannelOpenThrottler* throttler) +{ + delete throttler; +} + diff --git a/DryadVertex/VertexHost/system/channel/src/concreterchannelhelpers.h b/DryadVertex/VertexHost/system/channel/src/concreterchannelhelpers.h new file mode 100644 index 0000000..4c4facf --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/concreterchannelhelpers.h @@ -0,0 +1,194 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "concreterchannel.h" +#include "channelfifo.h" +#include + +class RChannelThrottledStream +{ +public: + ~RChannelThrottledStream(); + virtual void OpenAfterThrottle() = 0; +}; + +class RChannelOpenThrottler +{ +public: + class Dispatch : public WorkRequest + { + public: + Dispatch(RChannelThrottledStream* stream); + + void Process(); + bool ShouldAbort(); + + private: + RChannelThrottledStream* m_stream; + }; + + RChannelOpenThrottler(UInt32 maxOpenFiles, WorkQueue* workQueue); + ~RChannelOpenThrottler(); + + bool QueueOpen(RChannelThrottledStream* stream); + void NotifyFileCompleted(); + +private: + WorkQueue* m_workQueue; + UInt32 m_maxOpenFiles; + UInt32 m_openFileCount; + std::list m_blockedFileList; + CRITSEC m_baseDR; +}; + +class RChannelFifoHolder +{ +public: + RChannelFifoHolder(const char* channelURI, UInt32 fifoLength, + bool isReader, WorkQueue* workQueue); + ~RChannelFifoHolder(); + + RChannelFifo* GetFifo(); + + bool MakeReader(); + bool MakeWriter(); + + bool DiscardReader(); + bool DiscardWriter(); + +private: + RChannelFifo* m_fifo; + WorkQueue* m_workQueue; + bool m_madeReader; + bool m_discardedReader; + bool m_madeWriter; + bool m_discardedWriter; + + CRITSEC m_atomic; +}; + +class RChannelFifoReaderHolder : public RChannelReaderHolder +{ +public: + RChannelFifoReaderHolder(const char* channelURI, + WorkQueue* workQueue, + DVErrorReporter* errorReporter); + ~RChannelFifoReaderHolder(); + + RChannelReader* GetReader(); + void FillInStatus(DryadInputChannelDescription* status); + void Close(); + +private: + RChannelFifoReader* m_reader; +}; + +class RChannelFifoWriterHolder : public RChannelWriterHolder +{ +public: + RChannelFifoWriterHolder(const char* channelURI, + DVErrorReporter* errorReporter); + ~RChannelFifoWriterHolder(); + + RChannelWriter* GetWriter(); + void FillInStatus(DryadOutputChannelDescription* status); + void Close(); + +private: + RChannelFifoWriterBase* m_writer; +}; + +class RChannelBufferedReaderHolder : public RChannelReaderHolder +{ +public: + RChannelBufferedReaderHolder(const char* channelURI, + RChannelOpenThrottler* openThrottler, + DryadMetaData* metaData, + RChannelItemParserBase* parser, + UInt32 numberOfReaders, + UInt32 maxParseBatchSize, + UInt32 maxParseUnitsInFlight, + WorkQueue* workQueue, + DVErrorReporter* errorReporter, + LPDWORD localInputChannels); + ~RChannelBufferedReaderHolder(); + + RChannelReader* GetReader(); + void FillInStatus(DryadInputChannelDescription* status); + void Close(); + +private: + bool CreateBufferReader(UInt32 numberOfReaders, + RChannelOpenThrottler* openThrottler, + WorkQueue* workQueue, + const char* channelURI, + DryadMetaData* metaData, + DVErrorReporter* errorReporter, + LPDWORD localInputChannels); + + RChannelItemParserRef m_parser; + RChannelBufferReader* m_bufferReader; + RChannelReader* m_reader; +}; + +class RChannelBufferedWriterHolder : public RChannelWriterHolder +{ +public: + RChannelBufferedWriterHolder(const char* channelURI, + RChannelOpenThrottler* openThrottler, + DryadMetaData* metaData, + RChannelItemMarshalerBase* marshaler, + UInt32 numberOfWriters, + UInt32 maxMarshalBatchSize, + WorkQueue* workQueue, + DVErrorReporter* errorReporter); + ~RChannelBufferedWriterHolder(); + + RChannelWriter* GetWriter(); + void FillInStatus(DryadOutputChannelDescription* status); + void Close(); + +private: + void CreateBufferWriter(UInt32 numberOfWriters, + RChannelOpenThrottler* openThrottler, + const char* channelURI, + DryadMetaData* metaData, + bool* pBreakOnBufferBoundaries, + DVErrorReporter* errorReporter); + + RChannelItemMarshalerRef m_marshaler; + RChannelBufferWriter* m_bufferWriter; + RChannelWriter* m_writer; +}; + +class RChannelNullWriterHolder : public RChannelWriterHolder +{ +public: + RChannelNullWriterHolder(const char* uri); + ~RChannelNullWriterHolder(); + RChannelWriter* GetWriter(); + void FillInStatus(DryadOutputChannelDescription* s); + void Close(); + +private: + RChannelWriter* m_writer; +}; diff --git a/DryadVertex/VertexHost/system/channel/src/memorybuffers.cpp b/DryadVertex/VertexHost/system/channel/src/memorybuffers.cpp new file mode 100644 index 0000000..d86792e --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/memorybuffers.cpp @@ -0,0 +1,415 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "channelmemorybuffers.h" +#include "channelreader.h" +#include "channelwriter.h" + +#pragma unmanaged + + +RChannelReaderBuffer::Current::Current(Size_t initialStartOffset, + Size_t finalEndOffset, + DryadLockedMemoryBuffer** + bufferArray, + size_t nBuffers, + Size_t offset) +{ + m_currentBuffer = 0; + m_currentBufferBase = 0; + m_currentHeadCutLength = initialStartOffset; + m_currentTailOffset = 0; + + SetTailBufferData(bufferArray, nBuffers, finalEndOffset); + + while (offset >= m_currentTailOffset) + { + ++m_currentBuffer; + LogAssert(m_currentBuffer < nBuffers); + + m_currentBufferBase = m_currentTailOffset; + m_currentHeadCutLength = 0; + + SetTailBufferData(bufferArray, nBuffers, finalEndOffset); + } +} + +void* RChannelReaderBuffer::Current:: + GetDataAddress(DryadLockedMemoryBuffer** bufferArray, + Size_t offset, + Size_t *puSize, + Size_t *puPriorSize) +{ + LogAssert(offset >= m_currentBufferBase && + offset < m_currentTailOffset); + + Size_t offsetInAvailable = offset - m_currentBufferBase; + Size_t offsetInBuffer = offsetInAvailable + m_currentHeadCutLength; + Size_t remaining = m_currentTailOffset - offset; + + void* dataAddress = + bufferArray[m_currentBuffer]->GetDataAddress(offsetInBuffer, + puSize, + puPriorSize); + LogAssert(dataAddress != NULL); + + if (*puSize > remaining) + { + *puSize = remaining; + } + + if (puPriorSize != NULL && *puPriorSize > offsetInAvailable) + { + *puPriorSize = offsetInAvailable; + } + + return dataAddress; +} + +void RChannelReaderBuffer::Current:: + SetTailBufferData(DryadLockedMemoryBuffer** bufferArray, + size_t nBuffers, + Size_t finalEndOffset) +{ + Size_t bSize = bufferArray[m_currentBuffer]->GetAvailableSize(); + + if (m_currentBuffer == nBuffers-1) + { + LogAssert(finalEndOffset >= m_currentHeadCutLength); + LogAssert(finalEndOffset <= bSize); + + m_currentTailOffset += (finalEndOffset - m_currentHeadCutLength); + } + else + { + m_currentTailOffset += (bSize - m_currentHeadCutLength); + } +} + + +RChannelReaderBuffer::RChannelReaderBuffer(DryadLockedBufferList* bufferList, + Size_t startOffset, + Size_t endOffset) +{ + m_nBuffers = bufferList->CountLinks(); + LogAssert(m_nBuffers > 0); + m_bufferArray = new DryadLockedMemoryBuffer *[m_nBuffers]; + + LogAssert(m_uAllocatedSize == 0); + size_t i; + DrBListEntry* listEntry = bufferList->GetHead(); + for (i=0; i < m_nBuffers; ++i) + { + LogAssert(listEntry != NULL); + m_bufferArray[i] = bufferList->CastOut(listEntry); + m_bufferArray[i]->IncRef(); + m_uAllocatedSize += m_bufferArray[i]->GetAvailableSize(); + listEntry = bufferList->GetNext(listEntry); + } + LogAssert(listEntry == NULL); + + Initialise(startOffset, endOffset); +} + +RChannelReaderBuffer::RChannelReaderBuffer(ChannelDataBufferList* bufferList, + Size_t startOffset, + Size_t endOffset) +{ + m_nBuffers = bufferList->CountLinks(); + LogAssert(m_nBuffers > 0); + m_bufferArray = new DryadLockedMemoryBuffer *[m_nBuffers]; + + LogAssert(m_uAllocatedSize == 0); + size_t i; + DrBListEntry* listEntry = bufferList->GetHead(); + for (i=0; i < m_nBuffers; ++i) + { + LogAssert(listEntry != NULL); + m_bufferArray[i] = (bufferList->CastOut(listEntry))->GetData(); + m_bufferArray[i]->IncRef(); + m_uAllocatedSize += m_bufferArray[i]->GetAvailableSize(); + listEntry = bufferList->GetNext(listEntry); + } + LogAssert(listEntry == NULL); + + Initialise(startOffset, endOffset); +} + +void RChannelReaderBuffer::Initialise(Size_t startOffset, Size_t endOffset) +{ + m_fIsGrowable = false; + + LogAssert(startOffset <= m_bufferArray[0]->GetAvailableSize()); + LogAssert(endOffset <= m_bufferArray[m_nBuffers-1]->GetAvailableSize()); + Size_t tailCutLength = + m_bufferArray[m_nBuffers-1]->GetAvailableSize() - endOffset; + LogAssert(m_uAllocatedSize >= (startOffset + tailCutLength)); + m_uAllocatedSize -= (startOffset + tailCutLength); + + SetAvailableSize(m_uAllocatedSize); + + m_initialStartOffset = startOffset; + m_finalEndOffset = endOffset; +} + +RChannelReaderBuffer::~RChannelReaderBuffer() +{ + size_t i; + + for (i=0; iDecRef(); + } + delete [] m_bufferArray; +} + +void* RChannelReaderBuffer::GetDataAddress(Size_t uOffset, + Size_t *puSize, + Size_t *puPriorSize) +{ + if (uOffset >= m_uAllocatedSize) + { + return NULL; + } + + Current current(m_initialStartOffset, m_finalEndOffset, + m_bufferArray, m_nBuffers, uOffset); + + void* dataAddress = current.GetDataAddress(m_bufferArray, + uOffset, + puSize, + puPriorSize); + + return dataAddress; +} + +void RChannelReaderBuffer::IncreaseAllocatedSize(Size_t uSize) +{ + LogAssert(uSize <= m_uAllocatedSize); +} + + +RChannelWriterBuffer:: + RChannelWriterBuffer(RChannelBufferWriter* bufferProvider, + DryadFixedBufferList* bufferList) +{ + m_bufferProvider = bufferProvider; + m_bufferList = bufferList; + if (m_bufferList->IsEmpty() == false) + { + LogAssert(m_bufferList->CountLinks() == 1); + DryadFixedMemoryBuffer* currentBuffer = + m_bufferList->CastOut(m_bufferList->GetHead()); + + LogAssert(currentBuffer->GetAllocatedSize() > + currentBuffer->GetAvailableSize()); + + m_baseBufferOffset = currentBuffer->GetAvailableSize(); + m_uAllocatedSize = + currentBuffer->GetAllocatedSize() - m_baseBufferOffset; + } + else + { + m_baseBufferOffset = 0; + m_uAllocatedSize = 0; + } + + m_currentBufferOffset = 0; + m_currentBaseOffset = m_baseBufferOffset; + + m_availableHighWaterMark = 0; + m_availableStartOffset = m_baseBufferOffset; + m_availableBufferOffset = 0; + m_lastAvailableBuffer = m_bufferList->GetHead(); +} + +RChannelWriterBuffer::~RChannelWriterBuffer() +{ +} + +void* RChannelWriterBuffer::GetDataAddress(Size_t uOffset, + Size_t *puSize, + Size_t *puPriorSize) +{ + if (m_bufferList->IsEmpty()) + { + return NULL; + } + else + { + LogAssert(uOffset >= m_currentBufferOffset); + + Size_t offsetInAvailable = uOffset - m_currentBufferOffset; + Size_t offsetInBuffer = offsetInAvailable + m_currentBaseOffset; + + DryadFixedMemoryBuffer* buffer = + m_bufferList->CastOut(m_bufferList->GetTail()); + void* dataAddress = + buffer->GetDataAddress(offsetInBuffer, puSize, puPriorSize); + + if (puPriorSize != NULL && *puPriorSize > offsetInAvailable) + { + *puPriorSize = offsetInAvailable; + } + + return dataAddress; + } +} + +void RChannelWriterBuffer::IncreaseAllocatedSize(Size_t uSize) +{ + while (m_uAllocatedSize < uSize) + { + DryadFixedMemoryBuffer* newBuffer = + m_bufferProvider->GetNextWriteBuffer(); + m_bufferList->InsertAsTail(m_bufferList->CastIn(newBuffer)); + m_currentBufferOffset = m_uAllocatedSize; + m_uAllocatedSize += newBuffer->GetAllocatedSize(); + } + + /* this is zero for all but (optionally) the first buffer in the + list */ + m_currentBaseOffset = 0; +} + +void RChannelWriterBuffer::InternalSetAvailableSize(Size_t uSize) +{ + DrMemoryBuffer::InternalSetAvailableSize(uSize); + + LogAssert(uSize >= m_availableHighWaterMark); + + if (m_lastAvailableBuffer == NULL) + { + LogAssert(m_availableBufferOffset == 0); + LogAssert(m_availableHighWaterMark == 0); + LogAssert(m_availableStartOffset == m_baseBufferOffset); + + if (m_bufferList->IsEmpty()) + { + LogAssert(uSize == 0); + return; + } + + m_lastAvailableBuffer = m_bufferList->GetHead(); + } + + DrBListEntry* listEntry = m_lastAvailableBuffer; + if (listEntry == NULL) + { + listEntry = m_bufferList->GetHead(); + LogAssert(listEntry != NULL); + } + + do + { + DryadFixedMemoryBuffer* b = m_bufferList->CastOut(listEntry); + listEntry = m_bufferList->GetNext(listEntry); + + if (listEntry == NULL) + { + LogAssert(uSize >= m_availableBufferOffset); + Size_t thisAvailable = uSize - m_availableBufferOffset; + + LogAssert(m_availableStartOffset + thisAvailable <= + b->GetAllocatedSize()); + b->SetAvailableSize(m_availableStartOffset + thisAvailable); + } + else + { + b->SetAvailableSize(b->GetAllocatedSize()); + m_availableBufferOffset += + b->GetAllocatedSize() - m_availableStartOffset; + LogAssert(m_availableBufferOffset < uSize); + m_lastAvailableBuffer = listEntry; + m_availableStartOffset = 0; + } + } while (listEntry != NULL); + + m_availableHighWaterMark = uSize; +} + +ChannelMemoryBufferWriter:: + ChannelMemoryBufferWriter(DrMemoryBuffer* writeBuffer, + DryadFixedBufferList* bufferList) : + DrMemoryBufferWriter(writeBuffer) +{ + m_bufferList = bufferList; + + if (m_bufferList->IsEmpty()) + { + m_initialBoundary = 0; + } + else + { + LogAssert(m_bufferList->CountLinks() == 1); + DryadFixedMemoryBuffer* buffer = + m_bufferList->CastOut(m_bufferList->GetHead()); + m_initialBoundary = buffer->GetAvailableSize(); + LogAssert(m_initialBoundary < buffer->GetAllocatedSize()); + } + + m_lastRecordBoundary = 0; +} + +bool ChannelMemoryBufferWriter::MarkRecordBoundary() +{ + if (m_bufferList->IsEmpty()) + { + LogAssert(m_initialBoundary == 0); + LogAssert(m_lastRecordBoundary == 0); + return false; + } + else if (m_bufferList->CountLinks() == 1) + { + DryadFixedMemoryBuffer* buffer = + m_bufferList->CastOut(m_bufferList->GetHead()); + Size_t offset = GetBufferOffset(); + Size_t allocated = buffer->GetAllocatedSize(); + + if (m_initialBoundary + offset >= allocated) + { + DrError errTmp = FlushMemoryWriter(); + // sammck: should this assert here or return an error + LogAssert(errTmp == DrError_OK); + } + + Size_t boundary = buffer->GetAvailableSize(); + if (boundary == allocated) + { + return true; + } + else + { + m_lastRecordBoundary = offset; + LogAssert(m_initialBoundary + m_lastRecordBoundary < allocated); + return false; + } + } + else + { + return true; + } +} + +Size_t ChannelMemoryBufferWriter::GetLastRecordBoundary() +{ + return m_initialBoundary + m_lastRecordBoundary; +} + diff --git a/DryadVertex/VertexHost/system/channel/src/recorditem.cpp b/DryadVertex/VertexHost/system/channel/src/recorditem.cpp new file mode 100644 index 0000000..1f8cda7 --- /dev/null +++ b/DryadVertex/VertexHost/system/channel/src/recorditem.cpp @@ -0,0 +1,1140 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "recorditem.h" + +#pragma unmanaged + +RecordArrayBase::RecordArrayBase() +{ + m_recordArray = NULL; + m_recordArraySize = 0; + m_numberOfRecords = 0; + m_nextRecord = 0; + m_serializedAny = 0; +} + +RecordArrayBase::~RecordArrayBase() +{ + LogAssert(m_buffer == NULL); + LogAssert(m_recordArray == NULL); +} + +void RecordArrayBase::FreeStorage() +{ + m_buffer = NULL; + if (m_recordArraySize > 0) + { + FreeTypedArray(m_recordArray); + } + m_recordArray = NULL; + m_recordArraySize = 0; + m_numberOfRecords = 0; + m_nextRecord = 0; +} + +void RecordArrayBase::PopRecord() +{ + LogAssert(m_nextRecord > 0); + --m_nextRecord; +} + +UInt32 RecordArrayBase::GetNumberOfRecords() const +{ + return m_numberOfRecords; +} + +UInt64 RecordArrayBase::GetNumberOfSubItems() const +{ + return GetNumberOfRecords(); +} + +void RecordArrayBase::TruncateSubItems(UInt64 numberOfSubItems) +{ + LogAssert(numberOfSubItems < (UInt64) m_numberOfRecords); + SetRecordIndex((UInt32) numberOfSubItems); + Truncate(); + ResetRecordPointer(); +} + +UInt64 RecordArrayBase::GetItemSize() const +{ + return (UInt64) GetRecordSize() * (UInt64) GetNumberOfRecords(); +} + +void* RecordArrayBase::GetRecordArrayUntyped() +{ + return m_recordArray; +} + +void RecordArrayBase::SetNumberOfRecords(UInt32 numberOfRecords) +{ + if (m_recordArraySize < numberOfRecords) + { + if (m_recordArraySize > 0) + { + LogAssert(m_buffer == NULL); + FreeTypedArray(m_recordArray); + } + + m_recordArraySize = numberOfRecords; + if (m_recordArraySize > 0) + { + m_recordArray = MakeTypedArray(m_recordArraySize); + } + else + { + m_recordArray = NULL; + } + } + else if (m_recordArraySize == 0) + { + LogAssert(m_buffer != NULL); + LogAssert(m_recordArray != NULL); + m_recordArray = NULL; + } + + m_buffer = NULL; + + m_numberOfRecords = numberOfRecords; + ResetRecordPointer(); +} + +void RecordArrayBase::ResetRecordPointer() +{ + m_nextRecord = 0; +} + +void* RecordArrayBase::GetRecordUntyped(UInt32 index) +{ + LogAssert(index < m_numberOfRecords); + char* cPtr = (char *) m_recordArray; + return &(cPtr[GetRecordSize() * index]); +} + +// +// Return pointer to next record in array +// +void* RecordArrayBase::NextRecordUntyped() +{ + if (m_nextRecord == m_numberOfRecords) + { + // + // If no more records, return null + // + return NULL; + } + else + { + // + // If valid record index, set return value to correct + // offset from array start and increment "next record" index + // + LogAssert(m_nextRecord < m_numberOfRecords); + char* cPtr = (char *) m_recordArray; + void* r = &(cPtr[GetRecordSize() * m_nextRecord]); + ++m_nextRecord; + return r; + } +} + +// +// Return the next record index +// +UInt32 RecordArrayBase::GetRecordIndex() const +{ + return m_nextRecord; +} + +void RecordArrayBase::SetRecordIndex(UInt32 index) +{ + LogAssert(index <= m_numberOfRecords); + m_nextRecord = index; +} + +// +// Return whether there are any additional records available +// +bool RecordArrayBase::AtEnd() +{ + return (m_nextRecord == m_numberOfRecords); +} + +void RecordArrayBase::Truncate() +{ + if (m_nextRecord*2 < m_recordArraySize) + { + void* newArray = TransferTruncatedArray(m_recordArray, m_nextRecord); + FreeTypedArray(m_recordArray); + m_recordArray = newArray; + m_recordArraySize = m_nextRecord; + } + + m_numberOfRecords = m_nextRecord; +} + +void RecordArrayBase::TransferRecord(void* dstRecord, void* srcRecord) +{ + TransferRecord(NULL, dstRecord, NULL, srcRecord); +} + +void* RecordArrayBase::TransferTruncatedArray(void* srcArrayUntyped, + UInt32 numberOfRecords) +{ + if (numberOfRecords == 0) + { + return NULL; + } + + void* dstArrayUntyped = MakeTypedArray(numberOfRecords); + LogAssert(dstArrayUntyped != NULL); + + char* srcPtr = (char *) srcArrayUntyped; + char* dstPtr = (char *) dstArrayUntyped; + UInt32 i; + for (i=0; iGetAvailableSize(); + LogAssert(availableSize >= startOffset); + + Size_t remainingSize = buffer->GetAvailableSize() - startOffset; + if (remainingSize < neededSize) + { + SetNumberOfRecords(0); + return 0; + } + else + { + SetNumberOfRecords(nRecords); + buffer->Read(startOffset, m_recordArray, neededSize); + return (UInt32) neededSize; + } +} + +DrError RecordArrayBase::ReadFinalArray(DrMemoryBuffer* buffer) +{ + LogAssert(buffer->GetAvailableSize() < GetRecordSize()); + return DrError_EndOfStream; +} + +UInt32 RecordArrayBase::AttachArray(DryadLockedMemoryBuffer* buffer, + Size_t startOffset) +{ + Size_t remainingSize; + void* recordArray = (void *) buffer->GetReadAddress(startOffset, + &remainingSize); + LogAssert(remainingSize == buffer->GetAvailableSize() - startOffset); + + UInt32 numberOfRecords = (UInt32) (remainingSize / GetRecordSize()); + if (numberOfRecords == 0) + { + SetNumberOfRecords(0); + return 0; + } + else + { + FreeStorage(); + m_buffer = buffer; + m_recordArray = recordArray; + m_numberOfRecords = numberOfRecords; + return (UInt32) (m_numberOfRecords * GetRecordSize()); + } +} + +DrError RecordArrayBase::DeSerialize(DrResettableMemoryReader* reader, + Size_t availableSize) +{ + UInt32 numberOfRecords = (UInt32) (availableSize / GetRecordSize()); + if (numberOfRecords < 1) + { + return DrError_EndOfStream; + } + + if (GetNumberOfRecords() == 0) + { + SetNumberOfRecords(numberOfRecords); + } + + if (numberOfRecords >= GetNumberOfRecords()) + { + numberOfRecords = GetNumberOfRecords(); + } + else + { + SetNumberOfRecords(numberOfRecords); + } + + Size_t dataLength = (Size_t) (numberOfRecords * GetRecordSize()); + DrError err = reader->ReadBytes((BYTE *) GetRecordArrayUntyped(), + dataLength); + LogAssert(err == DrError_OK); + + return DrError_OK; +} + +void RecordArrayBase::StartSerializing() +{ + if (m_serializedAny == false) + { + ResetRecordPointer(); + m_serializedAny = true; + } +} + +DrError RecordArrayBase::Serialize(ChannelMemoryBufferWriter* writer) +{ + StartSerializing(); + + void* nextRecord; + bool filledBuffer = writer->MarkRecordBoundary(); + LogAssert(filledBuffer == false); + while (filledBuffer == false && + (nextRecord = NextRecordUntyped()) != NULL) + { + writer->WriteBytes((BYTE *) nextRecord, GetRecordSize()); + filledBuffer = writer->MarkRecordBoundary(); + } + + if (filledBuffer) + { + return DrError_IncompleteOperation; + } + else + { + return DrError_OK; + } +} + +PackedRecordArrayParserBase:: + PackedRecordArrayParserBase(DObjFactoryBase* factory) +{ + m_factory = factory; +} + +PackedRecordArrayParserBase::~PackedRecordArrayParserBase() +{ +} + +RChannelItem* PackedRecordArrayParserBase:: + ParseNextItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + Size_t* pOutLength) +{ + LogAssert(bufferList->IsEmpty() == false); + + RecordArrayBase *item = + (RecordArrayBase *) m_factory->AllocateObjectUntyped(); + + RChannelBufferData* headBuffer = + bufferList->CastOut(bufferList->GetHead()); + + if (bufferList->GetHead() != bufferList->GetTail()) + { + RChannelBufferData* tailBuffer = + bufferList->CastOut(bufferList->GetTail()); + Size_t tailBufferSize = + tailBuffer->GetData()->GetAvailableSize(); + + DrRef combinedBuffer; + combinedBuffer.Attach(new RChannelReaderBuffer(bufferList, + startOffset, + tailBufferSize)); + + *pOutLength = item->ReadArray(combinedBuffer, 0, 1); + } + else + { + *pOutLength = item->AttachArray(headBuffer->GetData(), startOffset); + } + + if (*pOutLength == 0) + { + m_factory->FreeObjectUntyped(item); + item = NULL; + } + + return item; +} + +RChannelItem* PackedRecordArrayParserBase:: + ParsePartialItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + RChannelBufferMarker* + markerBuffer) +{ + if (bufferList->IsEmpty()) + { + return NULL; + } + + RChannelBufferData* tailBuffer = + bufferList->CastOut(bufferList->GetTail()); + Size_t tailBufferSize = + tailBuffer->GetData()->GetAvailableSize(); + + DrRef combinedBuffer; + combinedBuffer.Attach(new RChannelReaderBuffer(bufferList, + startOffset, + tailBufferSize)); + + RecordArrayBase *item = + (RecordArrayBase *) m_factory->AllocateObjectUntyped(); + + DrError err = item->ReadFinalArray(combinedBuffer); + + if (err == DrError_OK) + { + return item; + } + + m_factory->FreeObjectUntyped(item); + + if (err == DrError_EndOfStream) + { + return NULL; + } + else + { + return RChannelMarkerItem::CreateErrorItem(RChannelItem_ParseError, + err); + } +} + +PackedRecordArrayParser::PackedRecordArrayParser(DObjFactoryBase* factory) : + PackedRecordArrayParserBase(factory) +{ +} + +PackedRecordArrayParser::~PackedRecordArrayParser() +{ +} + + +RecordArrayReaderBase::RecordArrayReaderBase() +{ +} + +RecordArrayReaderBase::RecordArrayReaderBase(SyncItemReaderBase* reader) +{ + Initialize(reader); +} + +RecordArrayReaderBase::~RecordArrayReaderBase() +{ + delete [] m_currentRecord; + delete [] m_itemCache; +} + +void RecordArrayReaderBase::Initialize(SyncItemReaderBase* reader) +{ + m_reader = reader; + m_arrayItem = NULL; + m_cacheSize = 32; + m_currentRecord = new void* [m_cacheSize]; + m_itemCache = new RChannelItemRef [m_cacheSize]; + m_valid = 0; + m_cachedItemCount = 0; +} + +void RecordArrayReaderBase::DiscardCachedItems() +{ + UInt32 i; + for (i=0; iReadItemSync(&m_item); + LogAssert(status == DrError_OK); + LogAssert(m_item != NULL); + RChannelItemType itemType = m_item->GetType(); + if (itemType == RChannelItem_Data) + { + m_arrayItem = (RecordArrayBase *) (m_item.Ptr()); + } + else if (RChannelItem::IsTerminationItem(itemType) == false) + { + /* this is a marker; skip it and fetch the next one */ + m_item = NULL; + } + } + + if (m_arrayItem != NULL) + { + m_currentRecord[slotNumber] = m_arrayItem->NextRecordUntyped(); + LogAssert(m_currentRecord[slotNumber] != NULL); + if (m_arrayItem->AtEnd()) + { + m_itemCache[m_cachedItemCount].TransferFrom(m_item); + LogAssert(m_item == NULL); + m_arrayItem = NULL; + ++m_cachedItemCount; + } + return true; + } + else + { + return false; + } +} + +void RecordArrayReaderBase::PushBack() +{ + PushBack(true); +} + +void RecordArrayReaderBase::PushBack(bool pushValid) +{ + if (pushValid) + { + LogAssert(m_valid > 0); + } + + if (m_arrayItem == NULL) + { + LogAssert(m_item == NULL); + LogAssert(m_cachedItemCount > 0); + --m_cachedItemCount; + m_item.TransferFrom(m_itemCache[m_cachedItemCount]); + m_arrayItem = (RecordArrayBase *) (m_item.Ptr()); + } + + if (m_arrayItem->GetRecordIndex() == 0) + { + DrLogA("Can't push back more than one record"); + } + + m_arrayItem->PopRecord(); + if (pushValid) + { + --m_valid; + } +} + +bool RecordArrayReaderBase::Advance() +{ + DiscardCachedItems(); + if (AdvanceInternal(0)) + { + m_valid = 1; + return true; + } + else + { + m_valid = 0; + return false; + } +} + +UInt32 RecordArrayReaderBase::AdvanceBlock(UInt32 validEntriesRequested) +{ + DiscardCachedItems(); + if (validEntriesRequested > m_cacheSize) + { + delete [] m_currentRecord; + delete [] m_itemCache; + m_cacheSize = validEntriesRequested; + m_currentRecord = new void* [m_cacheSize]; + m_itemCache = new RChannelItemRef [m_cacheSize]; + } + + UInt32 i; + for (i=0; iGetErrorFromItem(); + } +} + +RChannelItem* RecordArrayReaderBase::GetTerminationItem() +{ + if (m_item == NULL || + RChannelItem::IsTerminationItem(m_item->GetType()) == false) + { + return NULL; + } + else + { + return m_item; + } +} + +UInt32 RecordArrayReaderBase::GetValidCount() const +{ + return m_valid; +} + + +RecordArrayWriterBase::RecordArrayWriterBase() +{ + m_currentRecord = NULL; + m_itemCache = NULL; + m_cachedItemCount = 0; +} + +RecordArrayWriterBase::RecordArrayWriterBase(SyncItemWriterBase* writer, + DObjFactoryBase* factory) +{ + m_currentRecord = NULL; + m_itemCache = NULL; + m_cachedItemCount = 0; + Initialize(writer, factory); +} + +RecordArrayWriterBase::~RecordArrayWriterBase() +{ + Destroy(); +} + +void RecordArrayWriterBase::Destroy() +{ + Flush(); + delete [] m_currentRecord; + delete [] m_itemCache; +} + +void RecordArrayWriterBase::Initialize(SyncItemWriterBase* writer, + DObjFactoryBase* factory) +{ + Destroy(); + m_factory = factory; + m_writer = writer; + m_cacheSize = 32; + m_currentRecord = new void* [m_cacheSize]; + m_itemCache = new DrRef [m_cacheSize]; + m_valid = 0; + m_pushBackIndex = 0; + m_cachedItemCount = 0; +} + +void RecordArrayWriterBase::SetWriter(SyncItemWriterBase* writer) +{ + m_writer = writer; +} + +DrError RecordArrayWriterBase::GetWriterStatus() +{ + return m_writer->GetWriterStatus(); +} + +// +// Write out any cached items +// +void RecordArrayWriterBase::SendCachedItems() +{ + UInt32 i; + for (i=0; iResetRecordPointer(); + m_writer->WriteItemSync(m_itemCache[i]); + m_itemCache[i] = NULL; + } + + m_cachedItemCount = 0; +} + +// +// Move to next available object if one is available. +// +void RecordArrayWriterBase::AdvanceInternal(UInt32 slotNumber) +{ + if (m_item == NULL) + { + // + // If no current record array, generate one + // + RecordArrayBase *item = (RecordArrayBase *) + m_factory->AllocateObjectUntyped(); + m_item.Attach(item); + } + + // + // Set current record to the next available record + // + m_currentRecord[slotNumber] = m_item->NextRecordUntyped(); + LogAssert(m_currentRecord[slotNumber] != NULL); + + // + // If there are no more records in this record array, transfer the + // record array into the item cache to remember that all items are cached. + // + if (m_item->AtEnd()) + { + m_itemCache[m_cachedItemCount].TransferFrom(m_item); + LogAssert(m_item == NULL); + ++m_cachedItemCount; + } +} + +// +// +// +void RecordArrayWriterBase::MakeValid() +{ + // + // Clear out any cached items to reset + // + SendCachedItems(); + + // + // Get the index of the next available record + // + if (m_item == NULL) + { + m_pushBackIndex = 0; + } + else + { + m_pushBackIndex = m_item->GetRecordIndex(); + } + + // + // Move to next available object if one is available and mark as valid + // + AdvanceInternal(0); + m_valid = 1; +} + +void RecordArrayWriterBase::MakeValidBlock(UInt32 validEntriesRequested) +{ + SendCachedItems(); + + if (m_item == NULL) + { + m_pushBackIndex = 0; + } + else + { + m_pushBackIndex = m_item->GetRecordIndex(); + } + + if (validEntriesRequested > m_cacheSize) + { + delete [] m_currentRecord; + delete [] m_itemCache; + m_cacheSize = validEntriesRequested; + m_currentRecord = new void* [m_cacheSize]; + m_itemCache = new DrRef [m_cacheSize]; + } + + UInt32 i; + for (i=0; i 0); + if (m_cachedItemCount > 0) + { + /* throw away back to the item we were using when we started */ + m_item.TransferFrom(m_itemCache[0]); + UInt32 i; + for (i=1; iSetRecordIndex(m_pushBackIndex); + m_pushBackIndex = 0; + m_valid = 0; +} + +void* RecordArrayWriterBase::ReadValidUntyped(UInt32 index) +{ + LogAssert(index < m_valid); + return m_currentRecord[index]; +} + +void RecordArrayWriterBase::Flush() +{ + if (m_item != NULL) + { + m_item->Truncate(); + m_itemCache[m_cachedItemCount].TransferFrom(m_item); + LogAssert(m_item == NULL); + ++m_cachedItemCount; + } + SendCachedItems(); + m_valid = 0; +} + +void RecordArrayWriterBase::Terminate() +{ + Flush(); + RChannelItemRef item; + item.Attach(RChannelMarkerItem::Create(RChannelItem_EndOfStream, false)); + m_writer->WriteItemSync(item); +} + +AlternativeRecordParserBase:: + AlternativeRecordParserBase(DObjFactoryBase* factory) +{ + m_factory = factory; + m_function = NULL; +} + +AlternativeRecordParserBase:: + AlternativeRecordParserBase(DObjFactoryBase* factory, + RecordDeSerializerFunction* function) +{ + m_factory = factory; + m_function = function; +} + +AlternativeRecordParserBase::~AlternativeRecordParserBase() +{ +} + +void AlternativeRecordParserBase::ResetParser() +{ + m_pendingErrorItem = NULL; +} + +DrError AlternativeRecordParserBase:: + DeSerializeArray(RecordArrayBase* array, + DrResettableMemoryReader* reader, + Size_t availableSize) +{ + if (array->GetNumberOfRecords() == 0) + { + array->SetNumberOfRecords(RChannelItem::s_defaultRecordBatchSize); + } + + Size_t sizeUsed = 0; + Size_t remainingSize = availableSize; + DrError err = DrError_OK; + + array->ResetRecordPointer(); + void* nextRecord; + while (remainingSize > 0 && + (nextRecord = array->NextRecordUntyped()) != NULL) + { + err = DeSerializeUntyped(nextRecord, reader, remainingSize, false); + if (err != DrError_OK) + { + array->PopRecord(); + reader->ResetToBufferOffset(sizeUsed); + break; + } + + sizeUsed = reader->GetBufferOffset(); + LogAssert(sizeUsed <= availableSize); + remainingSize = availableSize - sizeUsed; + } + + array->Truncate(); + array->ResetRecordPointer(); + + if (err == DrError_EndOfStream && array->AtEnd() == false) + { + /* AtEnd() == false after Truncate+Reset means there's at + least one item */ + err = DrError_OK; + } + + return err; +} + +RChannelItem* AlternativeRecordParserBase:: + ParseNextItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + Size_t* pOutLength) +{ + if (m_pendingErrorItem != NULL) + { + return m_pendingErrorItem.Detach(); + } + + LogAssert(bufferList->IsEmpty() == false); + + RChannelBufferData* tailBuffer = + bufferList->CastOut(bufferList->GetTail()); + Size_t tailBufferSize = + tailBuffer->GetData()->GetAvailableSize(); + + DrRef buffer; + buffer.Attach(new RChannelReaderBuffer(bufferList, + startOffset, + tailBufferSize)); + + DrResettableMemoryReader reader(buffer); + + RecordArrayBase* recordArray = + (RecordArrayBase *) m_factory->AllocateObjectUntyped(); + DrError err = DeSerializeArray(recordArray, &reader, + buffer->GetAvailableSize()); + + if (err == DrError_OK) + { + *pOutLength = reader.GetBufferOffset(); + return recordArray; + } + + m_factory->FreeObjectUntyped(recordArray); + + if (err == DrError_EndOfStream) + { + return NULL; + } + else + { + return RChannelMarkerItem::CreateErrorItem(RChannelItem_ParseError, + err); + } +} + +RChannelItem* AlternativeRecordParserBase:: + ParsePartialItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + RChannelBufferMarker* + markerBuffer) +{ + if (m_pendingErrorItem != NULL) + { + return m_pendingErrorItem.Detach(); + } + + if (bufferList->IsEmpty()) + { + return NULL; + } + + RChannelBufferData* tailBuffer = + bufferList->CastOut(bufferList->GetTail()); + Size_t tailBufferSize = + tailBuffer->GetData()->GetAvailableSize(); + + DrRef buffer; + buffer.Attach(new RChannelReaderBuffer(bufferList, + startOffset, + tailBufferSize)); + + DrResettableMemoryReader reader(buffer); + + RecordArrayBase* recordArray = + (RecordArrayBase *) m_factory->AllocateObjectUntyped(); + recordArray->SetNumberOfRecords(1); + void* nextRecord = recordArray->NextRecordUntyped(); + DrError err = DeSerializeUntyped(nextRecord, &reader, + buffer->GetAvailableSize(), true); + + if (err == DrError_OK) + { + recordArray->ResetRecordPointer(); + return recordArray; + } + + m_factory->FreeObjectUntyped(recordArray); + + if (err == DrError_EndOfStream) + { + return NULL; + } + else + { + return RChannelMarkerItem::CreateErrorItem(RChannelItem_ParseError, + err); + } +} + +DrError AlternativeRecordParserBase:: + DeSerializeUntyped(void* record, + DrMemoryBufferReader* reader, + Size_t availableSize, bool lastRecordInStream) +{ + return (*m_function)(record, reader, availableSize, lastRecordInStream); +} + +UntypedAlternativeRecordParser:: + UntypedAlternativeRecordParser(DObjFactoryBase* factory, + RecordDeSerializerFunction* function) : + AlternativeRecordParserBase(factory, function) +{ +} + +StdAlternativeRecordParserFactory:: + StdAlternativeRecordParserFactory(DObjFactoryBase* factory, + RecordDeSerializerFunction* function) +{ + m_factory = factory; + m_function = function; +} + +void StdAlternativeRecordParserFactory:: + MakeParser(RChannelItemParserRef* pParser, + DVErrorReporter* errorReporter) +{ + pParser->Attach(new UntypedAlternativeRecordParser(m_factory, + m_function)); +} + + +AlternativeRecordMarshalerBase::AlternativeRecordMarshalerBase() +{ + m_function = NULL; +} + +AlternativeRecordMarshalerBase:: + AlternativeRecordMarshalerBase(RecordSerializerFunction* function) +{ + m_function = function; +} + +AlternativeRecordMarshalerBase::~AlternativeRecordMarshalerBase() +{ +} + +void AlternativeRecordMarshalerBase:: + SetFunction(RecordSerializerFunction* function) +{ + m_function = function; +} + +DrError AlternativeRecordMarshalerBase:: + MarshalItem(ChannelMemoryBufferWriter* writer, + RChannelItem* item, + bool flush, + RChannelItemRef* pFailureItem) +{ + DrError err = DrError_OK; + + if (item->GetType() != RChannelItem_Data) + { + return err; + } + + RecordArrayBase* arrayItem = (RecordArrayBase *) item; + + arrayItem->StartSerializing(); + + Size_t currentPosition = 0; + + void* nextRecord; + bool filledBuffer = writer->MarkRecordBoundary(); + LogAssert(filledBuffer == false); + while (err == DrError_OK && + filledBuffer == false && + (nextRecord = arrayItem->NextRecordUntyped()) != NULL) + { + currentPosition = writer->GetBufferOffset(); + err = SerializeUntyped(nextRecord, writer); + if (err == DrError_OK) + { + filledBuffer = writer->MarkRecordBoundary(); + } + } + + if (err != DrError_OK) + { + LogAssert(writer->GetStatus() == DrError_OK); + writer->SetBufferOffset(currentPosition); + + pFailureItem->Attach(RChannelMarkerItem:: + CreateErrorItem(RChannelItem_MarshalError, + err)); + return DryadError_ChannelAbort; + } + + if (filledBuffer) + { + return DrError_IncompleteOperation; + } + + return DrError_OK; +} + +DrError AlternativeRecordMarshalerBase:: + SerializeUntyped(void* record, + DrMemoryBufferWriter* writer) +{ + return (*m_function)(record, writer); +} + +UntypedAlternativeRecordMarshaler:: + UntypedAlternativeRecordMarshaler(RecordSerializerFunction* function) : + AlternativeRecordMarshalerBase(function) +{ +} + +StdAlternativeRecordMarshalerFactory:: + StdAlternativeRecordMarshalerFactory(RecordSerializerFunction* function) +{ + m_function = function; +} + +void StdAlternativeRecordMarshalerFactory:: + MakeMarshaler(RChannelItemMarshalerRef* pMarshaler, + DVErrorReporter* errorReporter) +{ + pMarshaler->Attach(new UntypedAlternativeRecordMarshaler(m_function)); +} diff --git a/DryadVertex/VertexHost/system/classlib/classlib.vcxproj b/DryadVertex/VertexHost/system/classlib/classlib.vcxproj new file mode 100644 index 0000000..08e7ef5 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/classlib.vcxproj @@ -0,0 +1,205 @@ + + + + + Debug + Win32 + + + Debug + x64 + + + Release + Win32 + + + Release + x64 + + + + {016E71D3-9A6F-425C-AB4F-8C5EDEFFE7FA} + classlib + Win32Proj + + + + StaticLibrary + + + StaticLibrary + + + StaticLibrary + Unicode + false + + + StaticLibrary + Unicode + false + + + + + + + + + + + + + + + + + + + <_ProjectFileVersion>10.0.40219.1 + Debug\ + Debug\ + $(Platform)\$(Configuration)\ + $(Platform)\$(Configuration)\ + Release\ + Release\ + $(Platform)\$(Configuration)\ + $(Platform)\$(Configuration)\ + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + + + + Disabled + WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions) + true + EnableFastChecks + MultiThreadedDebugDLL + + + Level3 + EditAndContinue + + + + + X64 + + + Disabled + ..\..\common\include;..\..\classlib\include;include;%(AdditionalIncludeDirectories) + WIN32;_DEBUG;_LIB;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + false + Default + MultiThreadedDebugDLL + + + Level3 + ProgramDatabase + + + + + WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + + + + + X64 + + + WIN32;NDEBUG;_LIB;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + ..\..\common\include;..\..\classlib\include;include;%(AdditionalIncludeDirectories) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/DryadVertex/VertexHost/system/classlib/include/DrBList.h b/DryadVertex/VertexHost/system/classlib/include/DrBList.h new file mode 100644 index 0000000..4eee8da --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrBList.h @@ -0,0 +1,411 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/* + * DrBList + * + * Declares a utility class for managing bi-linked listed + */ + +#ifndef __DRYADBLIST_H__ +#define __DRYADBLIST_H__ + + +/* + * DrBListEntry + * The list entry class is intentionally opaque. Operations should always + * be done using the list object. The list entry class should be embedded + * in the object you want to store, and the DR_GET_CONTAINER macro + * used to map back to the actual object + */ + +class DrBListEntry +{ +public: + + //Standard c'tor + inline DrBListEntry(); + + //Return TRUE if this entry is in a list. + //N.B. Remember to use the locking conventions of the list object! + inline BOOL IsInList(); + + //Remove this entry from whatever list its currently in + //N.B. Remember to use the locking conventions of the list object! + inline void Remove(); + +private: + friend class DrBList; + DrBListEntry * m_pNext, * m_pPrev; +}; + +/* + * DrBList + * Manages a bi-linked list of DrBListEntry's + */ + +class DrBList +{ +public: + + /* + * Lifecycle Management + */ + + //Standard c'tor + inline DrBList(); + + + /* + * Inserting entries to list + */ + + //Insert specified entry as head of list + //Asserts if pEntry is already in an existing list + inline void InsertAsHead(DrBListEntry * pEntry); + + //Insert specified entry as tail of list + //Asserts if pEntry is already in an existing list + inline void InsertAsTail(DrBListEntry * pEntry); + + //Insert pNewEntry as the next entry after pCurrentEntry + //Asserts if pNewEntry is already in an existing list + inline void InsertAsNext(DrBListEntry * pCurrentEntry, DrBListEntry * pNewEntry); + + //Insert pNewEntry as the previoud entry before pCurrentEntry + //Asserts if pNewEntry is already in an existing list + inline void InsertAsPrev(DrBListEntry * pCurrentEntry, DrBListEntry * pNewEntry); + + //Remove pEntry from a list and transition it to the head of this one + //pEntry can either be in this list or an unrelated one + //Asserts if pEntry is not already in an existing list + inline void TransitionToHead(DrBListEntry * pEntry); + + //Remove pEntry from a list and transition it to the tail of this one + //pEntry can either be in this list or an unrelated one + //Asserts if pEntry is not already in an existing list + inline void TransitionToTail(DrBListEntry * pEntry); + + //Remove all the entries from pList and transition them to the + //head of this one. pList must not be this list + inline void TransitionToHead(DrBList * pList); + + //Remove all the entries from pList and transition them to the + //tail of this one. pList must not be this list + inline void TransitionToTail(DrBList * pList); + + + /* + * Removing entries from list + */ + + //Remove entry from head of list and return it + //Returns NULL if list is empty + inline DrBListEntry * RemoveHead(); + + //Remove entry from tail of list and return it + //Returns NULL if list is empty + inline DrBListEntry * RemoveTail(); + + //Remove specified entry from list + inline DrBListEntry * Remove(DrBListEntry * pEntry); + + + /* + * Accessing entries in the list + */ + + //Scan the list and if pEntry is found return it + //Returns NULL if pEntry is not found + inline DrBListEntry * Find(DrBListEntry * pEntry); + + //Return the head entry in the list or NULL if list is empty + inline DrBListEntry * GetHead(); + + //Return the tail entry in the list or NULL if list is empty + inline DrBListEntry * GetTail(); + + //Return the next entry after pEntry, or NULL if pEntry is the tail + inline DrBListEntry * GetNext(DrBListEntry * pEntry); + + //Return the previous entry before pEntry, or NULL if pEntry is the head + inline DrBListEntry * GetPrev(DrBListEntry * pEntry); + + + /* + * Retrieving state of list + */ + + //Scan the list and count the total number of entries + inline DWORD CountLinks(); + + //Return TRUE if the list is empty + inline BOOL IsEmpty(); + +private: + + DrBListEntry m_dummyEntry; + +}; + + + +/* + * Inline methods for DrBListEntry + */ + + +DrBListEntry::DrBListEntry() +{ + m_pNext=this; + m_pPrev=this; +} + +BOOL DrBListEntry::IsInList() +{ + return (m_pNext!=this); +} + +void DrBListEntry::Remove() +{ + LogAssert(IsInList()); + + m_pNext->m_pPrev = m_pPrev; + m_pPrev->m_pNext = m_pNext; + m_pNext=this; + m_pPrev=this; +} + +/* + * Inline methods for DrBList + */ + + +DrBList::DrBList() +{ + //this space intentionally left blank +} + +void DrBList::InsertAsHead(DrBListEntry * pEntry) +{ + LogAssert(pEntry->IsInList()==FALSE); + + pEntry->m_pNext=m_dummyEntry.m_pNext; + pEntry->m_pPrev=&m_dummyEntry; + m_dummyEntry.m_pNext->m_pPrev=pEntry; + m_dummyEntry.m_pNext=pEntry; +} + +void DrBList::TransitionToHead(DrBListEntry * pEntry) +{ + LogAssert(pEntry->IsInList()); + + //Pull entry from existing list + pEntry->m_pNext->m_pPrev = pEntry->m_pPrev; + pEntry->m_pPrev->m_pNext = pEntry->m_pNext; + //Insert into this list + pEntry->m_pNext=m_dummyEntry.m_pNext; + pEntry->m_pPrev=&m_dummyEntry; + m_dummyEntry.m_pNext->m_pPrev=pEntry; + m_dummyEntry.m_pNext=pEntry; +} + +void DrBList::TransitionToHead(DrBList * pList) +{ + LogAssert(pList != this); + + if (!pList->IsEmpty()) + { + DrBListEntry* otherHead = pList->GetHead(); + DrBListEntry* otherTail = pList->GetTail(); + otherTail->m_pNext=m_dummyEntry.m_pNext; + otherHead->m_pPrev=&m_dummyEntry; + m_dummyEntry.m_pNext->m_pPrev=otherTail; + m_dummyEntry.m_pNext=otherHead; + pList->m_dummyEntry.m_pNext=&(pList->m_dummyEntry); + pList->m_dummyEntry.m_pPrev=&(pList->m_dummyEntry); + } +} + +void DrBList::InsertAsTail(DrBListEntry * pEntry) +{ + LogAssert(pEntry->IsInList()==FALSE); + + pEntry->m_pNext=&m_dummyEntry; + pEntry->m_pPrev=m_dummyEntry.m_pPrev; + m_dummyEntry.m_pPrev->m_pNext=pEntry; + m_dummyEntry.m_pPrev=pEntry; +} + +void DrBList::TransitionToTail(DrBListEntry * pEntry) +{ + LogAssert(pEntry->IsInList()); + //Pull entry from existing list + pEntry->m_pNext->m_pPrev = pEntry->m_pPrev; + pEntry->m_pPrev->m_pNext = pEntry->m_pNext; + //Insert into this list + pEntry->m_pNext=&m_dummyEntry; + pEntry->m_pPrev=m_dummyEntry.m_pPrev; + m_dummyEntry.m_pPrev->m_pNext=pEntry; + m_dummyEntry.m_pPrev=pEntry; +} + +// +// Append list to existing list's tail (uses m_dummyEntry to move links) +// +void DrBList::TransitionToTail(DrBList * pList) +{ + LogAssert(pList != this); + + if (!pList->IsEmpty()) + { + DrBListEntry* otherHead = pList->GetHead(); + DrBListEntry* otherTail = pList->GetTail(); + + // + // link tail's next and head's prev to placeholder's prev + // + otherTail->m_pNext=&m_dummyEntry; + otherHead->m_pPrev=m_dummyEntry.m_pPrev; + + // + // Link placeholder's prev's next to head + // + m_dummyEntry.m_pPrev->m_pNext=otherHead; + + // + // Link placeholder's prev to tail + // + m_dummyEntry.m_pPrev=otherTail; + + // + // Update placeholder next and previous to self + // + pList->m_dummyEntry.m_pNext=&(pList->m_dummyEntry); + pList->m_dummyEntry.m_pPrev=&(pList->m_dummyEntry); + } +} + + +void DrBList::InsertAsNext(DrBListEntry * pCurrentEntry, DrBListEntry * pNewEntry) +{ + LogAssert(pCurrentEntry->IsInList()); + LogAssert(pNewEntry->IsInList()==FALSE); + + pNewEntry->m_pNext=pCurrentEntry->m_pNext; + pNewEntry->m_pPrev=pCurrentEntry; + pCurrentEntry->m_pNext->m_pPrev=pNewEntry; + pCurrentEntry->m_pNext=pNewEntry; +} + +void DrBList::InsertAsPrev(DrBListEntry * pCurrentEntry, DrBListEntry * pNewEntry) +{ + LogAssert(pCurrentEntry->IsInList()); + LogAssert(pNewEntry->IsInList()==FALSE); + + pNewEntry->m_pNext=pCurrentEntry; + pNewEntry->m_pPrev=pCurrentEntry->m_pPrev; + pCurrentEntry->m_pPrev->m_pNext=pNewEntry; + pCurrentEntry->m_pPrev=pNewEntry; +} + +DrBListEntry * DrBList::RemoveHead() +{ + if (m_dummyEntry.m_pNext==&m_dummyEntry) + return NULL; + else + return Remove(m_dummyEntry.m_pNext); +} + +DrBListEntry * DrBList::RemoveTail() +{ + if (m_dummyEntry.m_pPrev==&m_dummyEntry) + return NULL; + else + return Remove(m_dummyEntry.m_pPrev); +} + +DrBListEntry * DrBList::Remove(DrBListEntry * pEntry) +{ + LogAssert(pEntry->IsInList()); + LogAssert(pEntry!=&m_dummyEntry); + + pEntry->m_pNext->m_pPrev = pEntry->m_pPrev; + pEntry->m_pPrev->m_pNext = pEntry->m_pNext; + pEntry->m_pNext=pEntry; + pEntry->m_pPrev=pEntry; + return pEntry; +} + + +DrBListEntry * DrBList::Find(DrBListEntry * pEntry) +{ + LogAssert(pEntry); + DrBListEntry * pScan=m_dummyEntry.m_pNext; + while (pScan!=&m_dummyEntry) + { + if (pScan==pEntry) + return pEntry; + pScan=pScan->m_pNext; + } + return NULL; +} + +DrBListEntry * DrBList::GetHead() +{ + return (m_dummyEntry.m_pNext==&m_dummyEntry) ? NULL : m_dummyEntry.m_pNext; +} + +DrBListEntry * DrBList::GetTail() +{ + return (m_dummyEntry.m_pPrev==&m_dummyEntry) ? NULL : m_dummyEntry.m_pPrev; +} + +DrBListEntry * DrBList::GetNext(DrBListEntry * pEntry) +{ + return (pEntry->m_pNext==&m_dummyEntry) ? NULL : pEntry->m_pNext; +} + +DrBListEntry * DrBList::GetPrev(DrBListEntry * pEntry) +{ + return (pEntry->m_pPrev==&m_dummyEntry) ? NULL : pEntry->m_pPrev; +} + +// +// Count the number of links in the list and return count +// +DWORD DrBList::CountLinks() +{ + DWORD dwCount=0; + DrBListEntry * pScan=m_dummyEntry.m_pNext; + while (pScan!=&m_dummyEntry) + { + dwCount++; + pScan=pScan->m_pNext; + } + return dwCount; +} + +BOOL DrBList::IsEmpty() +{ + return (m_dummyEntry.m_pNext==&m_dummyEntry); +} + +#endif //end if not defined __DRYADBLIST_H__ diff --git a/DryadVertex/VertexHost/system/classlib/include/DrCommon.h b/DryadVertex/VertexHost/system/classlib/include/DrCommon.h new file mode 100644 index 0000000..c97156e --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrCommon.h @@ -0,0 +1,47 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#pragma warning (disable: 4100) // unreferenced formal parameter +#pragma warning (disable: 4995) // deprecated function + +#include "DrTypes.h" +#include "Dryad.h" +#include "DrErrorDef.h" +#include "DrLogging.h" +#include "DrFunctions.h" +#include "DrHash.h" +#include "MSMutex.h" +#include "DrGuid.h" +#include "DrStringUtil.h" +#include "DrExitCodesDef.h" +#include "DrRefCounter.h" +#include "DrMemory.h" +#include "DrMemoryStream.h" +#include "DrCriticalSection.h" +#include "DrHeap.h" +#include "DrThread.h" +#include "propertyids.h" +#include "DrPropertiesDef.h" +#include "DrMemoryStream.h" +#include "DrNodeAddress.h" +#include "DrTagsDef.h" +#include "DrPropertyDumper.h" diff --git a/DryadVertex/VertexHost/system/classlib/include/DrCriticalSection.h b/DryadVertex/VertexHost/system/classlib/include/DrCriticalSection.h new file mode 100644 index 0000000..2e219e7 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrCriticalSection.h @@ -0,0 +1,520 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/* + * DrCriticalSection + * + * Defines a utility class for wrapping a critical section + */ +#pragma once + + +#define DEFAULT_SPIN_COUNT 4000u +#define DEFAULT_DR_LOG_HELD_TOO_LONG_TIMEOUT (500 * DrTimeInterval_Millisecond) + +class DrCriticalSectionBase : public CRITICAL_SECTION +{ +public: + void Init( + __in PCSTR name = NULL, + DWORD spinCount = DEFAULT_SPIN_COUNT, + bool logUsage = false, + DrTimeInterval logHeldTooLongTimeout = DEFAULT_DR_LOG_HELD_TOO_LONG_TIMEOUT + ); + + void Uninit() + { + DebugLogAssert( RecursionCount==0 && OwningThread == 0); + DeleteCriticalSection( this ); + } + + void Enter( PCSTR functionName = NULL, PCSTR fileName = NULL, UINT lineNumber = 0 ); + + void Leave( PCSTR functionName = NULL, PCSTR fileName = NULL, UINT lineNumber = 0 ); + + bool Acquired() const + { + return RecursionCount > 0 && (DWORD)(DWORD_PTR)OwningThread == GetCurrentThreadId(); + } + + // TODO: deprecate this, misspelled + bool Aquired() const + { + return RecursionCount > 0 && (DWORD)(DWORD_PTR)OwningThread == GetCurrentThreadId(); + } + + // Returns 0 if this thread does not own the critical section + UInt32 GetRecursionCountIfAcquired() + { + return (RecursionCount > 0 && (DWORD)(DWORD_PTR)OwningThread == GetCurrentThreadId()) ? (UInt32)RecursionCount : 0 ; + } + + // also returns true if n is 0 and the thread does not own the critical section + bool AcquiredExactNumberOfTimes(UInt32 n) + { + return (GetRecursionCountIfAcquired() == n); + } + + bool AcquiredExactlyOnce() + { + return RecursionCount == 1 && (DWORD)(DWORD_PTR)OwningThread == GetCurrentThreadId(); + } + + bool NotAcquired() const + { + return (DWORD)(DWORD_PTR)OwningThread != GetCurrentThreadId(); + } + + // TODO: deprecate this, misspelled + bool NotAquired() const + { + return (DWORD)(DWORD_PTR)OwningThread != GetCurrentThreadId(); + } + + // returns the old spin count + DWORD SetSpinCount(DWORD spinCount) + { + return SetCriticalSectionSpinCount(this, spinCount); + } + + // Could cause inconsistent settings if called concurrently. If you might be calling this + // method concurrently, you should claim the lock before calling. Since in most cases + // this is called before the lock is actually used, this precaution is + // not taken by default. + void SetCriticalSectionLoggingParameters( + bool logUsage = true, + DrTimeInterval logHeldTooLongTimeout = DEFAULT_DR_LOG_HELD_TOO_LONG_TIMEOUT + ); + + void SetCriticalSectionLogging( + bool logUsage = true + ); + + void SetCriticalSectionLogHeldTooLongTimeout( + DrTimeInterval logHeldTooLongTimeout = DEFAULT_DR_LOG_HELD_TOO_LONG_TIMEOUT + ); + + + void SetCriticalSectionName(__in PCSTR name) + { + _name = name; + } + + bool TryEnter( PCSTR functionName = NULL, PCSTR fileName = NULL, UINT lineNumber = 0) + { + BOOL entered = TryEnterCriticalSection(this); + + if ( entered ) + { + _lastFunctionName = functionName; + _lastFileName = fileName; + _lastLineNumber = lineNumber; + } + + if ( _logUsage ) + { + if ( entered ) + { + if (fileName != NULL) { + DrLogD( "CritSect TRY ENTER OK, %s at %s %s(%u), addr=%08Ix", + _name, functionName, fileName, lineNumber, this ); + } else { + DrLogD( "CritSect TRY ENTER OK, %s, addr=%08Ix", + _name, this ); + } + } + else + { + if (fileName != NULL) { + DrLogW( "CritSect TRY ENTER FAILED, %s at %s %s(%u), addr=%08Ix", + _name, functionName, fileName, lineNumber, this ); + } else { + DrLogW( "CritSect TRY ENTER FAILED, %s, addr=%08Ix", + _name, this ); + } + } + } + + return entered != FALSE; + } + + PCSTR _name; + bool _logUsage; + DWORD _logHeldTooLongTimeoutMs; + PCSTR _lastFunctionName; + PCSTR _lastFileName; + UINT _lastLineNumber; + DWORD _enterTimeMs; +}; + + +class DrCriticalSection : public DrCriticalSectionBase +{ +public: + + DrCriticalSection( + PCSTR name = NULL, + DWORD spinCount = DEFAULT_SPIN_COUNT, + bool logUsage = false, + UINT logHeldTooLongTimeout = DEFAULT_DR_LOG_HELD_TOO_LONG_TIMEOUT + ) + { + Init( name, spinCount, logUsage, logHeldTooLongTimeout ); + } + + ~DrCriticalSection() + { + Uninit(); + } + +}; + + +class AutoLock +{ + friend class AutoLockLogged; + +public: + + AutoLock() + { + _lock = NULL; + } + + AutoLock( DrCriticalSectionBase* lock ) + { + _lock = lock; + _lock->Enter(); + } + + AutoLock( DrCriticalSectionBase& lock ) + { + _lock = &lock; + _lock->Enter(); + } + + AutoLock( DrCriticalSectionBase* lock, PCSTR functionName, PCSTR fileName, UINT lineCount ) + { + _lock = lock; + _lock->Enter( functionName, fileName, lineCount ); + } + + ~AutoLock() + { + if ( _lock != NULL ) + { + _lock->Leave(); + } + } + + void Enter( DrCriticalSectionBase* newLock = NULL ) + { + if ( newLock != NULL ) + { + LogAssert( _lock == NULL ); + _lock = newLock; + } + LogAssert( _lock != NULL ); + _lock->Enter(); + } + + void Leave() + { + LogAssert( _lock != NULL ); + _lock->Leave(); + _lock = NULL; + } + +private: + + AutoLock( const AutoLock& ); + AutoLock& operator=( const AutoLock& ); + + DrCriticalSectionBase* _lock; + +}; + +typedef AutoLock DrScopedCritSec; +typedef AutoLock DrAutoCriticalSection; + +#define DR_ENTER(lock) (lock).Enter( __FUNCTION__, __FILE__, __LINE__ ) +#define DR_LEAVE(lock) (lock).Leave( __FUNCTION__, __FILE__, __LINE__ ) +#define DR_ENTER_NOLOG(lock) (lock).Enter() +#define DR_LEAVE_NOLOG(lock) (lock).Leave() + +#define EnterAndLog() Enter( __FUNCTION__, __FILE__, __LINE__ ) +#define LeaveAndLog() Leave( __FUNCTION__, __FILE__, __LINE__ ) + +#define LOCK_JOIN_NAME2(a, b, c) a##_line_##b##_counter_##c +#define LOCK_JOIN_NAME(a, b, c) LOCK_JOIN_NAME2(a, b, c) +#define LOCK(lock) AutoLock LOCK_JOIN_NAME(autoLock, __LINE__, __COUNTER__) (&lock, __FUNCTION__, __FILE__, __LINE__); +#define LOCK_NOLOG(lock) AutoLock LOCK_JOIN_NAME(autoLock, __LINE__, __COUNTER__) (&lock); + +#define DR_IN_CRITSEC(cs) LOCK(cs) + +#define DRLOCKABLENOIMPL \ + public: \ + virtual void Lock() override = 0; \ + virtual void Lock() const override = 0; \ + virtual void Unlock() override = 0; \ + virtual void Unlock() const override = 0; \ + +#define DRLOCKABLEIMPL_NO_LOCK \ + public: \ + virtual void Lock() override\ + { \ + } \ + void Lock() const override \ + { \ + } \ + virtual void Unlock() override \ + { \ + } \ + virtual void Unlock() const override \ + { \ + } + +#define DRLOCKABLEIMPL_PROTO \ + public: \ + virtual void Lock(); \ + virtual void Lock() const; \ + virtual void Unlock(); \ + virtual void Unlock() const; \ + +#define DRLOCKABLEIMPL_BASE(InterfaceClass) \ + protected: \ + mutable DrCriticalSection m_cs; \ + public: \ + virtual void Lock() override \ + { \ + m_cs.Enter(); \ + } \ + virtual void Lock() const override \ + { \ + m_cs.Enter(); \ + } \ + virtual void Unlock() override \ + { \ + m_cs.Leave(); \ + } \ + virtual void Unlock() const override \ + { \ + m_cs.Leave(); \ + } + +#define DRLOCKABLEIMPL_DELEGATE(pImpl) \ + public: \ + virtual void Lock() override \ + { \ + pImpl->Lock(); \ + } \ + virtual void Lock() const override \ + { \ + pImpl->Lock(); \ + } \ + virtual void Unlock() override \ + { \ + pImpl->Unlock(); \ + } \ + virtual void Unlock() const override \ + { \ + pImpl->Unlock(); \ + } + +#define DRLOCKABLEIMPL_DELEGATE_OUTSIDE_CLASS(classname, pImpl) \ + void classname::Lock() \ + { \ + pImpl->Lock(); \ + } \ + void classname::Lock() const \ + { \ + pImpl->Lock(); \ + } \ + void classname::Unlock() \ + { \ + pImpl->Unlock(); \ + } \ + void classname::Unlock() const \ + { \ + pImpl->Unlock(); \ + } + + +#define DRLOCKABLEIMPL DRLOCKABLEIMPL_BASE(IDrLockable) + +class IDrLockable +{ +public: + virtual void Lock() = 0; + virtual void Lock() const = 0; + virtual void Unlock() = 0; + virtual void Unlock() const = 0; + virtual ~IDrLockable() + { + } +}; + +class DrLockable : public IDrLockable +{ + DRLOCKABLEIMPL +}; + +template class DrScopedLock +{ +public: + DrScopedLock(T *pT = NULL) + { + m_pT = pT; + if (pT != NULL) { + pT->Lock(); + } + } + + ~DrScopedLock() + { + if (m_pT != NULL) { + m_pT->Unlock(); + } + } + +private: + T *m_pT; + +}; + +typedef DrScopedLock DrScopedLockable; + +extern DrCriticalSection *g_pDrGlobalCritSec; // A general shared critical section that is intialized on first use or at static constructor time + +extern DrCriticalSection *DrInitializeGlobalCritSec(); + +__inline DrCriticalSection *DrGetGlobalCritSec() +{ + if (g_pDrGlobalCritSec == NULL) { + DrInitializeGlobalCritSec(); + } + return g_pDrGlobalCritSec; +} + +__inline void DrEnterGlobalCritSec() +{ + if (g_pDrGlobalCritSec == NULL) { + DrInitializeGlobalCritSec(); + } +#pragma prefast(push) +#pragma prefast(disable:6011) + g_pDrGlobalCritSec->Enter(); +#pragma prefast(pop) +} + +__inline void DrLeaveGlobalCritSec() +{ + g_pDrGlobalCritSec->Leave(); +} + +// A scoped lock of the singleton global cosmos critical section +class DrScopedGlobalLock +{ +public: + DrScopedGlobalLock() + { + DrEnterGlobalCritSec(); + } + + ~DrScopedGlobalLock() + { + DrLeaveGlobalCritSec(); + } +}; + +// put this in your header file to declare a static singleton lock that is never destructed and can be +// used within static initializers +#define DECLARE_STATIC_LOCK(name) \ + class name \ + { \ + public: \ + static DrCriticalSection *InitialGetCritSec() \ + { \ + DrScopedGlobalLock glock; \ + if (s_pCritSec == NULL) { \ + s_pCritSec = new DrCriticalSection(); \ + } \ + return s_pCritSec; \ + } \ + __forceinline static DrCriticalSection *GetCritSec() \ + { \ + if (s_pCritSec == NULL) { \ + InitialGetCritSec(); \ + } \ + return s_pCritSec; \ + } \ + static void Lock() \ + { \ + GetCritSec()->Enter(); \ + } \ + static void Unlock() \ + { \ + GetCritSec()->Enter(); \ + } \ + private: \ + static DrCriticalSection *s_pCritSec; \ + }; \ + +// Put this in your cpp file to provide a definition for a static singleton lock that is never destructed and can be +// used within static initializers. +#define DEFINE_STATIC_LOCK(name) \ + DrCriticalSection * name::s_pCritSec = NULL; + +// A scoped lock of a singleton static critical section that is never destructed and can be +// used within static initializers. +template class DrScopedStaticLock +{ +public: + DrScopedStaticLock() + { + LockName::Lock(); + } + + ~DrScopedStaticLock() + { + LockName::Unlock(); + } +}; + +#define DRLOCKABLEIMPL_DELEGATE_TO_STATIC(LockName) \ + public: \ + virtual void Lock() override \ + { \ + LockName::Lock(); \ + } \ + virtual void Lock() const override \ + { \ + LockName::Lock(); \ + } \ + virtual void Unlock() override \ + { \ + LockName::Unlock(); \ + } \ + virtual void Unlock() const override \ + { \ + LockName::Unlock(); \ + } + + diff --git a/DryadVertex/VertexHost/system/classlib/include/DrError.h b/DryadVertex/VertexHost/system/classlib/include/DrError.h new file mode 100644 index 0000000..a032d90 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrError.h @@ -0,0 +1,557 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//JC Check this file for redundant information + +// DrError.h +// +// This file must contain *only* DEFINE_DR_ERROR directives! +// +// +// It is included multiple times with different macro definitions. + +#ifndef COMMON_DR_ERRORS_DEFINED + +// The operation succeeded +DEFINE_DR_ERROR (DrError_OK, S_OK, "The operation succeeded") + +// The operation failed +DEFINE_DR_ERROR (DrError_Fail, E_FAIL, "The operation failed") + +#endif + +// Out of memory +DEFINE_DR_ERROR (DrError_OutOfMemory, DR_ERROR (0x0001), "Out of memory") + +// The operation has not yet completed +DEFINE_DR_ERROR (DrError_IncompleteOperation, DR_ERROR (0x0002), "The operation has not yet completed") + +// Out of Range +DEFINE_DR_ERROR (DrError_OutOfRange, DR_ERROR (0x0003), "Out of Range") + +#ifndef COMMON_DR_ERRORS_DEFINED + +// Invalid Parameter +DEFINE_DR_ERROR (DrError_InvalidParameter, DR_ERROR (0x0004), "Invalid Parameter") + +#endif + +// Operation timed out +DEFINE_DR_ERROR (DrError_Timeout, DR_ERROR (0x0005), "Operation timed out") + +// Remote end of the connection disconnected +DEFINE_DR_ERROR (DrError_RemoteDisconnected, DR_ERROR (0x0006), "Remote end of the connection disconnected") + +// Local end of the connection disconnected +DEFINE_DR_ERROR (DrError_LocalDisconnected, DR_ERROR (0x0007), "Local end of the connection disconnected") + +// Listen operation was stopped +DEFINE_DR_ERROR (DrError_ListenStopped, DR_ERROR (0x0008), "Listen operation was stopped ") + +// Stream already exists +DEFINE_DR_ERROR (DrError_StreamAlreadyExists, DR_ERROR (0x0009), "Stream already exists") + +// Stream not found +DEFINE_DR_ERROR (DrError_StreamNotFound, DR_ERROR (0x000a), "Stream not found") + +#ifndef COMMON_DR_ERRORS_DEFINED + +// End of stream +DEFINE_DR_ERROR (DrError_EndOfStream, DR_ERROR (0x000b), "End of stream") + +// Invalid property +DEFINE_DR_ERROR (DrError_InvalidProperty, DR_ERROR (0x000c), "Invalid property") + +#endif + +// Extent not found +DEFINE_DR_ERROR (DrError_ExtentNotFound, DR_ERROR (0x000d), "Extent not found") + +#ifndef COMMON_DR_ERRORS_DEFINED + +// Unable to peek this far forward +DEFINE_DR_ERROR (DrError_PeekTooFar, DR_ERROR (0x000e), "Unable to peek this far forward") + +// String is too long +DEFINE_DR_ERROR (DrError_StringTooLong, DR_ERROR (0x000f), "String is too long") + +#endif + +// operation aborted +DEFINE_DR_ERROR (DrError_Aborted, DR_ERROR (0x0010), "The operation was aborted") + +// Directory does not exist +DEFINE_DR_ERROR (DrError_DirectoryDoesNotExist, DR_ERROR (0x0011), "Directory does not exist") + +// Invalid version +DEFINE_DR_ERROR (DrError_InvalidVersion, DR_ERROR (0x0012), "Invalid version") + +// Host not found or name resolution error +DEFINE_DR_ERROR (DrError_HostNotFound, DR_ERROR (0x0013), "Host not found") + +// Parameter not found +DEFINE_DR_ERROR (DrError_ParameterNotFound, DR_ERROR (0x0014), "Parameter not found") + +// Bad offset +DEFINE_DR_ERROR (DrError_BadOffset, DR_ERROR (0x0015), "Bad offset") + +// No available extent instances +DEFINE_DR_ERROR (DrError_NoExtentInstances, DR_ERROR (0x0016), "No available extent instances") + +// Invalid pathname +DEFINE_DR_ERROR (DrError_InvalidPathname, DR_ERROR (0x0017), "Invalid pathname") + +// Not implemented +DEFINE_DR_ERROR (DrError_NotImplemented, DR_ERROR (0x0018), "Not implemented") + +// Improper protocol response received +DEFINE_DR_ERROR (DrError_ImproperProtocolResponse, DR_ERROR (0x0019), "Improper Protocol Response Received") + +// The connection failed +DEFINE_DR_ERROR (DrError_ConnectionFailed, DR_ERROR (0x001a), "The connection failed") + +// Extent cannot be appended as it has reached its size limit +DEFINE_DR_ERROR(DrError_ExtentFull, DR_ERROR (0x001b), "The extent is full") + +// Extent has been sealed +DEFINE_DR_ERROR(DrError_ExtentSealed, DR_ERROR (0x001c), "The extent is sealed") + +// Extent failed to seal after reaching its size limit on an append +DEFINE_DR_ERROR(DrError_ExtentSealFailOnFull, DR_ERROR (0x001d), "The extent failed to seal as full") + +// Extent operation that can only be done on primary attempted on a non-primary EN +DEFINE_DR_ERROR(DrError_EnPrimaryOnly, DR_ERROR (0x001e), "The extent instance is not the primary instance") + +// Timeout waiting for a protocol response +DEFINE_DR_ERROR(DrError_ResponseTimeout, DR_ERROR (0x001f), "The Dryad server did not respond to the request") + +// Timeout waiting for a protocol send +DEFINE_DR_ERROR (DrError_SendTimeout, DR_ERROR (0x0020), "Unable to reach the Dryad node") + +// Name server unavailable +DEFINE_DR_ERROR (DrError_NameServerUnavailable, DR_ERROR (0x0021), "Dryad Name server unavailable") + +#ifndef COMMON_DR_ERRORS_DEFINED + +// Line Too Long +DEFINE_DR_ERROR (DrError_LineTooLong, DR_ERROR (0x0022), "Line is too long") + +#endif + +// Process Already Exists +DEFINE_DR_ERROR (DrError_ProcessAlreadyExists, DR_ERROR (0x0023), "Process already exists") + +// Extent operation that can only be done only on non-primary attempted on primary +DEFINE_DR_ERROR (DrError_EnNonPrimaryOnly, DR_ERROR (0x0024), "The extent instance is the primary instance") + +// Process Already Exists +DEFINE_DR_ERROR (DrError_UnknownProcess, DR_ERROR (0x0025), "Unknown process") + +// Extent is incomplete (and therefore whatever operation was being atttempted is invalid) +DEFINE_DR_ERROR (DrError_ExtentIncomplete, DR_ERROR (0x0026), "The extent is incomplete") + +// Length specified is invalid +DEFINE_DR_ERROR (DrError_InvalidLength, DR_ERROR (0x0028), "Invalid length") + +// Retry limit reached (client lib append) +DEFINE_DR_ERROR (DrError_RetryLimit, DR_ERROR (0x0029), "Retry limit reached") + +// Previous operation in queue failed (client lib fixed extent read/write queues) +DEFINE_DR_ERROR (DrError_PreviousFailed, DR_ERROR (0x002a), "Previous operation in queue failed") + +// This handle is not open for read +DEFINE_DR_ERROR (DrError_HandleInvalidModeRead, DR_ERROR (0x002b), "This handle is not open for read") + +// This handle is not open for append +DEFINE_DR_ERROR (DrError_HandleInvalidModeAppend, DR_ERROR (0x002c), "This handle is not open for append") + +// client lib Initialization failed +DEFINE_DR_ERROR (DrError_InitializationFailed, DR_ERROR (0x002d), "Initialization failed") + +// invalid handle +DEFINE_DR_ERROR (DrError_InvalidHandle, DR_ERROR (0x002e), "Invalid handle value") + +// Disconnect waiting for a protocol response +DEFINE_DR_ERROR(DrError_ResponseDisconnect, DR_ERROR (0x002f), "Connection lost while waiting for a response") + +// File-based stream allows only 0 extent index. +DEFINE_DR_ERROR(DrError_InvalidFileExtentIndex, DR_ERROR (0x0030), "Local file-based Dryad stream allows only 0 extent index") + +// Attempt to append beyond an unsealed extent in a stream +DEFINE_DR_ERROR (DrError_OffsetIsUnsealed, DR_ERROR (0x0031), "Offset is unsealed") + +// Request/call/operation unexpected +DEFINE_DR_ERROR(DrError_Unexpected, DR_ERROR (0x0032), "Unexpected call or request") + +#ifndef COMMON_DR_ERRORS_DEFINED + +// Invalid time interval +DEFINE_DR_ERROR(DrError_InvalidTimeInterval, DR_ERROR (0x0033), "Invalid time interval") + +#endif + +// Too many redirects +DEFINE_DR_ERROR(DrError_TooManyRedirects, DR_ERROR (0x0034), "Too many redirects") + +// Incompatible rendezvous parties +DEFINE_DR_ERROR(DrError_IncompatibleRendezvousParties, DR_ERROR (0x0035), "Incompatible rendezvous parties") + +// No matching rendezvous part +DEFINE_DR_ERROR(DrError_NoMatchingRendezvousParty, DR_ERROR (0x0036), "No matching rendezvous party") + +// Failed to seal an extent (used in failure mode sealing) +DEFINE_DR_ERROR (DrError_FailedToSeal, DR_ERROR (0x0037), "Failed to seal extent") + +//Stream option id is invalid +DEFINE_DR_ERROR(DrError_InvalidOption, DR_ERROR (0x0038), "Stream option id is invalid") + +//Stream option value size is invalid +DEFINE_DR_ERROR(DrError_InvalidOptionSize, DR_ERROR (0x0039), "Stream option value size is invalid") + +//Item not found in lookup table, typically converted to specific item code +DEFINE_DR_ERROR(DrError_ItemNotFound, DR_ERROR (0x003a), "Item not found") + +// The pipe was redirected +DEFINE_DR_ERROR(DrError_InternalPipeRedirected, DR_ERROR (0x003b), "The pipe was redirected (internal)") + +// request did not succeed, please talk to the primary Server +DEFINE_DR_ERROR(DrError_TalkToPrimaryServer, DR_ERROR (0x003c), "Please talk to primary server for this request") + +// Invalid pipe URI +DEFINE_DR_ERROR(DrError_InvalidPipeUri, DR_ERROR (0x003d), "Invalid Pipe URI") + +// Artificial fault error +DEFINE_DR_ERROR(DrError_ArtificialError, DR_ERROR (0x003e), "Artificial test error; Please test again") + +// Both streams should be on the same volume. +DEFINE_DR_ERROR(DrError_SameVolumeRequired, DR_ERROR (0x003f), "Both streams should be on the same volume") + +// Append block cannot be larger than the extent size +DEFINE_DR_ERROR(DrError_AppendBlockTooLarge, DR_ERROR (0x0040), "Append block is larger then the local limit for append or extent size") + +// Message was inconsistent +DEFINE_DR_ERROR(DrError_InconsistentMessage, DR_ERROR (0x0041), "Inconsistent message") + +// Message was inconsistent +DEFINE_DR_ERROR(DrError_OutOfLimits, DR_ERROR (0x0042), "The value was out of given limits") + +// Client specified a volume ID on the DRM that does not match the DRM's volume ID +DEFINE_DR_ERROR(DrError_InvalidVolumeId, DR_ERROR (0x0043), "The volume id was invalid") + +// Unsupported version +DEFINE_DR_ERROR(DrError_UnsupportedVersion, DR_ERROR (0x0044), "This version is not supported") + +// Client lib is not initizalized +DEFINE_DR_ERROR(DrError_ClientNotInit, DR_ERROR (0x0045), "Dryad Client Library is not initialized") + +// Inconsistent configuration +DEFINE_DR_ERROR(DrError_InconsistentConfig, DR_ERROR (0x0046), "Inconsistent INI Configuration") + +// One of hosts in chain failed with network error +DEFINE_DR_ERROR(DrError_ChainedNetworkError, DR_ERROR (0x0047), "One of hosts in chain failed with network error") + +// Process Property Version Mismatch +DEFINE_DR_ERROR(DrError_ProcessPropertyVersionMismatch, DR_ERROR (0x0048), "Process property version mismatch") + +// Current unsealed extent read from primary failed +DEFINE_DR_ERROR(DrError_PrimaryFailed, DR_ERROR (0x0049), "Current unsealed extent read from primary failed") + +// Async operation already completed +DEFINE_DR_ERROR(DrError_AlreadyCompleted, DR_ERROR (0x004a), "Async operation already completed") + +// Response message is invalid. +DEFINE_DR_ERROR(DrError_InvalidReply, DR_ERROR (0x004b), "Response message is invalid") + +// Internal error: required metadata is missing. +DEFINE_DR_ERROR(DrError_MissingMetadata, DR_ERROR (0x004c), "Internal error: required metadata is missing") + +// Internal error: required metadata is in inconsistent state. +DEFINE_DR_ERROR(DrError_CorruptMetadata, DR_ERROR (0x004d), "Internal error: required metadata is in inconsistent state") + +// Unable to seek over unsealed extent. +DEFINE_DR_ERROR(DrError_NoSeekOverUnsealed, DR_ERROR (0x004e), "Unable to seek over unsealed extent") + +// Too many files being retrieved by the cache manager +DEFINE_DR_ERROR(DrError_TooManyFilesBeingRetrieved, DR_ERROR (0x004f), "Too many files being retrieved by the cache manager") + +// Primary extent instance is not present in retrieved metadata +DEFINE_DR_ERROR(DrError_NoPrimary, DR_ERROR (0x0050), "Primary extent instance is not present in retrieved metadata") + +// Stream is read-only +DEFINE_DR_ERROR(DrError_StreamReadOnly, DR_ERROR (0x0051), "Stream is read-only") + +// Async operation was queued up successfully and it's now pending execution +DEFINE_DR_ERROR(DrError_Pending, DR_ERROR (0x0052), "Operation is in progress") + +// Attempted operation is invalid in the given context +DEFINE_DR_ERROR(DrError_InvalidOperation, DR_ERROR (0x0053), "Executing operation is invalid in the given context") + +// Extent is not sealed +DEFINE_DR_ERROR(DrError_ExtentNotSealed, DR_ERROR (0x0054), "Extent is not sealed") + +// Duplicated resources +DEFINE_DR_ERROR(DrError_DuplicatedResources, DR_ERROR (0x0055), "This file exists; possible duplicated resources specified") + +// Unhandled STL exception +DEFINE_DR_ERROR(DrError_InternalException, DR_ERROR (0x0056), "Unhandled internal exception") + +// Pn Queue is full and can not admit a new process +DEFINE_DR_ERROR(DrError_PnQueueFull, DR_ERROR (0x0057), "Pn Queue Full") + +// Append request not at end of data +DEFINE_DR_ERROR(DrError_AppendNotAtEnd, DR_ERROR (0x0058), "Append request not at end of data") + +// Supplied password for Pn accounts is not correct +DEFINE_DR_ERROR(DrError_IncorrectPnPassword, DR_ERROR (0x0059), "Incorrect PN password") + +// Extent is not in active state. It's either deleted, recycled, incomplete or corrupt. +DEFINE_DR_ERROR(DrError_ExtentNotActive, DR_ERROR (0x0060), "Extent is not in active state; it is either deleted, recycled, incomplete or corrupt") + +// Seal operation is in progress. +DEFINE_DR_ERROR(DrError_SealInProgress, DR_ERROR (0x0061), "Seal operation is in progress") + +// Extent needs failure mode seal. +DEFINE_DR_ERROR(DrError_ExtentNeedsSeal, DR_ERROR (0x0062), "Extent needs failure mode seal") + +// Extent needs failure mode seal. +DEFINE_DR_ERROR(DrError_UnknownMethod, DR_ERROR (0x0063), "Unknown request or method") + +// not authorized +DEFINE_DR_ERROR(DrError_NotAuthorized, DR_ERROR (0x0064), "Client request cannot be authorized") + +// Syntax error +DEFINE_DR_ERROR(DrError_SyntaxError, DR_ERROR (0x0065), "Syntax error") + +// Attempt to suspend a task that was already suspended +DEFINE_DR_ERROR(DrError_TaskAlreadySuspended, DR_ERROR (0x0066), "Task already suspended") + +// Attempt to unsuspend a task that was already unsuspended +DEFINE_DR_ERROR(DrError_TaskAlreadyUnsuspended, DR_ERROR (0x0067), "Task already unsuspended") + +// Task does not exist +DEFINE_DR_ERROR(DrError_TaskNotFound, DR_ERROR (0x0068), "Task does not exist") + +// Task exists but it is in the deleted state +DEFINE_DR_ERROR(DrError_TaskDeleted, DR_ERROR (0x0069), "Task is in the deleted state") + +// Task already completed (e.g. you tried to suspend it when it was already completed) +DEFINE_DR_ERROR(DrError_TaskAlreadyCompleted, DR_ERROR (0x006A), "Task already completed") + +// Scheduler already paused +DEFINE_DR_ERROR(DrError_SchedulerAlreadyPaused, DR_ERROR (0x006B), "Scheduler already paused") + +// Scheduler already unpaused +DEFINE_DR_ERROR(DrError_SchedulerAlreadyUnpaused, DR_ERROR (0x006C), "Scheduler already unpaused") + +// DrServiceDescriptor name resolving errors +DEFINE_DR_ERROR(DrError_ClusterNotFound, DR_ERROR (0x006D), "Cluster not found") +DEFINE_DR_ERROR(DrError_ServiceTypeNotFound, DR_ERROR (0x006E), "Service type not found") +DEFINE_DR_ERROR(DrError_NamespaceNotFound, DR_ERROR (0x006F), "Namespace/volume not found") +DEFINE_DR_ERROR(DrError_ServiceInstanceNotFound, DR_ERROR (0x0070), "Service instance not found") + +// Job ticket is invalid +DEFINE_DR_ERROR(DrError_JobTicketInvalid, DR_ERROR (0x0071), "Job ticket is invalid") + +// Job ticket has expired +DEFINE_DR_ERROR(DrError_JobTicketExpired, DR_ERROR (0x0072), "Job ticket has expired") + +// Another authentication is in process +DEFINE_DR_ERROR(DrError_AuthenticationInProcess, DR_ERROR (0x0073), "Another authentication is in process") + + +// i-th readahead request was satisfied with (i-1)-th +DEFINE_DR_ERROR(DrError_ReadaheadAleadySatisfied, DR_ERROR (0x0074), "Readahead satisfied by previous request") + +// +// Cache service errors +// +DEFINE_DR_ERROR(DrError_CacheItemFetchIncomplete, DR_ERROR (0x0075), "Cache Item is still being fetched from source") +DEFINE_DR_ERROR(DrError_CacheItemFetchFailed, DR_ERROR (0x0076), "Failed to fetch the cache item from source") +DEFINE_DR_ERROR(DrError_CacheItemSourceInvalid, DR_ERROR (0x0077), "The source of cache item is invalid") + +// PN process states +DEFINE_DR_ERROR(DrError_PnProcessInitializing, DR_ERROR (0x0078), "The process is being initialized") +DEFINE_DR_ERROR(DrError_PnProcessCreated, DR_ERROR (0x0079), "The process has been created") +DEFINE_DR_ERROR(DrError_PnProcessRunning, DR_ERROR (0x007A), "The process is running") +DEFINE_DR_ERROR(DrError_PnProcessCreateFailed, DR_ERROR (0x007B), "Process create failed") + +// VCWS errors +DEFINE_DR_ERROR(DrError_VcwsVirtualClusterNotFound, DR_ERROR (0x007C), "Virtual cluster not found") +DEFINE_DR_ERROR(DrError_VcwsMountPointNotFound, DR_ERROR (0x007D), "Mount point not found") +DEFINE_DR_ERROR(DrError_VcwsPermissionDenied, DR_ERROR (0x007E), "Permission denied") +DEFINE_DR_ERROR(DrError_VcwsInvalidMountPoint, DR_ERROR (0x007F), "Invalid mount point") +DEFINE_DR_ERROR(DrError_VcwsInvalidVirtualCluster, DR_ERROR (0x0080), "Invalid virtual cluster") +DEFINE_DR_ERROR(DrError_VcwsIncompleteMountPoint, DR_ERROR (0x0081), "Incomplete mount point") +DEFINE_DR_ERROR(DrError_DmRequestFailed, DR_ERROR (0x0082), "Failed to send command to DM") +DEFINE_DR_ERROR(DrError_CjsNoMaster, DR_ERROR (0x0083), "No master of CJS machines from DM") + +//PN process states continued +DEFINE_DR_ERROR(DrError_PnProcessSystemProcessRunning, DR_ERROR (0x0084), "Could not create process due to System priority process.") + +DEFINE_DR_ERROR(DrError_XmlAttributeNotFound, DR_ERROR (0x0085), "Attribute cannot be found in XML") +DEFINE_DR_ERROR(DrError_VcwsShareNameAlreadyExists, DR_ERROR (0x0086), "Share name has already exists in the virtual cluster") +DEFINE_DR_ERROR(DrError_VcwsShareNotFound, DR_ERROR (0x0087), "Share not found") +DEFINE_DR_ERROR(DrError_VcwsMultipleShareNames, DR_ERROR (0x0088), "This directory has already been shared with another share name") +DEFINE_DR_ERROR(DrError_VcwsJobResourceTooLarge, DR_ERROR (0x0089), "The size of some job resource is too large") +DEFINE_DR_ERROR(DrError_VcwsTooManyLogAppendRequestsInProgress, DR_ERROR (0x008a), "There are too many log append requests in progress and please re-try later.") +DEFINE_DR_ERROR(DrError_VcwsInvalidVcPath, DR_ERROR (0x008b), "The VC path format is invalid.") + +//PN process states continued +DEFINE_DR_ERROR(DrError_PnDeploymentInProgress, DR_ERROR (0x008C), "Could not create process due to deployment of cosmos code.") + +// Mirroring +DEFINE_DR_ERROR (DrError_MirrorPolicyNotFound, DR_ERROR (0x0090), "MirrorPolicyNotFound") + +// Returned when you try to perform some operation on a mirrored object that is not allowed; e.g. deleting a mirrored +// stream or directory, or setting up mirroring for a directory that is already mirrored. +DEFINE_DR_ERROR (DrError_ObjectIsMirrored, DR_ERROR (0x0091), "ObjectIsMirrored") + +DEFINE_DR_ERROR (DrError_DirectoryAlreadyMirrored, DR_ERROR (0x0092), "DirectoryAlreadyMirrored") +DEFINE_DR_ERROR (DrError_MirrorInstanceNotFound, DR_ERROR (0x0093), "MirrorInstanceNotFound") + +// Returned when you try to perform an illegal operation on a mirror instance that is the master; e.g. trying to delete it +DEFINE_DR_ERROR (DrError_MirrorInstanceIsMaster, DR_ERROR (0x0094), "MirrorInstanceIsMaster") + +DEFINE_DR_ERROR (DrError_MirrorInstanceAlreadyExists, DR_ERROR (0x0095), "MirrorInstanceAlreadyExists") +DEFINE_DR_ERROR (DrError_DirectoryNotMirrored, DR_ERROR (0x0096), "DirectoryNotMirrored") + +// Returned when mirror policy name does not start with the "cluster.volume." prefix of the master +DEFINE_DR_ERROR (DrError_InvalidMirrorPolicyName, DR_ERROR (0x0097), "InvalidMirrorPolicyName") + +DEFINE_DR_ERROR (DrError_MirrorPolicyAlreadyExists, DR_ERROR (0x0098), "MirrorPolicyAlreadyExists") + +DEFINE_DR_ERROR (DrError_UnalbeToAllocateExtentInstances, DR_ERROR (0x0099), "Unable to allocate extent instances") + +// VCWS errors continued +DEFINE_DR_ERROR(DrError_VcwsDrJobTimeout, DR_ERROR (0x009a), "Your stream request could not be completed in time because of unusually high latency to the volume of your requested resource. Please try again later.") +DEFINE_DR_ERROR(DrError_VcwsDrJobReject, DR_ERROR (0x009b), "Your stream request is being rejected because too many timeouts to the volume of your requested resource have been observed. Please try again later.") + +// network send abort: map the same error in DrNetlib +DEFINE_DR_ERROR (DrError_SendAborted, DR_ERROR (0x009c), "Send to cosmos node aborted") + +// PN memory, execution time and max execution time exceeded, speculative process killed +DEFINE_DR_ERROR (DrError_ExecutionQuotaExceeded, DR_ERROR (0x009d), "Process exceeded execution time lease") +DEFINE_DR_ERROR (DrError_MaxExecutionQuotaExceeded, DR_ERROR (0x009e), "Process execution time exceeded upper threshold") +DEFINE_DR_ERROR (DrError_MemoryQuotaExceeded, DR_ERROR (0x009f), "Process exceeded memory quota") +DEFINE_DR_ERROR (DrError_SpeculativeProcessAborted, DR_ERROR (0x0110), "Speculative process aborted") +DEFINE_DR_ERROR (DrError_ProcessTerminated, DR_ERROR (0x0111), "Process terminated by client") + +// service is too busy processing requests on the stream +DEFINE_DR_ERROR (DrError_StreamTooBusy, DR_ERROR (0x0112), "Service is too busy processing requests of this stream") + +// VCWS errors continued +DEFINE_DR_ERROR(DrError_VcwsInvalidShareName, DR_ERROR (0x0120), "There are invalid characters in share name and the list of all invalid characters is \"ASCII 1-31, \", *, :, <, >, ?, \\, , %%, [, ], (, ), &, ;, /, {, }, #\"") +DEFINE_DR_ERROR(DrError_VcwsPermissionDeniedByMountPointSecurityGroups, DR_ERROR (0x0121), "You are not in the security groups associated with this mount point. Please ask your VC admin for specific information.") + +// StreamSet errors +DEFINE_DR_ERROR(DrError_StreamSetTemplateEmpty, DR_ERROR (0x0130), "[StreamSet:Template] in StreamSet.ini is of zero length.") +DEFINE_DR_ERROR(DrError_InvalidStreamSetHotSpot, DR_ERROR (0x0131), "[StreamSet:HotSpot] in StreamSet.ini is invalid.") +DEFINE_DR_ERROR(DrError_InvalidStreamSetExpireAfter, DR_ERROR (0x0132), "[StreamSet:ExpireAfter] in StreamSet.ini is invalid.") +DEFINE_DR_ERROR(DrError_MultipleStreamsForAppend, DR_ERROR (0x0133), "A streamset being mapped to multiple streams is specified for appending.") + +// OpenFromBinary error (clientlib) +DEFINE_DR_ERROR(DrError_CannotCreateFromBinary, DR_ERROR (0x0134), "A stream cannot be created from binary stream info.") + +// APX Client errors +DEFINE_DR_ERROR(DrError_ApxAddressNotFound, DR_ERROR (0x0140), "No available Autopilot authentication proxy IP address could be found.") +DEFINE_DR_ERROR(DrError_ApxUnexpectedHttpStatusCode, DR_ERROR (0x0141), "An unexpected HTTP status code received from Autopilot authentication proxy.") + +// Dryad Partition Manager +DEFINE_DR_ERROR(DrError_InvalidPartitionEntry, DR_ERROR (0x0150), "Invalid Dryad partition entry.") +DEFINE_DR_ERROR(DrError_InvalidPartitionTable, DR_ERROR (0x0151), "Invalid Dryad partition table.") +DEFINE_DR_ERROR(DrError_PartitionNotFound, DR_ERROR (0x0152), "Dryad partition not found.") +DEFINE_DR_ERROR(DrError_PartitionKeyNotFound, DR_ERROR (0x0153), "Dryad partition key not found.") +DEFINE_DR_ERROR(DrError_PartitionServerUnavailable, DR_ERROR (0x0154), "Dryad Partition Manager unavailable.") + +// +// Cache service errors +// +DEFINE_DR_ERROR(DrError_CacheItemSizeExceedCacheSize, DR_ERROR (0x0160), "The size of the cache item requested is greater than the stable size of the cache") + +// +// HPC Errors +// +DEFINE_DR_ERROR(DrError_InvalidJob, DR_ERROR (0x0170), "Invalid HPC Job ID.") + + +// The error that should never be returned! +// This is useful for initializing error variables so we can test if they +// got set correctly at some point. +// ! DONT RETURN THIS AS AN ERROR EVER ! + +// An error code was not initialized +DEFINE_DR_ERROR(DrError_Impossible, DR_ERROR (0xFFFF), "An error code was not initialized") + + + + + +/* ------------------------------------ DrIo errors (should eventually go away) --------------------------------- */ + + +// Internal Error +DEFINE_DR_ERROR (DrError_IoInternalError, DR_ERROR (0x0100), "Internal Error") + +// Extent already Exists +DEFINE_DR_ERROR (DrError_IoExtentAlreadyExists, DR_ERROR (0x0101), "Extent already Exists") + +// Extent instance already deleted +DEFINE_DR_ERROR (DrError_IoExtentAlreadyDeleted, DR_ERROR (0x0102), "Extent instance already deleted") + +// Extent instance is corrupted +DEFINE_DR_ERROR (DrError_IoExtentCorrupted, DR_ERROR (0x0103), "Extent instance is corrupted") + +// Cannot open extent instance +DEFINE_DR_ERROR (DrError_IoExtentCannotOpen, DR_ERROR (0x0104), "Cannot open extent instance") + +// The volume is full +DEFINE_DR_ERROR (DrError_IoVolumeIsFull, DR_ERROR (0x0105), "The volume is full") + +// Dryad volume is corrupted +DEFINE_DR_ERROR (DrError_IoVolumeCorrupted, DR_ERROR (0x0106), "Dryad volume is corrupted") + +// Unknown Dryad volume +DEFINE_DR_ERROR (DrError_IoVolumeUnknown, DR_ERROR (0x0107), "Unknown Dryad volume") + +// An I/O operation is pending +DEFINE_DR_ERROR (DrError_IoPending, DR_ERROR (0x0108), "An I/O operation is pending") + +// An I/O error occurred +DEFINE_DR_ERROR (DrError_IoReadWriteError, DR_ERROR (0x0109), "An I/O error occurred") + +// CRC mismatch +DEFINE_DR_ERROR (DrError_IoCorruptedData, DR_ERROR (0x010a), "CRC mismatch") + +#ifndef COMMON_DR_ERRORS_DEFINED +// Extent instance not found +DEFINE_DR_ERROR (DrError_IoExtentNotFound, DR_ERROR (0x010b), "Extent instance not found") +#endif + +// Stream is sealed +DEFINE_DR_ERROR(DrError_StreamSealed, DR_ERROR (0x0200), "Stream is sealed") + +// Stream has changed +DEFINE_DR_ERROR(DrError_StreamChanged, DR_ERROR (0x0201), "Stream ID or length has changed") + + +#ifndef COMMON_DR_ERRORS_DEFINED +#define COMMON_DR_ERRORS_DEFINED +#endif + diff --git a/DryadVertex/VertexHost/system/classlib/include/DrErrorDef.h b/DryadVertex/VertexHost/system/classlib/include/DrErrorDef.h new file mode 100644 index 0000000..c740284 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrErrorDef.h @@ -0,0 +1,77 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#if defined(__cplusplus) +extern "C" { +#endif + +typedef HRESULT DrError; + +class DrStr; + +extern DrStr& DrAppendErrorDescription(DrStr& strOut, DrError err); +//JC extern DrWStr& DrAppendErrorDescription(DrWStr& strOut, DrError err); + +extern void DrInitErrorTable(void); +extern DrError DrAddErrorDescription(DrError code, const char *pszDescription); + +// The returned error string should be freed with free(); +// +// If the error code is unknown, a generic error description is returned. +extern char *DrGetErrorText(DrError err); + +// The returned error string should be freed with free(); +// +// If the error code is unknown, a generic error description is returned. +//JC extern WCHAR *DrGetErrorTextW(DrError err); + +// The buffer must be at least 64 bytes long to guarantee a result. If the result won't fit in the buffer, a generic +// error message is generated. +extern const char *DrGetErrorDescription(DrError err, char *pBuffer, int buffLen); + +// The buffer must be at least 64 chars long to guarantee a result. If the result won't fit in the buffer, a generic +// error message is generated. +//JC extern const WCHAR *DrGetErrorDescription(DrError err, WCHAR *pBuffer, int buffLen); + +static const int k_DrMaxErrorLength = 256; + +#ifndef FACILITY_DR +#define FACILITY_DR 0x309 +#endif + +#ifndef DR_ERROR +#define DR_ERROR(n) MAKE_HRESULT(SEVERITY_ERROR, FACILITY_DR, n) +#endif + +#ifdef DEFINE_DR_ERROR +#undef DEFINE_DR_ERROR +#endif + +#define DEFINE_DR_ERROR(name, number, description) static const DrError name = number; + +#include "DrError.h" + +#undef DEFINE_DR_ERROR + +#if defined(__cplusplus) +} +#endif diff --git a/DryadVertex/VertexHost/system/classlib/include/DrExecution.h b/DryadVertex/VertexHost/system/classlib/include/DrExecution.h new file mode 100644 index 0000000..d477bd0 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrExecution.h @@ -0,0 +1,103 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#ifndef _DrExecution_h_ +#define _DrExecution_h_ + +#pragma once + + +/* + DrExecution.h - main header file for Dryad execution +*/ + +#include "DrCommon.h" + +//JC#include "DrFileCache.h" +//JC#include "DrPipe.h" +//JC#include "DrPipeLoopback.h" +//JC#include "DrNativeHandlePipe.h" +//JC#include "DrTcpPipe.h" +//JC#include "DrPipeRendezvous.h" +//JC#include "DrJobStatistics.h" + + +//JC#define DR_ENVSTR_PN_QUOTA "PN_QUOTA" + +DrError DrInitExecution(); + +class DrExecutionEnvironment +{ + friend DrError DrInitExecution(); + +public: + + void Lock() + { + m_cs.Enter(); + } + + void Unlock() + { + m_cs.Leave(); + } + + /* JC + void GetCurrentJobDescriptor(DrJobDescriptorEx& jdOut) + { + Lock(); + jdOut = g_pDryadConfig->CurrentJobDescriptor(); + Unlock(); + } + + const DrProcessDescriptor& GetCurrentProcessDescriptor() + { + return g_pDryadConfig->GetCurrentProcessDescriptor(); + } + + const DrProcessDescriptor& GetRootProcessDescriptor() + { + return g_pDryadConfig->GetRootProcessDescriptor(); + } + + bool IsUnderProcessNode() + { + return g_pDryadConfig->IsUnderProcessNode(); + } + + bool JobWasInherited() + { + return g_pDryadConfig->JobWasInherited(); + } +*/ +private: + DrExecutionEnvironment() + { + } + + DrError Initialize(); + +private: + DrCriticalSection m_cs; +}; + +extern DrExecutionEnvironment *g_pDryadExecution; + +#endif // end if not defined _DrExecution_h_ diff --git a/DryadVertex/VertexHost/system/classlib/include/DrExitCodes.h b/DryadVertex/VertexHost/system/classlib/include/DrExitCodes.h new file mode 100644 index 0000000..565352f --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrExitCodes.h @@ -0,0 +1,72 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// DrExitCodes.h +// +// This file must contain *only* declarions of the form: +// +// DEFINE_DREXITCODE(name, value, description) +// DEFINE_DREXITCODE_NO_DESC(name, value) +// DECLARE_DREXITCODE(valname, description) +// DECLARE_DREXITCODE_NO_DESC(valname) +// +// It is included multiple times with different macro definitions. + +// Values from ntstatus.h: +#define STATUS_NOT_SUPPORTED ((DWORD)0xC00000BBL) +#define STATUS_UNHANDLED_EXCEPTION ((DWORD)0xC0000144L) +//#define STATUS_ASSERTION_FAILURE ((NTSTATUS)0xC0000420L) + + +DEFINE_DREXITCODE_NO_DESC(DrExitCode_OK, 0) +DEFINE_DREXITCODE_NO_DESC(DrExitCode_Fail, 1) +DEFINE_DREXITCODE_NO_DESC(DrExitCode_StillActive, STILL_ACTIVE) +DEFINE_DREXITCODE_NO_DESC(DrExitCode_Killed, (DrExitCode) STATUS_CONTROL_C_EXIT) +DEFINE_DREXITCODE_NO_DESC(DrExitCode_JobQuotaExceeded, ERROR_NOT_ENOUGH_QUOTA) + +DECLARE_DREXITCODE_NO_DESC(STATUS_INVALID_HANDLE) +DECLARE_DREXITCODE_NO_DESC(STATUS_ACCESS_VIOLATION) +DECLARE_DREXITCODE_NO_DESC(STATUS_ARRAY_BOUNDS_EXCEEDED) +DECLARE_DREXITCODE_NO_DESC(STATUS_BREAKPOINT) +DECLARE_DREXITCODE_NO_DESC(STATUS_DATATYPE_MISALIGNMENT) +DECLARE_DREXITCODE_NO_DESC(STATUS_FLOAT_DENORMAL_OPERAND) +DECLARE_DREXITCODE_NO_DESC(STATUS_FLOAT_DIVIDE_BY_ZERO) +DECLARE_DREXITCODE_NO_DESC(STATUS_FLOAT_INEXACT_RESULT) +DECLARE_DREXITCODE_NO_DESC(STATUS_FLOAT_INVALID_OPERATION) +DECLARE_DREXITCODE_NO_DESC(STATUS_FLOAT_MULTIPLE_FAULTS) +DECLARE_DREXITCODE_NO_DESC(STATUS_FLOAT_MULTIPLE_TRAPS) +DECLARE_DREXITCODE_NO_DESC(STATUS_FLOAT_OVERFLOW) +DECLARE_DREXITCODE_NO_DESC(STATUS_FLOAT_STACK_CHECK) +DECLARE_DREXITCODE_NO_DESC(STATUS_FLOAT_UNDERFLOW) +DECLARE_DREXITCODE_NO_DESC(STATUS_GUARD_PAGE_VIOLATION) +DECLARE_DREXITCODE_NO_DESC(STATUS_ILLEGAL_INSTRUCTION) +DECLARE_DREXITCODE_NO_DESC(STATUS_IN_PAGE_ERROR) +DECLARE_DREXITCODE_NO_DESC(STATUS_INVALID_DISPOSITION) +DECLARE_DREXITCODE_NO_DESC(STATUS_INTEGER_DIVIDE_BY_ZERO) +DECLARE_DREXITCODE_NO_DESC(STATUS_INTEGER_OVERFLOW) +DECLARE_DREXITCODE_NO_DESC(STATUS_NONCONTINUABLE_EXCEPTION) +DECLARE_DREXITCODE_NO_DESC(STATUS_NOT_SUPPORTED) +DECLARE_DREXITCODE_NO_DESC(STATUS_PRIVILEGED_INSTRUCTION) +DECLARE_DREXITCODE_NO_DESC(STATUS_REG_NAT_CONSUMPTION) +DECLARE_DREXITCODE_NO_DESC(STATUS_SINGLE_STEP) +DECLARE_DREXITCODE_NO_DESC(STATUS_STACK_OVERFLOW) +DECLARE_DREXITCODE_NO_DESC(STATUS_ASSERTION_FAILURE) +DECLARE_DREXITCODE_NO_DESC(STATUS_UNHANDLED_EXCEPTION) + diff --git a/DryadVertex/VertexHost/system/classlib/include/DrExitCodesDef.h b/DryadVertex/VertexHost/system/classlib/include/DrExitCodesDef.h new file mode 100644 index 0000000..51eb982 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrExitCodesDef.h @@ -0,0 +1,64 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +typedef UInt32 DrExitCode; + +void DrInitExitCodeTable(); +DrError DrAddExitCodeDescription(DrExitCode code, const char *pszDescription); +DrStr& DrAppendExitCodeDescription(DrStr& strOut, DrExitCode code); +DrStr& DrGetExitCodeDescription(DrStr& strOut, DrExitCode code); + +class DrExitCodeString : public DrStr128 +{ +public: + DrExitCodeString(DrExitCode code) + { + DrGetExitCodeDescription(*this, code); + } +}; + +// This macro can be used to obtain a temporary "const char *" error description for an exit code. It can be used +// as the parameter to a method call; the pointer will become invalid after the function returns +#define DREXITCODESTRING(code) (DrExitCodeString(code).GetString()) +#define DREXITCODEWSTRING(code) (DRUTF8TOWSTRING(DREXITCODESTRING(code))) + + + +#define DEFINE_DREXITCODE_NO_DESC(name, value) DEFINE_DREXITCODE(name, (DrExitCode)(value), #name) +#define DECLARE_DREXITCODE_NO_DESC(valname) DECLARE_DREXITCODE(valname, #valname) + +#ifdef DEFINE_DREXITCODE +#undef DEFINE_DREXITCODE +#endif +#ifdef DECLARE_DREXITCODE +#undef DECLARE_DREXITCODE +#endif +#define DEFINE_DREXITCODE(name, value, description) const DrExitCode name = (DrExitCode)(value); +#define DECLARE_DREXITCODE(valname, description) + +#include "DrExitCodes.h" + +#undef DEFINE_DREXITCODE +#undef DECLARE_DREXITCODE + + + diff --git a/DryadVertex/VertexHost/system/classlib/include/DrFPrint.h b/DryadVertex/VertexHost/system/classlib/include/DrFPrint.h new file mode 100644 index 0000000..71e935a --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrFPrint.h @@ -0,0 +1,193 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/* (c) Microsoft Corporation. All rights reserved. */ +#pragma once +#include +#include "basic_types.h" + +#undef IndexAssert +#define IndexAssert(expr) LogAssert(expr) +typedef __int64 Dryad_dupelim_fprint_int64_t; +typedef UInt64 Dryad_dupelim_fprint_uint64_t; + +/* the type of a 64-bit fingerprint */ +typedef Dryad_dupelim_fprint_uint64_t Dryad_dupelim_fprint_t; + +/* an opaque type used to keep the data structures need to compute + fingerprints. */ + +typedef struct Dryad_dupelim_fprint_data_s *Dryad_dupelim_fprint_data_t; +typedef const struct Dryad_dupelim_fprint_data_s *Dryad_dupelim_fprint_data_tc; + +/* hash lengths */ +enum HashPolyLength{ + Poly8bit, + Poly16bit, + Poly32bit, + Poly64bit} ; + +/* Allocate and return a new Rabin fingerprint function. + Rabin fingerprint belongs to the family of CRC hashes + Its collusion is bounded by a very small number + Since it employs polynomials in a galois field, it is very + efficient in calculating recursive hashes + + + for straight-forward applications, use Dryad_dupelim_rabinhash_create() and + Dryad_dupelim_rabinhash_process() in your applications. unless you + understand what the other functions exactly do, refrain from using them. + + Dryad_dupelim_rabinhash_new()---------------------- + returns true if a hash function is created, + fprint data structure should already been allocated + returns a pointer to the hash function created on pHashfunction + hashLen is the order of polynomials to be used for the hash function + seed is the index of the polynomial to be used in the hash function + seed has to be less than or equal to cbPolysN {N = 8 , 16, 32, 64} + otherwise, Dryad_dupelim_rabinhash_new will return false +*/ +bool Dryad_dupelim_rabinhash_init (Dryad_dupelim_fprint_data_s* pHashData, + HashPolyLength hashLen, + UInt32 seed); +/* if fp was generated with polynomial P,bytes "data[0, ..., len-1]" + contain string A, return the fingerprint under P of A. + Strings are treated as polynomials. The low-order bit in the first + byte is the highest degree coefficient in the polynomial. +*/ +Dryad_dupelim_fprint_t Dryad_dupelim_rabinhash_process(Dryad_dupelim_fprint_data_s* pHashData, + const unsigned char *data, unsigned len); + +/* if fp was generated with polynomial P,bytes "data[0, ..., len-1]" + contain string B, and initialHash contains the hash value for string A + return the fingerprint under P of A added to initialHash. + the output value is merely the hash of string A concat string B. + Strings are treated as polynomials. The low-order bit in the first + byte is the highest degree coefficient in the polynomial. +*/ + +Dryad_dupelim_fprint_t Dryad_dupelim_rabinhash_add(Dryad_dupelim_fprint_data_s* pHashFunction, Dryad_dupelim_fprint_t initialHash, + const unsigned char *data, unsigned len); + + +/* Allocate and return a new fingerprint function. + + Computes the tables needed for fingerprint manipulations. + Requires that "poly" be the binary representation + of an irreducible polynomial in GF(2) of degree 64. The X^64 term + is not represented. The X^0 term is the high order bit, and the + X^63 term is the low-order bit. + + span is used in later calls to Dryad_dupelim_fprint_slide_word(). + If Dryad_dupelim_fprint_slide_word() is not to be called, span + should be set to zero. */ +Dryad_dupelim_fprint_data_t Dryad_dupelim_fprint_new (Dryad_dupelim_fprint_t poly, + unsigned span); + +/* Like "new" above, except that the degree can be any value between 1 + and 64. Return 0 if that's not true. + + The X^(degree-1) term is in the low-order bit of poly. +*/ +Dryad_dupelim_fprint_data_t Dryad_dupelim_fprint_new2 (Dryad_dupelim_fprint_t poly, + unsigned span, int degree); + +/* returns the seeded polynomial ie. fingerprint of an empty element under this fp */ +Dryad_dupelim_fprint_t Dryad_dupelim_fprint_empty (Dryad_dupelim_fprint_data_tc fp); + +/* if fp was generated with polynomial P, "a" is the fingerprint under + P of string A, and bytes "data[0, ..., len-1]" contain string B, + return the fingerprint under P of the concatenation of A and B. + Strings are treated as polynomials. The low-order bit in the first + byte is the highest degree coefficient in the polynomial. This + routine differs from Dryad_dupelim_fprint_extend_word() in that it + will read bytes in increasing address order, regardless of the + endianness of the machine. + data's length is the number of unsigned chars +*/ +Dryad_dupelim_fprint_t Dryad_dupelim_fprint_extend (Dryad_dupelim_fprint_data_tc fp, + Dryad_dupelim_fprint_t a, + const unsigned char *data, unsigned len); + +/* If fp was generated with polynomial P, "a" is the fingerprint under + P of string A, and 64-bit words "data[0, ..., len-1]" contain + string B, return the fingerprint under P of the concatenation of A + and B. Arrays of words are treated as polynomials. The low-order + bit in the first word is the highest degree coefficient in the + polynomial. This routine differs from Dryad_dupelim_fprint_extend() + on bigendian machines, where the byte order within each word is + backwards. */ +Dryad_dupelim_fprint_t +Dryad_dupelim_fprint_extend_word (Dryad_dupelim_fprint_data_tc fp, + Dryad_dupelim_fprint_t a, + const Dryad_dupelim_fprint_uint64_t *data, + unsigned len); + +/* if fp was generated with polynomial P, "a" is the fingerprint under + P of string A, and "b" is the fingerprint under P of string B, + which has length "blen" bytes, return the fingerprint under P of + the concatenation of A and B */ +Dryad_dupelim_fprint_t +Dryad_dupelim_fprint_concat(Dryad_dupelim_fprint_data_tc fp, + Dryad_dupelim_fprint_t a, Dryad_dupelim_fprint_t + b, Dryad_dupelim_fprint_t blen); + + +/* Turn fingerprint "f" into a hexadecimal, ascii-zero-filled + printable string S of length 16, and place the characters in + buf[0,...,15]. No null terminator is written by the routine. */ +void Dryad_dupelim_fprint_toascii (Dryad_dupelim_fprint_t f, char *buf); + +/* if "fp" was generated with polynomial P, X is some string of length + "(span-1)*sizeof (Dryad_dupelim_fprint_uint64_t)" bytes (see + Dryad_dupelim_fprint_new()), "f" is the fingerprint under P of word + "a" concatenated with X, return the fingerprint under P of X + concatenated with word "b". The words "a" and "b" represent + polynomials whose X^0 term is in the high-order bit, and whose X^63 + term is in the low order bit. */ +Dryad_dupelim_fprint_t +Dryad_dupelim_fprint_slideword (Dryad_dupelim_fprint_data_tc fp, + Dryad_dupelim_fprint_t f, + Dryad_dupelim_fprint_uint64_t a, + Dryad_dupelim_fprint_uint64_t b); + +/* discard the data associated with "fp" */ +void Dryad_dupelim_fprint_close (Dryad_dupelim_fprint_data_t fp); + +/* fprint struct */ +struct Dryad_dupelim_fprint_data_s { + Dryad_dupelim_fprint_t poly[2]; + /* poly[0] = 0; poly[1] = polynomial */ + Dryad_dupelim_fprint_t empty; + /* fingerprint of the empty string */ + Dryad_dupelim_fprint_t bybyte[8][256]; + /* bybyte[b][i] is i*X^(degree+8*b) mod poly[1] */ + Dryad_dupelim_fprint_t powers[64]; + /* extend[i] is X^(8*2^i) mod poly[1] */ + static const UInt32 LOGZEROBLOCK = 8; + static const UInt32 ZEROBLOCK = (1 << LOGZEROBLOCK); + union { + double align; + unsigned char zeroes[ZEROBLOCK]; + } zeroes; + Dryad_dupelim_fprint_t bybyte_out[8][256]; + /* bybyte_out[b][i] is i*X^(degree+8*(b+span)) mod poly[1] */ + unsigned span; +}; diff --git a/DryadVertex/VertexHost/system/classlib/include/DrFPrint_polynomials.h b/DryadVertex/VertexHost/system/classlib/include/DrFPrint_polynomials.h new file mode 100644 index 0000000..9bfdb30 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrFPrint_polynomials.h @@ -0,0 +1,190 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// This file defines 4 sets of GF[2] irreducible polynomials +// These polynomials are of the order 8, 16, 32 and 64 +// You can use these to initialize fprint_data + +static const Dryad_dupelim_fprint_t polys8[] = { + (Dryad_dupelim_fprint_t)0x00d4UL, + (Dryad_dupelim_fprint_t)0x00b4UL, + (Dryad_dupelim_fprint_t)0x00b1UL, + (Dryad_dupelim_fprint_t)0x00b2UL, + (Dryad_dupelim_fprint_t)0x0095UL, + (Dryad_dupelim_fprint_t)0x00afUL, + (Dryad_dupelim_fprint_t)0x00b2UL, + (Dryad_dupelim_fprint_t)0x00a6UL, + (Dryad_dupelim_fprint_t)0x0096UL +}; + +static const Dryad_dupelim_fprint_t polys16[] = { + (Dryad_dupelim_fprint_t)0x00009ee6UL, + (Dryad_dupelim_fprint_t)0x0000dfb5UL, + (Dryad_dupelim_fprint_t)0x0000e95dUL, + (Dryad_dupelim_fprint_t)0x0000ab23UL, + (Dryad_dupelim_fprint_t)0x00009566UL, + (Dryad_dupelim_fprint_t)0x0000d5e9UL, + (Dryad_dupelim_fprint_t)0x000086c1UL, + (Dryad_dupelim_fprint_t)0x000082c3UL, + (Dryad_dupelim_fprint_t)0x0000a485UL, + (Dryad_dupelim_fprint_t)0x00008b55UL, + (Dryad_dupelim_fprint_t)0x00008b4dUL, + (Dryad_dupelim_fprint_t)0x0000883fUL, + (Dryad_dupelim_fprint_t)0x0000bd5cUL, + (Dryad_dupelim_fprint_t)0x0000e87dUL, + (Dryad_dupelim_fprint_t)0x000082b4UL, + (Dryad_dupelim_fprint_t)0x0000c036UL, + (Dryad_dupelim_fprint_t)0x000097e9UL, + (Dryad_dupelim_fprint_t)0x00009e98UL, + (Dryad_dupelim_fprint_t)0x000099f9UL, + (Dryad_dupelim_fprint_t)0x0000fa93UL, +}; + +static const Dryad_dupelim_fprint_t polys32[] = { + (Dryad_dupelim_fprint_t)0x8b950699UL, + (Dryad_dupelim_fprint_t)0xf8c45e2aUL, + (Dryad_dupelim_fprint_t)0xdfdac578UL, + (Dryad_dupelim_fprint_t)0x896f6717UL, + (Dryad_dupelim_fprint_t)0xb2ab5f5dUL, + (Dryad_dupelim_fprint_t)0xece51013UL, + (Dryad_dupelim_fprint_t)0xc9ed9c7bUL, + (Dryad_dupelim_fprint_t)0xb2b28a80UL, + (Dryad_dupelim_fprint_t)0xb03c9ed2UL, + (Dryad_dupelim_fprint_t)0x85cd5087UL, + (Dryad_dupelim_fprint_t)0xcb7d544eUL, + (Dryad_dupelim_fprint_t)0xf090b664UL, + (Dryad_dupelim_fprint_t)0xfe442fe2UL, + (Dryad_dupelim_fprint_t)0x80a0adc0UL, + (Dryad_dupelim_fprint_t)0x9132521fUL, + (Dryad_dupelim_fprint_t)0xeca10123UL, + (Dryad_dupelim_fprint_t)0xf06b52c3UL, + (Dryad_dupelim_fprint_t)0x87b146b5UL, + (Dryad_dupelim_fprint_t)0xc6b63122UL, + (Dryad_dupelim_fprint_t)0xaa109fabUL +}; + +const Dryad_dupelim_fprint_t polys64[] = { + (((Dryad_dupelim_fprint_t)0xb40ab24eUL) << 32) | (Dryad_dupelim_fprint_t)0x49737109UL, + (((Dryad_dupelim_fprint_t)0xc0398760UL) << 32) | (Dryad_dupelim_fprint_t)0xd3108fd6UL, + (((Dryad_dupelim_fprint_t)0xd869093fUL) << 32) | (Dryad_dupelim_fprint_t)0x2ebec587UL, + (((Dryad_dupelim_fprint_t)0xa6ab08f8UL) << 32) | (Dryad_dupelim_fprint_t)0x00c128c9UL, + (((Dryad_dupelim_fprint_t)0xa629a9c4UL) << 32) | (Dryad_dupelim_fprint_t)0x60a8edfbUL, + (((Dryad_dupelim_fprint_t)0xd422e286UL) << 32) | (Dryad_dupelim_fprint_t)0x78b47614UL, + (((Dryad_dupelim_fprint_t)0x93facdf9UL) << 32) | (Dryad_dupelim_fprint_t)0xbc1363a2UL, + (((Dryad_dupelim_fprint_t)0x93caa3c5UL) << 32) | (Dryad_dupelim_fprint_t)0xdd40d768UL, + (((Dryad_dupelim_fprint_t)0xaa53204aUL) << 32) | (Dryad_dupelim_fprint_t)0x7969914eUL, + (((Dryad_dupelim_fprint_t)0xe2415fb3UL) << 32) | (Dryad_dupelim_fprint_t)0x440a16bbUL, + (((Dryad_dupelim_fprint_t)0xa05f3d02UL) << 32) | (Dryad_dupelim_fprint_t)0x95be208fUL, + (((Dryad_dupelim_fprint_t)0xb1e61188UL) << 32) | (Dryad_dupelim_fprint_t)0x6ec27c88UL, + (((Dryad_dupelim_fprint_t)0xd6d2bc63UL) << 32) | (Dryad_dupelim_fprint_t)0xc91d290eUL, + (((Dryad_dupelim_fprint_t)0xf80f25b8UL) << 32) | (Dryad_dupelim_fprint_t)0xc1930eccUL, + (((Dryad_dupelim_fprint_t)0x97dc1fd1UL) << 32) | (Dryad_dupelim_fprint_t)0x15e0e70eUL, + (((Dryad_dupelim_fprint_t)0xe17f23cdUL) << 32) | (Dryad_dupelim_fprint_t)0x55fe08aeUL, + (((Dryad_dupelim_fprint_t)0xd309c54aUL) << 32) | (Dryad_dupelim_fprint_t)0xe0d66600UL, + (((Dryad_dupelim_fprint_t)0xb55bd691UL) << 32) | (Dryad_dupelim_fprint_t)0x17e20f21UL, + (((Dryad_dupelim_fprint_t)0x9b19a5d4UL) << 32) | (Dryad_dupelim_fprint_t)0xd4f5ccbeUL, + (((Dryad_dupelim_fprint_t)0xcbca35d9UL) << 32) | (Dryad_dupelim_fprint_t)0xab901b9bUL, + (((Dryad_dupelim_fprint_t)0x889417edUL) << 32) | (Dryad_dupelim_fprint_t)0x965534ddUL, + (((Dryad_dupelim_fprint_t)0x8f27c100UL) << 32) | (Dryad_dupelim_fprint_t)0xbd898837UL, + (((Dryad_dupelim_fprint_t)0x930fc2d3UL) << 32) | (Dryad_dupelim_fprint_t)0x4cc207e3UL, + (((Dryad_dupelim_fprint_t)0xba0920c3UL) << 32) | (Dryad_dupelim_fprint_t)0xf1c7b364UL, + (((Dryad_dupelim_fprint_t)0x80d46b49UL) << 32) | (Dryad_dupelim_fprint_t)0xcfadf5ccUL, + (((Dryad_dupelim_fprint_t)0xb45b9d25UL) << 32) | (Dryad_dupelim_fprint_t)0x2b5d6071UL, + (((Dryad_dupelim_fprint_t)0x9fe4d82fUL) << 32) | (Dryad_dupelim_fprint_t)0x5fd432d2UL, + (((Dryad_dupelim_fprint_t)0xa97d6763UL) << 32) | (Dryad_dupelim_fprint_t)0xd5f818b3UL, + (((Dryad_dupelim_fprint_t)0xe8d6b0beUL) << 32) | (Dryad_dupelim_fprint_t)0x7c43649dUL, + (((Dryad_dupelim_fprint_t)0xbc673c33UL) << 32) | (Dryad_dupelim_fprint_t)0xfbe55129UL, + (((Dryad_dupelim_fprint_t)0xec03ce27UL) << 32) | (Dryad_dupelim_fprint_t)0xf7509ae5UL, + (((Dryad_dupelim_fprint_t)0x808401d4UL) << 32) | (Dryad_dupelim_fprint_t)0x40abf627UL, + (((Dryad_dupelim_fprint_t)0x95c51b3dUL) << 32) | (Dryad_dupelim_fprint_t)0x387ce64bUL, + (((Dryad_dupelim_fprint_t)0xa5a59bd2UL) << 32) | (Dryad_dupelim_fprint_t)0x7d3f452dUL, + (((Dryad_dupelim_fprint_t)0xe429f8beUL) << 32) | (Dryad_dupelim_fprint_t)0x22291027UL, + (((Dryad_dupelim_fprint_t)0xe4764c26UL) << 32) | (Dryad_dupelim_fprint_t)0x913308e0UL, + (((Dryad_dupelim_fprint_t)0xafd52ea1UL) << 32) | (Dryad_dupelim_fprint_t)0x35797bdaUL, + (((Dryad_dupelim_fprint_t)0xeb04bdfeUL) << 32) | (Dryad_dupelim_fprint_t)0xa0163482UL, + (((Dryad_dupelim_fprint_t)0x9e81f8b8UL) << 32) | (Dryad_dupelim_fprint_t)0xd63a6b87UL, + (((Dryad_dupelim_fprint_t)0xd320f803UL) << 32) | (Dryad_dupelim_fprint_t)0x485563aeUL, + (((Dryad_dupelim_fprint_t)0x8af88fe4UL) << 32) | (Dryad_dupelim_fprint_t)0x09983363UL, + (((Dryad_dupelim_fprint_t)0xd66102feUL) << 32) | (Dryad_dupelim_fprint_t)0xf6ccfe37UL, + (((Dryad_dupelim_fprint_t)0xa93e4704UL) << 32) | (Dryad_dupelim_fprint_t)0x3985cda0UL, + (((Dryad_dupelim_fprint_t)0x88bf43afUL) << 32) | (Dryad_dupelim_fprint_t)0x43565fa7UL, + (((Dryad_dupelim_fprint_t)0xbebb7241UL) << 32) | (Dryad_dupelim_fprint_t)0x360adb47UL, + (((Dryad_dupelim_fprint_t)0xd399e12dUL) << 32) | (Dryad_dupelim_fprint_t)0xea25d131UL, + (((Dryad_dupelim_fprint_t)0xd03a3d3cUL) << 32) | (Dryad_dupelim_fprint_t)0x20aa87f4UL, + (((Dryad_dupelim_fprint_t)0x8111202dUL) << 32) | (Dryad_dupelim_fprint_t)0x77c4b0a8UL, + (((Dryad_dupelim_fprint_t)0xc62d960fUL) << 32) | (Dryad_dupelim_fprint_t)0xccc5ba7fUL, + (((Dryad_dupelim_fprint_t)0x9d94edd9UL) << 32) | (Dryad_dupelim_fprint_t)0xe31c0833UL, + (((Dryad_dupelim_fprint_t)0xa926bc80UL) << 32) | (Dryad_dupelim_fprint_t)0x10d838e0UL, + (((Dryad_dupelim_fprint_t)0xf3c8b809UL) << 32) | (Dryad_dupelim_fprint_t)0x89f6395aUL, + (((Dryad_dupelim_fprint_t)0x99824e83UL) << 32) | (Dryad_dupelim_fprint_t)0xb5562fbaUL, + (((Dryad_dupelim_fprint_t)0xd87d11f3UL) << 32) | (Dryad_dupelim_fprint_t)0xa1ae7f31UL, + (((Dryad_dupelim_fprint_t)0xadb9b99eUL) << 32) | (Dryad_dupelim_fprint_t)0x5d44d4eaUL, + (((Dryad_dupelim_fprint_t)0xaef654bbUL) << 32) | (Dryad_dupelim_fprint_t)0x644fe26aUL, + (((Dryad_dupelim_fprint_t)0xcbf16d7aUL) << 32) | (Dryad_dupelim_fprint_t)0xc4a259e8UL, + (((Dryad_dupelim_fprint_t)0x8a1a38ceUL) << 32) | (Dryad_dupelim_fprint_t)0x068a8e79UL, + (((Dryad_dupelim_fprint_t)0xfc5207dcUL) << 32) | (Dryad_dupelim_fprint_t)0x711c0a9fUL, + (((Dryad_dupelim_fprint_t)0xd30ddb1bUL) << 32) | (Dryad_dupelim_fprint_t)0xa0f02884UL, + (((Dryad_dupelim_fprint_t)0xd48fc688UL) << 32) | (Dryad_dupelim_fprint_t)0x376f2998UL, + (((Dryad_dupelim_fprint_t)0xa79f0024UL) << 32) | (Dryad_dupelim_fprint_t)0xe168fb6eUL, + (((Dryad_dupelim_fprint_t)0x80709fe6UL) << 32) | (Dryad_dupelim_fprint_t)0xa7dd8d6fUL, + (((Dryad_dupelim_fprint_t)0xc8771453UL) << 32) | (Dryad_dupelim_fprint_t)0xabb9e8e3UL, + (((Dryad_dupelim_fprint_t)0xc9e8268eUL) << 32) | (Dryad_dupelim_fprint_t)0xfb9fd8a3UL, + (((Dryad_dupelim_fprint_t)0xc994dbf7UL) << 32) | (Dryad_dupelim_fprint_t)0xc566278eUL, + (((Dryad_dupelim_fprint_t)0xddd80109UL) << 32) | (Dryad_dupelim_fprint_t)0xc37bd67bUL, + (((Dryad_dupelim_fprint_t)0xa9cc5534UL) << 32) | (Dryad_dupelim_fprint_t)0x8f13c673UL, + (((Dryad_dupelim_fprint_t)0xa36d7a45UL) << 32) | (Dryad_dupelim_fprint_t)0xd27bc907UL, + (((Dryad_dupelim_fprint_t)0xd7e2a78cUL) << 32) | (Dryad_dupelim_fprint_t)0x66663257UL, + (((Dryad_dupelim_fprint_t)0xdd426ee6UL) << 32) | (Dryad_dupelim_fprint_t)0x7c908039UL, + (((Dryad_dupelim_fprint_t)0xc80996c7UL) << 32) | (Dryad_dupelim_fprint_t)0x916f5fc8UL, + (((Dryad_dupelim_fprint_t)0xf9a6c515UL) << 32) | (Dryad_dupelim_fprint_t)0x3d62dc96UL, + (((Dryad_dupelim_fprint_t)0x8267aaa0UL) << 32) | (Dryad_dupelim_fprint_t)0xc80b20a6UL, + (((Dryad_dupelim_fprint_t)0xdeb59e2dUL) << 32) | (Dryad_dupelim_fprint_t)0xb3e430a8UL, + (((Dryad_dupelim_fprint_t)0xa03fa280UL) << 32) | (Dryad_dupelim_fprint_t)0x2d0318a9UL, + (((Dryad_dupelim_fprint_t)0x83b7afb5UL) << 32) | (Dryad_dupelim_fprint_t)0xc47e0dfcUL, + (((Dryad_dupelim_fprint_t)0x8752b710UL) << 32) | (Dryad_dupelim_fprint_t)0xe740bfa9UL, + (((Dryad_dupelim_fprint_t)0xa6ee843cUL) << 32) | (Dryad_dupelim_fprint_t)0x1df1006eUL, + (((Dryad_dupelim_fprint_t)0x814705bfUL) << 32) | (Dryad_dupelim_fprint_t)0x21a7a80eUL, + (((Dryad_dupelim_fprint_t)0xf3feedbaUL) << 32) | (Dryad_dupelim_fprint_t)0x611a554dUL, + (((Dryad_dupelim_fprint_t)0xdbe78addUL) << 32) | (Dryad_dupelim_fprint_t)0xf2daa748UL, + (((Dryad_dupelim_fprint_t)0x961e7a41UL) << 32) | (Dryad_dupelim_fprint_t)0x615851ccUL, + (((Dryad_dupelim_fprint_t)0xdb85afd5UL) << 32) | (Dryad_dupelim_fprint_t)0x496a1c1dUL, + (((Dryad_dupelim_fprint_t)0xbadd6e78UL) << 32) | (Dryad_dupelim_fprint_t)0x2e2ba8ceUL, + (((Dryad_dupelim_fprint_t)0xaf93ef6dUL) << 32) | (Dryad_dupelim_fprint_t)0x2abed356UL, + (((Dryad_dupelim_fprint_t)0xc645141aUL) << 32) | (Dryad_dupelim_fprint_t)0xd5794d6cUL, + (((Dryad_dupelim_fprint_t)0xd86e9600UL) << 32) | (Dryad_dupelim_fprint_t)0x582cb555UL, + (((Dryad_dupelim_fprint_t)0xc39d12b4UL) << 32) | (Dryad_dupelim_fprint_t)0x25fe98a3UL, + (((Dryad_dupelim_fprint_t)0x8c346762UL) << 32) | (Dryad_dupelim_fprint_t)0x9a5f7296UL, + (((Dryad_dupelim_fprint_t)0x9f373e3cUL) << 32) | (Dryad_dupelim_fprint_t)0x90100d71UL, + (((Dryad_dupelim_fprint_t)0xb00c9e7bUL) << 32) | (Dryad_dupelim_fprint_t)0x68d20287UL, + (((Dryad_dupelim_fprint_t)0x9f6f838bUL) << 32) | (Dryad_dupelim_fprint_t)0x293b2e4aUL, + (((Dryad_dupelim_fprint_t)0xcbd55e6bUL) << 32) | (Dryad_dupelim_fprint_t)0xb5990fdcUL, + (((Dryad_dupelim_fprint_t)0xc9ca494cUL) << 32) | (Dryad_dupelim_fprint_t)0x50fcc7c8UL, + (((Dryad_dupelim_fprint_t)0xe7e36ad9UL) << 32) | (Dryad_dupelim_fprint_t)0x68b357d0UL, + (((Dryad_dupelim_fprint_t)0x88f27f83UL) << 32) | (Dryad_dupelim_fprint_t)0xc0204576UL, + (((Dryad_dupelim_fprint_t)0x9b17ad6fUL) << 32) | (Dryad_dupelim_fprint_t)0x4c8a74b2UL, + (((Dryad_dupelim_fprint_t)0xe0cfbf08UL) << 32) | (Dryad_dupelim_fprint_t)0x5660db1cUL, + (((Dryad_dupelim_fprint_t)0x982f1507UL) << 32) | (Dryad_dupelim_fprint_t)0x9f214ce0UL +}; + +const UInt64 cbPolys8 =(sizeof(polys8) / sizeof(polys8[0])); +const UInt64 cbPolys16 =(sizeof(polys16) / sizeof(polys16[0])); +const UInt64 cbPolys32 =(sizeof(polys32) / sizeof(polys32[0])); +const UInt64 cbPolys64 = sizeof (polys64) / sizeof (polys64[0]); + diff --git a/DryadVertex/VertexHost/system/classlib/include/DrFunctions.h b/DryadVertex/VertexHost/system/classlib/include/DrFunctions.h new file mode 100644 index 0000000..ee40fd9 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrFunctions.h @@ -0,0 +1,119 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once +#define _ATL_CSTRING_EXPLICIT_CONSTRUCTORS // some CString constructors will be explicit + + +#include +#include + +DrError DrStringToSignedOrUnsignedInt64(const char *psz, UInt64 *pResult, bool fSigned); +DrError DrStringToFloat(const char *psz, float *pResult); +DrError DrStringToInt64(const char *psz, Int64 *pResult); +DrError DrStringToInt32(const char *psz, Int32 *pResult); +DrError DrStringToUInt64(const char *psz, UInt64 *pResult); +DrError DrStringToUInt32(const char *psz, UInt32 *pResult); +DrError DrStringToUInt16(const char *psz, UInt16 *pResult); +DrError DrStringToPortNumber(const char *psz, DrPortNumber *pResult); +DrError DrStringToUInt(const char *psz, unsigned int *pResult); +DrError DrStringToInt(const char *psz, int *pResult); +DrError DrStringToDouble(const char *psz, double *pResult); +DrError DrStringToBool(const char *psz, bool *pResult); + +// Convert string to size +// String may contain KB, MB, GB, TB, PB at the end, i.e. 12KB. It's case-insensitive. +// If suffix is present then fractions are allowed , i.e. 20.5MB +DrError DrStringToSize(PCSTR psz, UInt64* result); +DrError DrStringToIntegerSize(PCSTR psz, Int64* result); //parses negative sizes as well +DrError DrStringToSizeEx(PCSTR psz, UInt64* result, bool allowNegative); + +/* Returns a timestamp representing the current date and time (UTC time). Note that this time may change + suddenly due to system clock updates and may even move backwards. */ +DrTimeStamp DrGetCurrentTimeStamp(); + +// Returns the elapsed time between two Dryad timestamps. May be negative. +inline DrTimeInterval DrGetElapsedTime(DrTimeStamp tStart, DrTimeStamp tEnd) +{ + return (DrTimeInterval)(tEnd - tStart); +} + +// Reads an environment variable +// the returned value must be freed with free() +// Note that the varname and returned value are both UTF-8 encoded +DrError DrGetEnvironmentVariable(const char *pszVarName, /* out */ const char **ppszValue); + +// +// Get the environment variable using WCHARs +// +DrError DrGetEnvironmentVariable(const WCHAR *pszVarName, WCHAR ppszValue[]); + +// +// Get the SID for a user +// +DrError DrGetSidForUser(LPCWSTR domainUserName, PSID* ppSid); + +// +// Get the computer name whether it's on-premises or in Azure +// +DrError DrGetComputerName(WCHAR ppszValue[]); + +__inline DWORD DrGetTimerMsFromInterval(DrTimeInterval t) +{ + if (t == DrTimeInterval_Infinite) { + return INFINITE; + } + LogAssert (t >= DrTimeInterval_Zero && t < (DrTimeInterval_Millisecond * 0x7fffffff)); + + // round up since timers are a minimum time + return (DWORD)((t + DrTimeInterval_Millisecond - DrTimeInterval_Quantum) / DrTimeInterval_Millisecond); +} + +// Same as ExitProcess, but flushes logging and stdout/stderr first... +void DrExitProcess(UInt32 exitCode); + +// Converts a Win32 SYSTEMTIME structure, either in UTC or local time, to a Dryad timestamp +DrError DrSystemTimeToTimeStamp(const SYSTEMTIME *pSystemTime, /* out */ DrTimeStamp *pTimeStamp, bool fFromLocalTimeZone = false); + +// Generates a string to append to a timestamp string to identify a local time zone ("form "Z" or "L+7h") +DrError DrGenerateTimeZoneBiasSuffix(DrTimeInterval bias, char *szBuff, size_t nbBuff); + +// Converts a time interval string to a DrTimeInterval. +// If len is -1, it is computed with strlen. +// Strings must include units; e.g., "105.42s" or "12d5h10m". +DrError DrStringToTimeInterval(const char *pszString, DrTimeInterval *pTimeInterval, int len = -1); + +// Converts a Dryad timestamp to a human-readable string, either in UTC (if bias is 0) or local time (according to bias). If bias is given the +// special value DrTimeInterval_Infinite, the default local bias is used. +// nFracDig Is the number of fractional second digits to include. If -1, either 0 or 3 digits will be included depending on whether the time represents an integral second. +DrError DrTimeStampToString(DrTimeStamp timeStamp, char *pBuffer, int buffLen, DrTimeInterval bias = DrTimeInterval_Zero, Int32 nFracDig=-1); +DrError DrTimeStampToString(DrTimeStamp timeStamp, char *pBuffer, int buffLen, bool fToLocalTimeZone, Int32 nFracDig=-1); + +static const size_t k_DrTimeIntervalStringBufferSize = 64; // this size guarrantees successful completion of DrTimeIntervalToString + +// Converts a cosmos timeinterval to a human-readable string. +// The generated string may be fed back into DrStringToTimeInterval +DrError DrTimeIntervalToString(DrTimeInterval timeInterval, char *pBuffer, size_t buffLen); + +DrError DrStringToTimeStamp(const char *pszTime, DrTimeStamp *pTimeStampOut, bool fDefaultLocalTimeZone = true); +DrError DrStringToTimeStamp(const char *pszTime, DrTimeStamp *pTimeStampOut, DrTimeInterval defaultTimeZoneBias); + +// Converts a Dryad timestamp to a Win32 SYSTEMTIME structure, either in UTC or local time +DrError DrTimeStampToSystemTime(DrTimeStamp timeStamp, /* out */ SYSTEMTIME *pSystemTime, bool fToLocalTimeZone = false); diff --git a/DryadVertex/VertexHost/system/classlib/include/DrGuid.h b/DryadVertex/VertexHost/system/classlib/include/DrGuid.h new file mode 100644 index 0000000..59879e6 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrGuid.h @@ -0,0 +1,155 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include "DrHash.h" + +#pragma once + +class DrStr; +//JC class DrWStr; + +extern const GUID g_DrInvalidGuid; +extern const GUID g_DrNullGuid; + +#define DrInvalidGuid ((const DrGuid &)g_DrInvalidGuid) +#define DrNullGuid ((const DrGuid &)g_DrNullGuid) + +// A class that encapsulates guids used within Dryad +// +// NOTE: This class has no constructor, so that it can be used in unions. Use +// DrInitializedGuid if you want the constructor version. +// +// This class must add no new member state since we cast from arbitrary GUIDs to this class. +class DrGuid : public GUID +{ +public: + DrGuid& Set(const GUID& other) + { + *(GUID *)this = other; + return *this; + } + + DrGuid& operator=(const GUID& other) + { + return Set(other); + } + + bool operator==(const GUID& other) const + { + return memcmp(this, &other, sizeof(DrGuid)) == 0; + } + + bool operator!=(const GUID& other) const + { + return memcmp(this, &other, sizeof(DrGuid)) != 0; + } + + bool operator>(const DrGuid& other) const + { + const unsigned int *pCur = (const unsigned int*) this; + const unsigned int *pOther = (const unsigned int*) &other; + + for(int i = 0; i < sizeof(DrGuid)/sizeof(int); i++) { + if(pCur[i] > pOther[i]) + return true; + else if(pCur[i] < pOther[i]) + return false; + } + + return false; // They are equal + } + + // Return a 32-bit hash for this guid + DWORD Hash() const + { + return DrHash32::Guid( this ); + } + + // Store the invalid guid as our guid + void Invalidate() + { + Set( g_DrInvalidGuid ); + } + + // Returns whether this is a valid guid + bool IsValid() const + { + return (*this != g_DrInvalidGuid) && !IsNull(); + } + + // Store the null guid as our guid + void SetToNull() + { + Set( g_DrNullGuid ); + } + + // Returns whether this is a valid guid + bool IsNull() const + { + return *this == g_DrNullGuid; + } + + // Generate a guid + void Generate(); + + // Parse a guid from a string. Acceptes guids either with or without braces + BOOL Parse(const char *string); + + // Parses guid from a string, return NULL on failure or pointer to next char after GUID in given string on success + // allowBraces - string may contain optional braces + // requireBraces - braces are required (allowBraces should be true) + // requireEOL - string should contain nothing after GUID (i.e. zero terminator should be present right after GUID) + const char* Parse(const char *string, bool allowBraces, bool requireBraces, bool requireEOL); + + const static size_t GuidStringLength = 39; //38 + null terminator + + // Output the guid in string form; {EFF6744C-7143-11cf-A51B-080036F12502}. If fBraces + // is false, the braces are omitted. + // String must be able to hold DrGuid::GuidStringLength (39 characters = 38 + null terminator) + char *ToString(char *string, bool fBraces=true) const; + + // appends the guid in string form; {EFF6744C-7143-11cf-A51B-080036F12502}. If fBraces + // is false, the braces are omitted. + DrStr& AppendToString(DrStr& strOut, bool fBraces=true) const; + + // appends the guid in string form; {EFF6744C-7143-11cf-A51B-080036F12502}. If fBraces + // is false, the braces are omitted. +//JC DrWStr& AppendToString(DrWStr& strOut, bool fBraces=true) const; + +protected: + // Helper function to convert a byte into 2 hex digits + static void ByteToHex(BYTE input, char *output); +}; + +class DrInitializedGuid : public DrGuid +{ +public: + DrInitializedGuid() + { + Set(g_DrInvalidGuid); + } + + DrInitializedGuid(const GUID& other) + { + Set(other); + } +}; + diff --git a/DryadVertex/VertexHost/system/classlib/include/DrHash.h b/DryadVertex/VertexHost/system/classlib/include/DrHash.h new file mode 100644 index 0000000..df8e513 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrHash.h @@ -0,0 +1,231 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// Fast, thorough hash functions returning 32 or 64 bit results. +// +// You can use hash table sizes that are a power of 2, and use & to +// trim values, for example: +// LOGSIZE=10; +// Buckets *hash_table[(1<> 32), &uHash1, &uHash2); + return uHash1 | (((UInt64)uHash2) << 32); + } + + // + // case-insensitive hash of null terminated string + // produce the same hash as Compute on an uppercased string + // + static const UInt64 StringI ( + const char *pString, + Size_t uSize, + UInt64 uSeed) + { + UInt32 uHash1, uHash2; + DrHash32::StringI2( pString, uSize, (UInt32) uSeed, (UInt32) (uSeed >> 32), &uHash1, &uHash2); + return uHash1 | (((UInt64)uHash2) << 32); + } + + // Compute a 64-bit hash of a GUID + static const UInt64 Guid ( const GUID *pGuid ) + { + UInt32 a = ((UInt32 *)pGuid)[0]; + UInt32 b = ((UInt32 *)pGuid)[1]; + UInt32 c = ((UInt32 *)pGuid)[2]; + DrHash32::Mix(a,b,c); + a ^= ((UInt32 *)pGuid)[3]; + DrHash32::Final(a,b,c); + return c | (((UInt64)b) << 32); + } +}; + +#pragma pack (pop) + +//JC} // namespace apsdk + +#ifdef USING_APSDK_NAMESPACE +using namespace apsdk; +#endif diff --git a/DryadVertex/VertexHost/system/classlib/include/DrHeap.h b/DryadVertex/VertexHost/system/classlib/include/DrHeap.h new file mode 100644 index 0000000..5f3c847 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrHeap.h @@ -0,0 +1,104 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#ifndef __DRYADHEAP_H__ +#define __DRYADHEAP_H__ + + +// Heap grows by 2*oldsize + c_heapGrowAmount +static const int c_heapGrowAmount = 100; + +class DryadHeapItem +{ +public: + // This is where the item is in the index + DWORD m_heapIndex; + + // Returns TRUE if we are supposed to be dequeued before "other" + virtual bool IsHigherPriorityThan(DryadHeapItem *other) = NULL; +}; + +// Implements a binary heap priority queue +class DryadHeap +{ +public: + DryadHeap() + { + m_entries = NULL; + m_numEntries = 0; + m_heapAllocSize = 0; + } + + // Initialize heap with the initial count of elements + void Initialize(int initialCount); + + // Dequeue the heap root; returns NULL if empty + DryadHeapItem* DequeueHeapRoot(); + + // Peek at the heap root without dequeueing; returns NULL if empty + DryadHeapItem* PeekHeapRoot(); + + // Insert a new entry into the heap + void InsertHeapEntry(DryadHeapItem *entry); + + // Remove the heap item at the given index + void RemoveHeapEntry(DWORD index); + +protected: + void DownHeapify(DWORD index); + void UpHeapify(DWORD index); + void Heapify(DWORD index); + + // Increases the size of the heap, when it needs more memory + void GrowHeap(); + + inline void HeapSwap(DWORD index1, DWORD index2); + + // Return whether an index exists + // Indices are 1-based, so 1 is the root and m_numEntries is the last entry + // e.g. ParentOf(1) will return FALSE + bool Exists(DWORD index) + { + return (index > 0 && index <= m_numEntries); + } + + // Note, ParentOf(root) will return root, so your code may need to check for this + static DWORD ParentOf(DWORD index) + { return ((index) >> 1); }; + + static DWORD LeftChild(DWORD index) + { return ((index) << 1); }; + + static DWORD RightChild(DWORD index) + { return (((index) << 1)+1); }; + + +public: + DryadHeapItem** m_entries; + + // Current # entries + DWORD m_numEntries; + + // # entries allocated + size_t m_heapAllocSize; +}; + + +#endif //end if not defined __DRYADHEAP_H__ diff --git a/DryadVertex/VertexHost/system/classlib/include/DrList.h b/DryadVertex/VertexHost/system/classlib/include/DrList.h new file mode 100644 index 0000000..02d75e3 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrList.h @@ -0,0 +1,606 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#ifndef __DRLIST_H__ +#define __DRLIST_H__ + +#pragma once + +// Simple Growable vector of arbitrary assignable typed values +// Insertions and deletions can be performed both at the head and at the tail of the list, making it suitable for queues. +// The underlying type must either be a simple scalar, a struct/class that is clonable with memcpy, or a class that implements copy constructor and assignment operator +// TODO: This version constructs all objects on reallocation, and does not destruct unused items until the list is destroyed. +// TODO: That is OK for simple types, but can waste space for types that allocate additional storage. +// TODO: to do it correctly, fix this class to use in-place contructor/destructor; construct object when it is added to list, destruct when it is removed. +template class DrValList sealed +{ +public: + DrValList(UInt32 uFirstAllocSize = 20) + { + m_nEntries = 0; + m_uFirstEntry = 0; + m_nAllocated = 0; + m_prgEntries = NULL; + m_uFirstAllocSize = uFirstAllocSize; + } + + ~DrValList() + { + if (m_prgEntries != NULL) { + delete[] m_prgEntries; + m_prgEntries = NULL; + } + m_nEntries = 0; + m_nAllocated = 0; + } + + // forces the buffer to be reallocated with the given size, even if it + // is the same as the current size. + // the requested size must be big enough to hold the valid entries. + // On exit, the valid entries are always contiguous starting at offset 0 + // if n==0, frees the buffer + void ForceRealloc(UInt32 n) + { + LogAssert(n >= m_nEntries); + t *pnew = NULL; + if (n != 0) { + pnew = new t[n]; + LogAssert(pnew != NULL); + UInt32 uFrom = m_uFirstEntry; + for (UInt32 i = 0; i < m_nEntries; i++) { + pnew[i]=m_prgEntries[uFrom++]; + if (uFrom >= m_nAllocated) { + uFrom = 0; + } + } + } + if (m_prgEntries != NULL) { + delete[] m_prgEntries; + } + m_prgEntries = pnew; + m_nAllocated = n; + m_uFirstEntry = 0; + } + + // reallocates the buffer if there are not at least n elements allocated + void GrowTo(UInt32 n) + { + if (n > m_nAllocated) { + if (n < 2 * m_nAllocated) { + n = 2 * m_nAllocated; + } + if (n < m_uFirstAllocSize) { + n = m_uFirstAllocSize; + } + ForceRealloc(n); + } + } + + // Converts a potentially wrapped list (if you have moved the head) into a + // contiguous list, and returns a pointer to the first item in the contiguous list + // If possible, nothing is moved. If the list is wrapped, a new buffer is allocated (easier than moving everything in a full list) + // This operation is always cheap if you never remove from or add to the head. + t *MakeContiguous() + { + if (m_uFirstEntry + m_nEntries > m_nAllocated) { + ForceRealloc(m_nEntries); + } + + return m_prgEntries + m_uFirstEntry; + } + + UInt32 NumEntries() const + { + return m_nEntries; + } + + bool IsEmpty() + { + return m_nEntries == 0; + } + + t& EntryAt(UInt32 index) + { + LogAssert(index < m_nEntries); + return m_prgEntries[NormalizeEntryIndex(index)]; + } + + const t& EntryAt(UInt32 index) const + { + LogAssert(index < m_nEntries); + return m_prgEntries[NormalizeEntryIndex(index)]; + } + + t& operator[](UInt32 index) + { + return EntryAt(index); + } + + const t& operator[](UInt32 index) const + { + return EntryAt(index); + } + + t& Head() + { + LogAssert(m_nEntries != 0); + return m_prgEntries[m_uFirstEntry]; + } + + const t& Head() const + { + LogAssert(m_nEntries != 0); + return m_prgEntries[m_uFirstEntry]; + } + + t& Tail() + { + LogAssert(m_nEntries != 0); + return EntryAt(m_nEntries-1); + } + + const t& Tail() const + { + LogAssert(m_nEntries != 0); + return EntryAt(m_nEntries-1); + } + + // LIFO-style top of stack + t& TopOfStack() + { + return Tail(); + } + + const t& TopOfStack() const + { + return Tail(); + } + + + // Invalidates all entry references previously returned + // This is the typical method used to implement Enqueue for FIFO queues, or Push for stacks + t& AddEntryToTail(const t& val) + { + GrowTo(m_nEntries+1); + t& newEntry = m_prgEntries[NormalizeEntryIndex(m_nEntries++)]; + newEntry = val; + return newEntry; + } + + // FIFO-style queueing + t& Enqueue(const t& val) + { + return AddEntryToTail(val); + } + + // LIFO-style stack push + t& Push(const t& val) + { + return AddEntryToTail(val); + } + + // Invalidates all entry references previously returned + t& AddEntryToHead(const t& val) + { + GrowTo(m_nEntries+1); + if (m_uFirstEntry == 0) { + m_uFirstEntry = m_nAllocated - 1; + } else { + m_uFirstEntry--; + } + m_nEntries++; + t& newEntry = m_prgEntries[m_uFirstEntry]; + newEntry = val; + return newEntry; + } + + // Invalidates all entry references previously returned + t& AddEntry(const t& val) + { + return AddEntryToTail(val); + } + + // This is the typical method used to emplement Pop for stacks + t& RemoveEntryFromTail(__out t *pValOut) + { + LogAssert(m_nEntries != 0); + *pValOut = m_prgEntries[NormalizeEntryIndex(--m_nEntries)]; + } + + t RemoveEntryFromTail() + { + LogAssert(m_nEntries != 0); + // NOTE: following code depends on not destructing the returned value until after we return. + // if we add in-place destructors, this has to change + const t& retVal = m_prgEntries[NormalizeEntryIndex(--m_nEntries)]; + return retVal; + } + + // LIFO-style stack pop + t Pop() + { + return RemoveEntryFromTail(); + } + + t& Pop(__out t *pValOut) + { + return RemoveEntryFromTail(pValOut); + } + + // This is the typical method used to emplement Pop for stacks + t& RemoveEntryFromHead(__out t *pValOut) + { + LogAssert(m_nEntries != 0); + *pValOut = m_prgEntries[m_uFirstEntry++]; + if (m_uFirstEntry >= m_nAllocated) { + m_uFirstEntry = 0; + } + m_nEntries--; + return *pValOut; + } + + t RemoveEntryFromHead() + { + LogAssert(m_nEntries != 0); + // NOTE: following code depends on not destructing the returned value until after we return. + // if we add in-place destructors, this has to change + const t& retVal = m_prgEntries[m_uFirstEntry++]; + if (m_uFirstEntry >= m_nAllocated) { + m_uFirstEntry = 0; + } + m_nEntries--; + return retVal; + } + + // FIFO-style dequeueing + t Dequeue() + { + return RemoveEntryFromHead(); + } + + t& Dequeue(__out t *pValOut) + { + return RemoveEntryFromHead(pValOut); + } + + + // does not shrink the allocated list or destruct existing entries. To do that, use ForceRealloc(0). + void Clear() + { + m_nEntries = 0; + m_uFirstEntry = 0; + } + + + // returns NULL if not in the list. + t *FindVal(const t& val) + { + for (UInt32 i = 0; i < m_nEntries; i++) { + t& entry = EntryAt(i); + if (entry == val) { + return &entry; + } + } + return NULL; + } + + // returns NULL if not in the list. + const t *FindVal(const t& val) const + { + for (UInt32 i = 0; i < m_nEntries; i++) { + const t& entry = EntryAt(i); + if (entry == val) { + return &entry; + } + } + return NULL; + } + + bool ContainsVal(const t& val) const + { + return FindVal(val) != NULL; + } + + static int __cdecl InternalDrValListEntryPointerCompare(void *context, const void *p1, const void *p2) + { + if (**( const t**)p1 > **(const t**)p2) { + return 1; + } else if (**( const t**)p1 ==**(const t**)p2) { + return 0; + } + return -1; + } + + // performs a quicksort on the list. + // To sort, the base type must suport the ">" and "==" operators + // As a side-effect, truncates the allocated size to the actual size. + void Sort() + { + // We cannot sort directly with C's quicksort, since it uses memmove. + // So we will sort a pointer list, and then reallocate + if (m_nEntries != 0) + { + const t **rgpEntries = new const t *[m_nEntries]; + LogAssert(rgpEntries != NULL); + for (UInt32 i = 0; i < m_nEntries; i++) { + rgpEntries[i] = &(EntryAt(i)); + } + qsort(rgpEntries, m_nEntries, sizeof(rgpEntries[0]), InternalDrValListEntryPointerCompare); + // Generally, noone ever sorts when they plan to grow the list, so we can truncate to actual size. + t *pNew = new t[m_nEntries]; + LogAssert(pNew != NULL); + for (UInt32 i = 0; i < m_nEntries; i++) + { + pNew[i] = *(rgpEntries[i]); + } + delete[] m_prgEntries; + m_prgEntries = pNew; + m_nAllocated = m_nEntries; + m_uFirstEntry = 0; + } + else + { + ForceRealloc(0); + } + } + +protected: + UInt32 NormalizeEntryIndex(UInt32 index) const + { + LogAssert(index < m_nAllocated); + if (m_uFirstEntry != 0) { + index += m_uFirstEntry; + if (index >= m_nAllocated) { + index -= m_nAllocated; + } + } + return index; + } + +private: + UInt32 m_uFirstEntry; // Normally 0, the index of the first entry (for circular buffers) + + UInt32 m_uFirstAllocSize; + UInt32 m_nEntries; + UInt32 m_nAllocated; + t *m_prgEntries; +}; + + + +// a List template that only grows :) +// You are recommended to use DrPtrList unless you are aware of the potential complicated issues behind it +// when copying/assigning values. . :) + +template +class DrList{ +private: + //basically not supported + void Merge(DrList &other){ + for(unsigned int i = 0; i +class DrInternalPtrList : public DrList{ +public: + DrInternalPtrList(unsigned int numAlloc) : DrList(numAlloc) {} + // allow pointer expressions to be passed as argument + void Push(T *ptr){ + DrList::Push(ptr); + } +}; + +// managed means the pointers give to the list are owned by the list and therefore they will be freed by the list. + +template +class DrPtrList : public DrInternalPtrList{ +public: + DrPtrList(unsigned int numAlloc = 10) : DrInternalPtrList(numAlloc) {} +}; + +template +class DrUnmanagedPtrList : public DrInternalPtrList{ +public: + DrUnmanagedPtrList(unsigned int numAlloc = 10) : DrInternalPtrList(numAlloc) {} +}; + +template +class DrUnmanagedArrList : public DrInternalPtrList{ +public: + DrUnmanagedArrList(unsigned int numAlloc = 10) : DrInternalPtrList(numAlloc) {} +}; + +template +class DrManagedPtrList : public DrInternalPtrList{ +public: + DrManagedPtrList(unsigned int numAlloc = 10) : DrInternalPtrList(numAlloc) {} +protected: + virtual void FreeElements(){ + if(m_p){ + for(unsigned int i = 0; i < Size(); i ++){ + delete GetEntry(i); + } + } + } +}; + +template +class DrManagedArrList: public DrInternalPtrList{ +public: + DrManagedArrList(unsigned int numAlloc = 10) : DrInternalPtrList(numAlloc) {} +protected: + virtual void FreeElements(){ + if(m_p){ + for(unsigned int i = 0; i < Size(); i ++){ + delete[] GetEntry(i); + } + } + } +}; + +typedef DrManagedArrList DrManagedStrList; +typedef DrUnmanagedPtrList DrUnmanagedStrList; + +#endif diff --git a/DryadVertex/VertexHost/system/classlib/include/DrLogging.h b/DryadVertex/VertexHost/system/classlib/include/DrLogging.h new file mode 100644 index 0000000..5f97020 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrLogging.h @@ -0,0 +1,105 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include +#include + +#define DrLogWithType(_x) DrLogHelper(_x,__FILE__,__FUNCTION__,__LINE__) + +#define DRMAKELOGTYPE(_type,_initial) \ + + +// +// Create logging calls at various levels and allow them to complete iff logging is enabled +// at the selected level +// +#define DrLogD if (DrLogging::Enabled(LogLevel_Debug)) DrLogWithType(LogLevel_Debug) +#define DrLogI if (DrLogging::Enabled(LogLevel_Info)) DrLogWithType(LogLevel_Info) +#define DrLogW if (DrLogging::Enabled(LogLevel_Warning)) DrLogWithType(LogLevel_Warning) +#define DrLogE if (DrLogging::Enabled(LogLevel_Error)) DrLogWithType(LogLevel_Error) +#define DrLogA if (DrLogging::Enabled(LogLevel_Assert)) DrLogWithType(LogLevel_Assert) + +// +// Define the logging levels +// +typedef enum +{ + LogLevel_Off = 0, + LogLevel_Assert = 1, + LogLevel_Error = 3, + LogLevel_Warning = 7, + LogLevel_Info = 15, + LogLevel_Debug = 31 +} LogLevel; + +// +// Expose functions to set the logging level, check the logging level, and flush the log +// +class DrLogging +{ +public: + static void SetLoggingLevel(LogLevel type); + static bool Enabled(LogLevel type); + static void FlushLog(); + static FILE* GetLogFile(); + +private: + static FILE* CreateLogFile(); + static FILE* m_logFile; +}; + +// +// Class that defines the logging context for a particular log call +// +class DrLogHelper +{ +public: + DrLogHelper(LogLevel type, const char* file, const char* function, int line) + { + m_type = type; + m_file = file; + m_function = function; + m_line = line; + } + + void operator()(const char* format, ...); + +private: + LogLevel m_type; + const char* m_file; + const char* m_function; + int m_line; +}; + +// +// Define a helper that logs at the assert level if any conditional fails +// +#define LogAssert(exp, ...) \ + do { \ + if (!(exp)) { \ + printf("Assert -- %s, %d: %s\n", __FILE__, __LINE__, #exp); \ + DrLogA(#exp, __VA_ARGS__ ); \ + } \ + } while (0) + +#define DebugLogAssert(exp, ...) LogAssert(#exp, __VA_ARGS__ ) diff --git a/DryadVertex/VertexHost/system/classlib/include/DrMemory.h b/DryadVertex/VertexHost/system/classlib/include/DrMemory.h new file mode 100644 index 0000000..f302ef3 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrMemory.h @@ -0,0 +1,362 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +class DrMemoryBuffer: public DrRefCounter +{ +private: + Size_t m_uAvailableSize; + +protected: + Size_t m_uAllocatedSize; + bool m_fIsGrowable; + + // these two are anded together to determine writability + bool m_fIsWritable; + bool m_fIsUserWritable; + + void Init() + { + m_uAllocatedSize = 0; + m_uAvailableSize = 0; + m_fIsGrowable = true; + m_fIsWritable = true; + m_fIsUserWritable = true; + }; + + DrMemoryBuffer() + { + Init(); + }; + + virtual ~DrMemoryBuffer() + { + }; + + virtual void InternalSetAvailableSize(Size_t uSize) + { + m_uAvailableSize = uSize; + } + + +/* --------------------- Functions that the actual implementation provides --------------------- */ + +public: + + // + // Retrieve pointer to the data stored in memory block at uOffset and max size available in this block + // + // Returns NULL if no data at this offset, valid pointer otherwise + // + virtual void *GetDataAddress( + Size_t uOffset, // starting offset + Size_t *puSize, // number of bytes available (0 in case of failure) + Size_t *puPriorSize // optional; size of contigious memory area prior to (*GetDataAddress()) + ) = 0; + + // + // Preallocate enough memory buffers to fix uMaxSize bytes of data. + // + // If the buffer is not growable, it is a fatal error to attempt to grow the allocated size + // beyond the current value. + virtual void IncreaseAllocatedSize( + Size_t uSize // preallocate memory blocks to fit at least uSize bytes of data + ) = 0; + + // + // Updates the available size. Grows the allocated size if necessary. + // + // The default implementation just does the required updating, then updates m_uAvailableSize + // + virtual void SetAvailableSize(Size_t uSize) + { + if (uSize > m_uAllocatedSize) { + IncreaseAllocatedSize(uSize); + LogAssert(uSize <= m_uAllocatedSize); + } + InternalSetAvailableSize(uSize); + } + + +/* ------------------------------- The rest are predefined ----------------------------------------- */ + + +public: + + bool IsWritable() + { + return m_fIsWritable && m_fIsUserWritable; + } + + bool IsGrowable() + { + return m_fIsGrowable; + } + + // + // Get max size of data that may be stored in the buffer + // + Size_t GetAllocatedSize() + { + return m_uAllocatedSize; + }; + + // + // Get amount of data already stored in the buffer + // + Size_t GetAvailableSize() + { + return m_uAvailableSize; + }; + + // + // Get address and size of available memory chunk allocating more memory if necessary + // + // Note that this may return memory beyond the current available size; it is + // the caller's responsibility to SetAvailableSize() if necessary. + // + // Note also that the returned *puSize may be greater than uDataSize + // + // If the buffer is not growable, it is a fatal error to ask for data beyond the allocated size. + void *GetWriteAddress( + Size_t uOffset, // offset at which to return a write pointer + Size_t uDataSize, // minimum number of bytes to ensure is available (not necessarily contiguous) starting at the specified offset + /* out */ Size_t *puSize, // Number of contiguous bytes starting at the returned pointer + /* out */ Size_t *puPreceedingSize = NULL // Number of contiguous bytes that preceed the returned pointer + ); + + // + // This is the same as GetWriteAddress, except it is not a fatal error if the buffer is not growable grown to accomodate; in + // this case, NULL is returned. + void *GetWriteAddressIfPossible( + Size_t uOffset, // offset at which to return a write pointer + Size_t uDataSize, // minimum number of bytes to ensure is available (not necessarily contiguous) starting at the specified offset + /* out */ Size_t *puSize, // Number of contiguous bytes starting at the returned pointer + /* out */ Size_t *puPreceedingSize = NULL // Number of contiguous bytes that preceed the returned pointer + ) + { + Size_t uEnd = uOffset + uDataSize; + if (!IsGrowable() && uEnd > m_uAllocatedSize) { + *puSize = 0; + if (puPreceedingSize != NULL) { + *puPreceedingSize = 0; + } + return NULL; + } + return GetWriteAddress(uOffset, uDataSize, puSize, puPreceedingSize); + } + + // + // Copy uDataSize bytes from pData array into the buffer at the + // specified offset. Grows the available data size if necessary to include the written data. + // + void Write( + Size_t uOffset, // starting offset + const void *pData, // data buffer + Size_t uDataSize // number of bytes to copy + ); + + void Append( + const void *pData, + Size_t uDataSize + ) + { + Write(GetAvailableSize(), pData, uDataSize); + } + + // + // Zero uDataSize bytes in the buffer at the + // specified offset. Grows the available data size if necessary to include the zeroed data. + // + void Zero( + Size_t uOffset, // starting offset + Size_t uDataSize // number of bytes to set to 0 + ); + + // + // Get the address and size of contiguous available readable memory area at offset uOffset + // It is a fatal error to request readable memory at or beyond the current available buffer size. + // + const void *GetReadAddress( + Size_t uOffset, + /* out */ Size_t *puSize, // Number of contiguous available bytes beginning at uOffset + /* out */ Size_t *puPreceedingSize = NULL // Number of contiguous readable bytes that preceed the returned pointer + ); + + // + // Read uDataSize bytes into pData into the buffer starting at uOffset. + // + // It is a fatal error to attempt to read beyond the available size of the buffer + // + void Read( + Size_t uOffset, + void *pData, + Size_t uDataSize + ); + + // + // Compares uDataSize bytes from pData with buffer contents starting at uOffset. + // + // Return 0 on match, < 0 if contents of the buffer is less than contents of pData, > 0 otherwise + // + // It is a fatal error to attempt to read beyond the available size of the buffer + // + int Compare( + Size_t uOffset, + const void *pData, + Size_t uDataSize + ); + + // + // Copy data from one buffer to another + // + void CopyBuffer( + Size_t uDstOffset, // starting offset + DrMemoryBuffer *pSrcBuffer, // source buffer + Size_t uSrcOffset, // starting offset in source buffer + Size_t uDataSize // number of bytes to copy + ); + +}; + + +// A simple implementation of DrMemoryBuffer that is built on a single malloc'd block of memory +class DrSimpleHeapBuffer : public DrMemoryBuffer +{ +private: + BYTE *m_pData; + +public: + DrSimpleHeapBuffer(); + DrSimpleHeapBuffer(Size_t allocedSize); // pregrows allocated size + virtual ~DrSimpleHeapBuffer(); + + // Returns the underlying heap object (if any). This buffer remains + // the owner of the heap object. + // + // returns NULL if there is no underlying heap object (allocedSize = 0) + // + void *GetHeapItem() + { + return m_pData; + } + + // Detaches the underlying heap object (if any) and returns it to the caller, who + // must call free() on the memory when done with it. + // + // returns NULL if there is no underlying heap object (allocedSize = 0) + // + // After this call, the buffer is a new buffer with no data in it. + // + void *DetachHeapItem(); + + // Attaches an external heap item to the buffer. Any previous heap item is + // freed. The buffer becomes the owner of the heap item. + // + void AttachHeapItem(void *pHeapItem, Size_t allocedSize, Size_t dataSize); + + // DrMemoryBuffer implementation: + +public: + + // + // Retrieve pointer to the data stored in memory block at uOffset and max size available in this block + // + // Returns NULL if no data at this offset, valid pointer otherwise + // + virtual void *GetDataAddress( + Size_t uOffset, // starting offset + Size_t *puSize, // number of bytes available (0 in case of failure) + Size_t *puPriorSize // optional; size of contigious memory area prior to (*GetDataAddress()) + ); + + // + // Preallocate enough memory buffers to fix uMaxSize bytes of data. + // + virtual void IncreaseAllocatedSize( + Size_t uSize // preallocate memory blocks to fit at least uSize bytes of data + ); +}; + +// A buffer that wraps a fixed contiguous block of memory. +// +// The buffer allocated size is not growable. No memory is freed when the buffer is destroyed +// +class DrFixedMemoryBuffer : public DrMemoryBuffer +{ +private: + const BYTE *m_pData; + +public: + DrFixedMemoryBuffer() + { + m_pData = NULL; + m_fIsGrowable = false; + m_fIsWritable = false; + } + + virtual ~DrFixedMemoryBuffer() + { + } + + // + // Initialize fixed length buffer with byte array and sizes + // + void Init(const BYTE *pData, Size_t allocatedSize, Size_t availableSize = 0) + { + LogAssert(availableSize <= allocatedSize); + m_pData = pData; + m_uAllocatedSize = allocatedSize; + InternalSetAvailableSize(availableSize); + m_fIsWritable = true; + } + + // + // Create new fixed length buffer with byte array and sizes + // + DrFixedMemoryBuffer(const BYTE *pData, Size_t allocatedSize, Size_t availableSize = 0) + { + Init(pData, allocatedSize, availableSize); + } + + // DrMemoryBuffer implementation: + +public: + + // + // Retrieve pointer to the data stored in memory block at uOffset and max size available in this block + // + // Returns NULL if no data at this offset, valid pointer otherwise + // + virtual void *GetDataAddress( + Size_t uOffset, // starting offset + Size_t *puSize, // number of bytes available (0 in case of failure) + Size_t *puPriorSize // optional; size of contigious memory area prior to (*GetDataAddress()) + ); + + // + // Preallocate enough memory buffers to fix uMaxSize bytes of data. + // + virtual void IncreaseAllocatedSize( + Size_t uSize // preallocate memory blocks to fit at least uSize bytes of data + ); +}; + diff --git a/DryadVertex/VertexHost/system/classlib/include/DrMemoryStream.h b/DryadVertex/VertexHost/system/classlib/include/DrMemoryStream.h new file mode 100644 index 0000000..08b0e01 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrMemoryStream.h @@ -0,0 +1,2302 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include + + +// Abstract class used if you use the ReadAggregate function on DrMemoryReader to parse properties +class DrMemoryReader; + +class DrPropertyParser +{ +public: + virtual DrError OnParseProperty(DrMemoryReader *reader, UInt16 enumID, UInt32 dataLen, void *cookie) = NULL; +}; + +/* Classes for marshalling data into and out of non-contiguous memory buffers */ + +#pragma warning (push) + +// These are: +// 1) Linker complaining about a couple of functions that are never called (this is ok, function may be called in the future) +// 2) size_t issues. Sammck is going to fix these. +#pragma warning (disable:4201) +#pragma warning (disable:4365) + +class DrMemoryStream +{ + friend class DrMemoryWriter; + friend class DrMemoryReader; + +public: + // Base-class implemented methods: + + // Sets the current status, if not already set to a failing code. Once the status + // has been changed to something other than DrError_OK, it cannot be changed. + // Returns the new current status, which may not be the same as the status passed in. + __forceinline DrError SetStatus(DrError err) + { + if (status == DrError_OK) { + status = err; + } + return status; + } + + // Returns the current status. If this is not DrError_OK, the status is persistent and + // all future attempts to manipulate this stream will fail with this status code. */ + __forceinline DrError GetStatus() + { + return status; + } + + // Returns the current physical byte offset from the beginning of the entire stream. + // For readers, this is the last set physical stream position (or 0 if it has never been set) plus the number of bytes read/parsed since + // the physical position was last set. Does not include buffered or peeked data that has not yet been permanently read. + // For writers, this is the last set physical stream position (or 0 if it has never been set) plus the number of bytes written since + // the physical position was last set, including data that had been written but has not yet been flushed. + __forceinline UInt64 GetPhysicalStreamPosition() + { + // Note: both pData and pBlockBase may be NULL, which results in uBlockBasePhysicalStreamPosition + return uBlockBasePhysicalStreamPosition + (pData - pBlockBase); + } + +public: + // Deprecated methods: + + // BUGBUG: this method used to be void, but needs to return DrError + // BUGBUG: changed return type to intentionally break subclass virtual override, and deprecated to encourage callers to + // BUGBUG: check return code. + // Deprecated --Replace with CloseMemoryReader() or CloseMemoryWriter, and check return code or writer status. + __declspec(deprecated) virtual DrError Close() sealed; + +protected: + // subclass-visible methods + + // Returns the number of bytes remaining within the current contiguous block. + // For readers, this is the number of contiguous bytes to be read before the read block + // must be advanced (or end of stream is reached). Does not include prefetched/peeked data. + // For writers, this is the number of contiguous unflushed bytes which can be written before the current + // block must be flushed and a new block allocated. + __forceinline Size_t NumContiguousBytesRemaining() + { + // Note: both pData and pBlockBase may be NULL, which results in blockLength (which must be 0 in that case) + return blockLength - (pData - pBlockBase); + } + +private: + // private friend-only methods (only usable by friend classes): + + // Clears all buffer context values to initial defaults (including settting the physical stream position to 0), but does not change the current status code. + void DiscardMemoryStreamContext(); + +private: + // private methods: + // Cannot be directly instantiated; must be subclassed by a friend class + DrMemoryStream(); + + // cannot be directly destructed, must be referenced through subclass + virtual ~DrMemoryStream(); + +// TODO: hide protected members, use accessors +protected: + // sub-class visible member data + + /* The status starts at DrError_OK. at the first stream error, it is set to an + error code. After this point, the current pointer is never advanced and the + error is never reset. + When the DrMemoryStream is closed, memory resources are freed and the + pointers are reset, but the status remains the same as the pre-close status + */ + // TODO: prevent subclass from changing status after it is set to an error + DrError status; + + /* The base pointer to the current contiguous block */ + BYTE *pBlockBase; + + /* The length of the current contiguous block */ + Size_t blockLength; + + /* The current read or write pointer (inside the current contiguous block) */ + BYTE *pData; + + /* The physical position of the beginning of the current contiguous block within the + overall stream. This is advanced by blockLength when we move to a new block. It + is provided to that the caller can determine physical stream position independent + of buffering behavior. */ + UInt64 uBlockBasePhysicalStreamPosition; +private: +}; + +class DrMemoryReader; + +class DrMemoryWriter : public DrMemoryStream +{ +public: + // Public subclass-overridable methods + + // Flushes output to its destination. By default, this just returns the writer status. + // This method is called by CloseMemoryWriter if the stream is not yet closed. + virtual DrError FlushMemoryWriter() + { + return status; + } + + // Default implementation just frees any abandoned temporary + // buffered write buffers, calls FlushMemoryWriter, and returns the writer status. + // If this is the first time CloseMemoryWriter() has been called, then resources are freed regardless of + // the current status, allowing this method to be used to free resources even on failed streams. + // The current status is always returned. + // Subclasses that override this method should *always* delegate to their parent class + // after freeing their own resources, and should *always* return the current status at completion. + virtual DrError CloseMemoryWriter(); + + // Special Write methods that can be optimized by subclasses to copy buffers by reference. The default + // implementation simply reads contiguous blocks from the buffer and writes them to the output stream. + virtual DrError WriteBytesFromBuffer(DrMemoryBuffer *pBuffer, bool fAllowCopyBufferByReference=false); + +public: + // public base class methods: + + // Returns true if CloseMemoryWriter has been called and has run all the way down to the base class + __forceinline bool MemoryWriterIsClosed() + { + return m_fMemoryWriterIsClosed; + } + + // Verifies that a certain number of bytes can be written without error (DrError_OK is not a guarantee of + // success, but anything else is a guarantee of failure and becomes a persistent failure). This method + // handles the contiguous case inline, and delegates the split case to a real function. + // + // returns DrError_EndOfStream if the output stream cannot be extended by the required length. + // + virtual DrError EnsureCanBeWritten(Size_t length) + { + if (status != DrError_OK) { + return status; + } else if (length > NumContiguousBytesRemaining() && !CrossBlockCanBeWritten(length)) { + return SetStatus(DrError_EndOfStream); + } else { + return DrError_OK; + } + } + + // This is the quick, and primary, method for stuffing bytes into the output stream. It handles + // the contiguous case inline, and delegates the cross-block case to a real function. + // Returns DrError_EndOfStream if the stream becomes full before all data can be written (partial data + // may still be written). + virtual DrError WriteBytes(const void *pBytes, Size_t length) + { + if (status == DrError_OK) { + LogAssert(pBytes != NULL || length == 0); + if (length <= NumContiguousBytesRemaining()) { + memcpy(pData, pBytes, length); + pData += length; + } else { + CrossBlockWriteBytes((const BYTE *)pBytes, length); + } + } + return status; + } + + DrError WriteBytesFromReader(DrMemoryReader *pReader, Size_t length); + + /* Append values of different types. Note that once an error occurs on the stream, these + methods always fail without having any effect. This allows the caller to call a sequence + of these methods and only check the return code of the last one. */ + + inline DrError WriteByte(BYTE val) + { + return WriteBytes(&val, sizeof(val)); + } + + inline DrError WriteChar(char val) + { + return WriteBytes(&val, sizeof(val)); + } + + inline DrError WriteInt8(Int8 val) + { + return WriteBytes((const BYTE *)(&val), sizeof(val)); + } + + inline DrError WriteInt16(Int16 val) + { + // WARNING: Assumes little endian + return WriteBytes(&val, sizeof(val)); + } + + inline DrError WriteInt32(Int32 val) + { + // WARNING: Assumes little endian + return WriteBytes(&val, sizeof(val)); + } + + inline DrError WriteInt64(Int64 val) + { + // WARNING: Assumes little endian + return WriteBytes(&val, sizeof(val)); + } + + // compat with netlib + inline DrError WriteUInt8(UInt8 val) + { + return WriteBytes((const BYTE *)(&val), sizeof(val)); + } + + inline DrError WriteUInt16(UInt16 val) + { + // WARNING: Assumes little endian + return WriteBytes(&val, sizeof(val)); + } + + inline DrError WriteUInt32(UInt32 val) + { + // WARNING: Assumes little endian + return WriteBytes(&val, sizeof(val)); + } + + inline DrError WriteUInt64(UInt64 val) + { + // WARNING: Assumes little endian + return WriteBytes(&val, sizeof(val)); + } + + // This will cause an assertion failure if the size is too big to fit in UInt32 + inline DrError WriteSize_tAsUInt32(Size_t val) + { + // WARNING: Assumes little endian + LogAssert((Size_t)(UInt32)val == val); + return WriteUInt32((UInt32)val); + } + + inline DrError WriteSize_tAsUInt64(Size_t val) + { + // WARNING: Assumes little endian + return WriteUInt64((UInt64)val); + } + + inline DrError WriteFloat(float val) + { + // WARNING: Assumes little endian + return WriteBytes(&val, sizeof(val)); + } + + inline DrError WriteDouble(double val) + { + // WARNING: Assumes little endian + return WriteBytes(&val, sizeof(val)); + } + + /* copy length bytes verbatim into the buffer. This call does not + write length into the buffer explicitly; if you want to precede + the data blob with a length, call e.g. WriteUInt32() first. */ + inline DrError WriteData(UInt32 length, const void *data) + { + return WriteBytes(data, (Size_t)length); + } + + inline DrError WriteGuid(const GUID& guid) + { + return WriteBytes(&guid, sizeof(guid)); + } + + inline DrError WriteDrError(DrError val) + { + // WARNING: Assumes little endian + return WriteBytes(&val, sizeof(val)); + } + + inline DrError WriteBool(bool val) + { + UInt8 v = val ? 1 : 0; + return WriteBytes(&v, sizeof(v)); + } + + inline DrError WriteTimeStamp(DrTimeStamp ts) + { + return WriteBytes(&ts, sizeof(ts)); + } + + inline DrError WriteTimeInterval(DrTimeInterval ti) + { + return WriteBytes(&ti, sizeof(ti)); + } + + /* Array versions of writes. Write an array of data elements in to the + buffer */ + + inline DrError WriteByteArray(UInt32 count, const BYTE *vals) + { + return WriteBytes(vals, (Size_t)count); + } + + inline DrError WriteCharArray(UInt32 count, const char *vals) + { + return WriteBytes((const BYTE *)(void *)vals, (Size_t)count); + } + + inline DrError WriteInt8Array(UInt32 count, const Int8 *vals) + { + return WriteBytes((const BYTE *)(void *)vals, (Size_t)count); + } + + inline DrError WriteInt16Array(UInt32 count, const Int16 *vals) + { + // Assumes Little Endian + return WriteBytes((const BYTE *)(void *)vals, (Size_t)count * sizeof(vals[0])); + } + + inline DrError WriteInt32Array(UInt32 count, const Int32 *vals) + { + // Assumes Little Endian + return WriteBytes((const BYTE *)(void *)vals, (Size_t)count * sizeof(vals[0])); + } + + inline DrError WriteInt64Array(UInt32 count, const Int64 *vals) + { + // Assumes Little Endian + return WriteBytes((const BYTE *)(void *)vals, (Size_t)count * sizeof(vals[0])); + } + + inline DrError WriteUInt8Array(UInt32 count, const UInt8 *vals) + { + return WriteBytes((const BYTE *)(void *)vals, (Size_t)count); + } + + inline DrError WriteUInt16Array(UInt32 count, const UInt16 *vals) + { + // Assumes Little Endian + return WriteBytes((const BYTE *)(void *)vals, (Size_t)count * sizeof(vals[0])); + } + + inline DrError WriteUInt32Array(UInt32 count, const UInt32 *vals) + { + // Assumes Little Endian + return WriteBytes((const BYTE *)(void *)vals, (Size_t)count * sizeof(vals[0])); + } + + inline DrError WriteUInt64Array(UInt32 count, const UInt64 *vals) + { + // Assumes Little Endian + return WriteBytes((const BYTE *)(void *)vals, (Size_t)count * sizeof(vals[0])); + } + + inline DrError WriteFloatArray(UInt32 count, const float *vals) + { + // Assumes Little Endian + return WriteBytes((const BYTE *)(void *)vals, (Size_t)count * sizeof(vals[0])); + } + + inline DrError WriteDoubleArray(UInt32 count, const double *vals) + { + // Assumes Little Endian + return WriteBytes((const BYTE *)(void *)vals, (Size_t)count * sizeof(vals[0])); + } + + inline DrError WriteGuidArray(UInt32 count, const GUID *vals) + { + // Assumes little Endian + return WriteBytes((const BYTE *)(void *)vals, (Size_t)count * sizeof(vals[0])); + } + + // WritePropertyTagXXX: These functions write the part of the + // property not incuding "data". After calling WritePropertyTagXXX, you *must* + // write exactly dataLen bytes of data. + // WritePropertyTagShort must only be used for SHORTATOM's, + // and WritePropertyTagLong must only be used for LONGATOM's. + // WritePropertyTagAnySize calls either WritePropertyTagShort or WritePropertyTagLong + // depending on the atom type. + + + // WritePropertyTagShort--write a SHORTATOM property ID and length value. + // Will cause an assertion failure if the property id is not a SHORTATOM. + inline DrError WritePropertyTagShort(UInt16 enumId, UInt8 dataLen) + { + LogAssert((enumId & PropLengthMask) == PropLength_Short); + WriteUInt16(enumId); + return WriteUInt8(dataLen); + } + + // WritePropertyTagShort--write a SHORTATOM property ID and length value. + // This version takes an arbitrary Size_t dataLen + // Will cause an assertion failure if the property id is not a SHORTATOM, or if dataLen > 255. + inline DrError WritePropertyTagShort(UInt16 enumId, Size_t dataLen) + { + LogAssert((enumId & PropLengthMask) == PropLength_Short && dataLen <= (Size_t)_UI8_MAX); + WriteUInt16(enumId); + return WriteUInt8((UInt8)dataLen); + } + + // WritePropertyTagLong--write a LONGATOM property ID and length value. + // Will cause an assertion failure if the property id is not a LONGATOM, or if dataLen > _UI32_MAX. + inline DrError WritePropertyTagLong(UInt16 enumId, Size_t dataLen) + { + LogAssert((enumId & PropLengthMask) == PropLength_Long && dataLen <= (Size_t)_UI32_MAX); + WriteUInt16(enumId); + return WriteUInt32((UInt32)dataLen); + } + + // WritePropertyTagAnySize--write a SHORTATOM or LONGATOM property ID and length value. + // Will cause an assertion failure if: + // a) The property id is a SHORTATOM and dataLen > 255 + // b) The property id is a LONGATOM and dataLen > _UI32_MAX + inline DrError WritePropertyTagAnySize(UInt16 enumId, Size_t dataLen) + { + if ((enumId & PropLengthMask) == PropLength_Short) + { + WritePropertyTagShort(enumId, dataLen); + } + else + { + WritePropertyTagLong(enumId, dataLen); + } + return status; + } + + /* the following methods may only be called with enumId values that represent SHORTATOM's */ + + // WriteProperty: these functions write a property tag followed by + // the value of the implicitly specified type. + + // This method writes a LONGATOM or SHORTATOM property as an arbitrary sequence of bytes. + // + // Causes an assertion failure if: + // a) "pvData" is NULL but "dataLen" is not 0. + // b) "dataLen" > 255 for SHORTATOMs + // c) "dataLen" > _UI32_MAX for LONGATOMs + inline DrError WriteAnySizeBlobProperty(UInt16 enumId, Size_t dataLen, const void *pvData) + { + WritePropertyTagAnySize(enumId, dataLen); + return WriteBytes(pvData, dataLen); + } + + // This method writes a LONGATOM property as an arbitrary sequence of bytes. + // + // Causes an assertion failure if: + // a) "pvData" is NULL but "dataLen" is not 0. + // b) enumId is not a LONGATOM + // c) "dataLen" > _UI32_MAX + inline DrError WriteLongBlobProperty(UInt16 enumId, Size_t dataLen, const void *pvData) + { + WritePropertyTagLong(enumId, dataLen); + return WriteBytes(pvData, dataLen); + } + + // This method writes a SHORTATOM property as an arbitrary sequence of bytes. + // + // Causes an assertion failure if: + // a) "pvData" is NULL but "dataLen" is not 0. + // b) enumId is not a SHORTATOM + // c) "dataLen" > 255 + inline DrError WriteShortBlobProperty(UInt16 enumId, Size_t dataLen, const void *pvData) + { + WritePropertyTagShort(enumId, dataLen); + return WriteBytes(pvData, dataLen); + } + + // Note: formerly, this only allowed short property atoms. Now it allows either long or short properties. + // For optimal efficiency, use WriteEmptyShortProperty or WriteEmptyLongProperty if known at design time. + inline DrError WriteEmptyProperty(UInt16 enumId) + { + return WritePropertyTagAnySize(enumId, 0); + } + + inline DrError WriteEmptyShortProperty(UInt16 enumId) + { + return WritePropertyTagShort(enumId, (UInt8)0); + } + + inline DrError WriteEmptyLongProperty(UInt16 enumId) + { + return WritePropertyTagLong(enumId, 0); + } + + inline DrError WriteInt8Property(UInt16 enumId, Int8 value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteInt8(value); + } + + inline DrError WriteInt16Property(UInt16 enumId, Int16 value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteInt16(value); + } + + inline DrError WriteInt32Property(UInt16 enumId, Int32 value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteInt32(value); + } + + inline DrError WriteInt64Property(UInt16 enumId, Int64 value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteInt64(value); + } + + inline DrError WriteUInt8Property(UInt16 enumId, UInt8 value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteUInt8(value); + } + + inline DrError WriteUInt16Property(UInt16 enumId, UInt16 value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteUInt16(value); + } + + inline DrError WriteUInt32Property(UInt16 enumId, UInt32 value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteUInt32(value); + } + + inline DrError WriteUInt64Property(UInt16 enumId, UInt64 value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteUInt64(value); + } + + inline DrError WriteSize_tAsUInt32Property(UInt16 enumId, Size_t value) + { + WritePropertyTagShort(enumId, sizeof(UInt32)); + return WriteSize_tAsUInt32(value); + } + + inline DrError WriteSize_tAsUint64Property(UInt16 enumId, Size_t value) + { + WritePropertyTagShort(enumId, sizeof(UInt64)); + return WriteSize_tAsUInt64(value); + } + + inline DrError WriteFloatProperty(UInt16 enumId, float value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteFloat(value); + } + + inline DrError WriteDoubleProperty(UInt16 enumId, double value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteDouble(value); + } + + inline DrError WriteGuidProperty(UInt16 enumId, const GUID& value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteGuid(value); + } + + inline DrError WriteBoolProperty(UInt16 enumId, bool value) + { + WritePropertyTagShort(enumId, sizeof(UInt8)); + return WriteBool(value); + } + + inline DrError WriteDrErrorProperty(UInt16 enumId, DrError value) + { + WritePropertyTagShort(enumId, sizeof(DrError)); + return WriteDrError(value); + } + + inline DrError WriteTimeStampProperty(UInt16 enumId, DrTimeStamp value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteTimeStamp(value); + } + + inline DrError WriteTimeIntervalProperty(UInt16 enumId, DrTimeInterval value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteTimeInterval(value); + } + + // Write a narrow (char[]) string "pstr" as a property, with an explicit provided string length. pstr may be NULL, which is + // properly encoded as a distinct value from a zero-length string. Typically, a non-NULL pstr points to a + // UTF-8 string (not-necessarily '\0'-terminated) that is "length" bytes long, without any embedded '\0' bytes; + // however, this function may be used to encode arbitrary byte blocks, including blocks that contain embedded '\0' bytes, + // so it can be used, e.g., to encode a concatenated list of '\0'-terminated strings. + // The primary distinctions between this function and WriteBlob are: + // a) NULL is a distinct value from a zero-length string. + // b) The property length is 0 for NULL values, and Llength"+1 for non-NULL values. + // b) For non-NULL values, a '\0' is appended to the string bytes as the last byte of the property data. This does not change + // the value of the string, but allows a reader to use the serialized data as a null terminated string without copying it, and + // allows a reader to distinguish between a NULL string and an empty string. + // + // The provided length does not include the terminating '\0' byte; however, a non-NULL pstr + // is always '\0'-terminated in the output stream, and a NULL pstr is wriiten as an empty property. + // The actual length of the property data will be "length" + 1 if pstr is not NULL. + // + // These semantics are consistent with round-trip encoding of a DrStr value. + // + // This vertion only works with LONGATOM properties. This is asserted, to help detect bugs where strings that may + // occasionally exceed 254 bytes are accidently tagged as SHORTATOM. If you have a string property that you know + // *must* always be less than 255 bytes long, you can use WriteShortStringPropertyWithLength. + // + // Causes an assertion failure if: + // a) enumId is a SHORTATOM + // b) length >= _UI32_MAX + // c) pstr == NULL && length != 0 + DrError WriteLongStringPropertyWithLength(UInt16 enumId, const char *pstr, Size_t length); + + // This version writes a SHORTATOM string value. + // + // This vertion only works with SHORTATOM properties. If you have a string property that you know + // *must* always be less than 255 bytes long, you can use a SHORTATOM property ID and this method. + // Otherwise, you should use WriteLongStringPropertyWithLength(). + // + // Causes an assertion failure if: + // a) enumId is a LONGATOM + // b) length >= 255 + // c) pstr == NULL && length != 0 + DrError WriteShortStringPropertyWithLength(UInt16 enumId, const char *pstr, Size_t length); + + // Writes a '\0'-terminated string as a LONGATOM property + // + // Causes an assertion failure if: + // a) enumId is a SHORTATOM + // b) strlen(pstr) >= _UI32_MAX + // + // See WriteLongStringPropertyWithLength for more information. + // + DrError WriteLongStringProperty(UInt16 enumId, const char *pstr) + { + size_t len = 0; + if (pstr != NULL) + { + len = strlen(pstr); + } + return WriteLongStringPropertyWithLength(enumId, pstr, len); + } + + // Writes a '\0'-terminated string as a SHORTATOM property + // + // Causes an assertion failure if: + // a) enumId is a SHORTATOM + // b) strlen(pstr) >= _UI32_MAX + // + // See WriteLongStringPropertyWithLength for more information. + // + DrError WriteShortStringProperty(UInt16 enumId, const char *pstr) + { + size_t len = 0; + if (pstr != NULL) + { + len = strlen(pstr); + } + return WriteShortStringPropertyWithLength(enumId, pstr, len); + } + + // Writes a LONGATOM string property value from a DrStr. + // + // Any DrStr value, including NULL and strings with embedded '\0' bytes can be properly round-tripped. + // + // Causes an assertion failure if: + // a) enumId is a SHORTATOM + // b) length >= _UI32_MAX + // c) pstr == NULL && length != 0 + // + // See WriteLongStringPropertyWithLength for more information. + // + DrError WriteLongDrStrProperty(UInt16 enumId, const DrStr& str) + { + return WriteLongStringPropertyWithLength(enumId, str.GetString(), str.GetLength()); + } + + // Writes a SHORTATOM string property value from a DrStr. + // + // Any DrStr value, including NULL and strings with embedded '\0' bytes can be properly round-tripped. + // + // Causes an assertion failure if: + // a) enumId is a LONGATOM + // b) length >= 255 + // c) pstr == NULL && length != 0 + // + // See WriteLongStringPropertyWithLength for more information. + // + DrError WriteShortDrStrProperty(UInt16 enumId, const DrStr& str) + { + return WriteShortStringPropertyWithLength(enumId, str.GetString(), str.GetLength()); + } + +#define DRDEPRECATED_UNTYPED __declspec(deprecated) + +public: + // Deprecated methods: + + // BUGBUG: this method used to be void, but needs to return DrError + // BUGBUG: changed return type to intentionally break subclass virtual override, and deprecated to encourage callers to + // BUGBUG: check return code. + // Deprecated --Replace with FlushMemoryWriter(), and check return code or writer status. + __declspec(deprecated) virtual DrError Flush() sealed + { + SetStatus(FlushMemoryWriter()); + return status; + } + + // deprecated, use WriteBlobProperty + inline DRDEPRECATED_UNTYPED DrError WriteProperty(UInt16 enumId, Size_t dataLen, const void *data) + { + WritePropertyTagAnySize(enumId, dataLen); + return WriteBytes(data, dataLen); + } + + // deprecated, use WriteEmptyShortProperty, WriteEmptyLongProperty, or WriteEmptyProperty + inline DRDEPRECATED_UNTYPED DrError WriteProperty(UInt16 enumId) + { + return WritePropertyTagAnySize(enumId, 0); + } + + // deprecated, use WriteInt8Property + inline DRDEPRECATED_UNTYPED DrError WriteProperty(UInt16 enumId, Int8 value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteInt8(value); + } + + // deprecated, use WriteInt16Property + inline DRDEPRECATED_UNTYPED DrError WriteProperty(UInt16 enumId, Int16 value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteInt16(value); + } + + // deprecated, use WriteInt32Property + inline DRDEPRECATED_UNTYPED DrError WriteProperty(UInt16 enumId, Int32 value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteInt32(value); + } + + // deprecated, use WriteInt64Property + inline DRDEPRECATED_UNTYPED DrError WriteProperty(UInt16 enumId, Int64 value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteInt64(value); + } + + // deprecated, use WriteUInt8Property + inline DRDEPRECATED_UNTYPED DrError WriteProperty(UInt16 enumId, UInt8 value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteUInt8(value); + } + + // deprecated, use WriteUInt16Property + inline DRDEPRECATED_UNTYPED DrError WriteProperty(UInt16 enumId, UInt16 value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteUInt16(value); + } + + // deprecated, use WriteUInt32Property + inline DRDEPRECATED_UNTYPED DrError WriteProperty(UInt16 enumId, UInt32 value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteUInt32(value); + } + + // deprecated, use WriteUInt64Property + inline DRDEPRECATED_UNTYPED DrError WriteProperty(UInt16 enumId, UInt64 value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteUInt64(value); + } + + // deprecated, use WriteFloatProperty + inline DRDEPRECATED_UNTYPED DrError WriteProperty(UInt16 enumId, float value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteFloat(value); + } + + // deprecated, use WriteDoubleProperty + inline DRDEPRECATED_UNTYPED DrError WriteProperty(UInt16 enumId, double value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteDouble(value); + } + + // deprecated, use WriteGuidProperty + inline DRDEPRECATED_UNTYPED DrError WriteProperty(UInt16 enumId, const GUID& value) + { + WritePropertyTagShort(enumId, sizeof(value)); + return WriteGuid(value); + } + +public: + // virtual destructor + + // Note: it is the most derived subclass's responsibility to call MemoryWriterDestructorClose() + // at destruct time if desired and necessary (since the virtual destructor cleanup order + // makes it too late to call the virtual CloseMemoryWriter() method. However, since an error cannot be returned + // at destruct time, this is only possible if CloseMemoryWriter() cannot fail. So MemoryWriterDestructorClose() + // will assert on success of CloseMemoryWriter() unless status is already failing prior to MemoryWriterDestructorClose() time, + // or MemoryWriterDestructorClose has already been called. + // If the subclass wants to inhibit a base class from attempting to close at destruct time, it should just call + // SetIgnoreMemoryWriterCloseFailureInDestructor(). + virtual ~DrMemoryWriter(); + +protected: + // Protected subclass-overridable methods + + // Clears all buffer context values to initial defaults (including settting the physical stream position to 0 + // and discarding abandandoned temporary buffered write buffers), but does not change the current status code + // or "closed" status. + // may be overridden by a subclass if it needs to free resources controlled by the + // context pointers before delegating to this method implementation. + virtual void DiscardMemoryWriterContext(); + + // This method should be overridden by memory writers that know how to allocate a new block and keep writing. + // The implementation should update pData, pBlockBase, blockLength, and uBlockBasePhysicalStreamPosition to point to the new block. + // The current block will be completely filled before calling this method, since there is no way to back up. + // After this call, the old block can be disposed of in any way the underlying implementation chooses (e.g., flushing). + // Returns SetStatus(DrError_EndOfStream) if a new block can't or shouldn't be allocated. + // The default implementation always returns DrError_EndOfStream, which is appropriate for single-block writers. + // If an error is returned, status has been set. + virtual DrError AdvanceToNextBlock(); + + // This method checks whether an attempt to write beyond the end of the current block + // can possibly succeed. The implementation should return true if the buffer is indefinitely growable. + // The default implementation returns false, which is appropriate for ungrowable single-block writers. + // Note that this method must not set status, and will only be called when status is DrError_OK. + virtual bool FutureBlocksCanBeWritten(Size_t length); + +protected: + // protected base class methods + + // Note: it is the most derived subclass's responsibility to call MemoryWriterDestructorClose() + // at destruct time if desired and necessary (since the virtual destructor cleanup order + // makes it too late to call the virtual FlushMemoryWriter() method from the base class. Since an error cannot be returned + // at destruct time, this is only safe if CloseMemoryWriter() cannot fail. So MemoryWriterDestructorClose() + // will assert on success of CloseMemoryWriter() unless status is already failing prior to MemoryWriterDestructorClose() time, + // or CloseMemoryWriter() has already been called. + // If the subclass wants to inhibit the base class from crashing on close failure at destruct time, it can just call + // SetIgnoreMemoryWriterCloseFailureInDestructor(). + void MemoryWriterDestructorClose(); + +private: + // Private methods: + + // This frees all temporary buffers (e.g., incomplete/abandoned buffered writes). The base class stae remains unchanged. + void InternalFree(); + + // This method determines if a number of bytes can successfully be written, handling + // the case where the data will cross blocks. + // This method is only called with status == DrError_OK + inline bool CrossBlockCanBeWritten(Size_t length) + { + if (length <= NumContiguousBytesRemaining()) { + return true; + } else { + return FutureBlocksCanBeWritten(length - NumContiguousBytesRemaining()); + } + } + + // This method writes bytes into the buffer, handling the case where the data + // will cross blocks. + DrError CrossBlockWriteBytes(const BYTE *pBytes, Size_t length); + +protected: + // cannot be directly instantiated, must be subclassed + DrMemoryWriter(); + +private: + // Private member variables + + DrRef m_pPendingBufferedWriteBuffer; // If not NULL, a buffered write is in progress, and this is the buffer + Size_t m_pendingBufferredWriteOldAvailableSize; // If a buffered write is in progress, the available size of the buffer at the beginning of the operation + + // true if CloseMemoryWriter has already been called or is currently running + bool m_fMemoryWriterIsClosed; + + // true if the subclass wants us to ignore failures in MemoryWriterDestructorClose. + bool m_fIgnoreMemoryWriterCloseFailureInDestructor; +}; + +class DrMemoryReader : public DrMemoryStream +{ +public: + // public subclass-overridable methods + + // Default implementation just frees any buffers and returns the reader status. + // If this is the first time CloseMemoryReader() has been called, then resources are freed regardless of + // the current status, allowing this method to be used to free resources even on failed streams. + // The current status is always returned. + // Unless MemoryStreamIsClosed(), subclasses that override this method should *always* delegate to their parent class + // after freeing their own resources, and should *always* return the current status at completion. + virtual DrError CloseMemoryReader(); + + // This method frees any temporary buffers that have been allocated to store returned results (such as strings) from + // selected auto-allocating ReadXXX methods (required when the result crossed non-contiguous buffer boundaries). + // After this call is made, any result pointers from these ReadXXX methods become invalid. + // The stream itself is still valid, and reading may continue. + virtual void DiscardTemporaryResults(); + +public: + // public base-class methods + + // Returns true if CloseMemoryReader has been called and has run all the way down to the base class + __forceinline bool MemoryReaderIsClosed() + { + return m_fMemoryReaderIsClosed; + } + + inline Size_t NumContiguousBytesRead() + { + return pData - pBlockBase; + } + + inline DrError EnsureCanBeRead(Size_t length) + { + if (status != DrError_OK) { + return status; + } else if (length > NumContiguousBytesRemaining() && !CrossBlockCanBeRead(length)) { + return SetStatus(DrError_EndOfStream); + } else { + return DrError_OK; + } + } + + // Reads data from memory, without advancing the current read pointer. Uses the caller's buffer. + // Handles simple case inline. Delegates cross-block cases to CrossBlockPeekBytes. + // Returns DrError_EndOfStream (and sets status) if the stream reaches the end before all data can be read (partial data + // is still written into the byte array). + inline DrError PeekBytes(/* out */BYTE *pBytes, Size_t length) + { + if (status != DrError_OK) { + return status; + } else if (length > NumContiguousBytesRemaining()) { + return CrossBlockPeekBytes(pBytes, length); + } else { + memcpy(pBytes, pData, length); + return DrError_OK; + } + } + + // Reads data from memory, advancing the current read pointer. Uses the caller's buffer. + // Handles simple case inline. Delegates cross-block cases to CrossBlockReadBytes. + // Returns DrError_EndOfStream (and sets status) if the stream reaches the end before all data can be read (partial data + // is still written into the byte array). + inline DrError ReadBytes(/* out */BYTE *pBytes, Size_t length) + { + if (status != DrError_OK) { + return status; + } else if (length > NumContiguousBytesRemaining()) { + return CrossBlockReadBytes(pBytes, length); + } else { + memcpy(pBytes, pData, length); + pData += length; + return DrError_OK; + } + } + + // reads bytes from this stream, and writes them into the destination stream. + // Note that an error reading does not result in an error status on the writer. + // Also, if DrError_EndOfStream is encountered, then all of the remaining data in this stream + // has been written to the writer (may be less than the requested number of bytes). + DrError ReadBytesIntoWriter(DrMemoryWriter *pWriter, Size_t length); + + // Skips data, advancing the current read pointer. + // Handles simple case inline. Delegates cross-block cases to CrossBlockReadBytes. + // Returns DrError_EndOfStream (and sets status) if the stream reaches the end before the specified length can be skipped (partial data + // is still skipped). + inline DrError SkipBytes(Size_t length) + { + if (status != DrError_OK) { + return status; + } else if (length > NumContiguousBytesRemaining()) { + return CrossBlockSkipBytes(length); + } else { + pData += length; + return DrError_OK; + } + } + + inline DrError AdvanceToNonemptyPeekBlock() + { + while (status == DrError_OK && NumContiguousBytesRemaining() == 0) { + AdvanceToNextPeekBlock(); + } + return status; + } + + // Reads data from memory, without advancing the current read pointer. If the peek is + // cross-block, copies the data into a temporary memory space and returns the copy; otherwise, + // returns a pointer directly into the source block. The returned pointer is valid until + // this DrMemoryReader is destroyed. + // Handles simple case inline. Delegates cross-block cases to CrossBlockPeekBytes. + // Returns NULL if the stream reaches the end before all data can be read (partial data + // is still written into the byte array). + inline DrError PeekBytes(Size_t length, /* out */ const BYTE **ppBytes) + { + *ppBytes = NULL; + if (length > 0) { + AdvanceToNonemptyPeekBlock(); + } + if (status != DrError_OK) { + return status; + } else { + *ppBytes = pData; + return DrError_OK; + } + } + + // Reads data from memory, advancing the current read pointer. If the read is + // cross-block, copies the data into a temporary memory space and returns the copy; otherwise, + // returns a pointer directly into the source block. The returned pointer is valid until + // this DrMemoryReader is destroyed. + // Handles simple case inline. Delegates cross-block cases to CrossBlockReadBytes. + // Returns NULL if the stream reaches the end before all data can be read (partial data + // is still written into the byte array). + inline DrError ReadBytes(Size_t length, /* out */ const BYTE **ppBytes) + { + *ppBytes = NULL; + if (length > 0) { + AdvanceToNonemptyPeekBlock(); + } + if (status != DrError_OK) { + return status; + } else { + *ppBytes = pData; + pData += length; + return DrError_OK; + } + } + + /* read values of different types. Note that once an error is returned, the pointer does not advance + and all subsequent reads return an error. */ + + inline DrError ReadByte(/* out */ BYTE *pVal) + { + return ReadBytes(pVal, sizeof(*pVal)); + } + + inline DrError ReadChar(/* out */ char *pVal) + { + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError ReadInt8(/* out */ Int8 *pVal) + { + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError ReadInt16(/* out */ Int16 *pVal) + { + // assumes little endian + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError ReadInt32(/* out */ Int32 *pVal) + { + // assumes little endian + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError ReadInt64(/* out */ Int64 *pVal) + { + // assumes little endian + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError ReadUInt8(/* out */ UInt8 *pVal) + { + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError ReadUInt16(/* out */ UInt16 *pVal) + { + // assumes little endian + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError ReadUInt32(/* out */ UInt32 *pVal) + { + // assumes little endian + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError ReadUInt64(/* out */ UInt64 *pVal) + { + // assumes little endian + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError ReadUInt32ToSize_t(/* out */Size_t *pVal) + { + //assumes little endian + UInt32 val = 0; + DrError ret = ReadUInt32(&val); + *pVal = (Size_t)val; + return ret; + } + + // This will return DrError_InvalidProperty if the encoded UInt64 will not fit in a Size_t. + inline DrError ReadUInt64ToSize_t( /* out */ Size_t *pVal) + { + //assumes little endian + *pVal = 0; + UInt64 val; + DrError ret = ReadUInt64(&val); + if (ret == DrError_OK) { + if ((UInt64)(Size_t)val == val) { + *pVal = (Size_t)val; + } else { + // too big for Size_t + ret = SetStatus(DrError_InvalidProperty); + } + } + return ret; + } + + inline DrError ReadFloat(/* out */ float *pVal) + { + // assumes little endian + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError ReadDouble(/* out */ double *pVal) + { + // assumes little endian + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError ReadGuid(/* out */ GUID *pVal) + { + // assumes little endian + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError ReadBool(/* out */ bool *pVal) + { + UInt8 v; + + if (ReadBytes((BYTE *)(void *)&v, sizeof(v)) == DrError_OK) { + *pVal = (v != 0); + } + return status; + } + + inline DrError ReadDrError(/* out */ DrError *pVal) + { + // assumes little endian + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError ReadTimeStamp(/* out */ DrTimeStamp *pVal) + { + // assumes little endian + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError ReadTimeInterval(/* out */ DrTimeInterval *pVal) + { + // assumes little endian + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + // Reads data into a preallocated caller-owned buffer + // just here for netlib naming convention + inline DrError ReadData(UInt32 length, /* out */ void *data) + { + return ReadBytes((BYTE *)data, (Size_t)length); + } + + // Returns a pointer into the buffer, or into a temporary copy. + inline DrError ReadData(UInt32 length, /* out */ const void **ppData) + { + return ReadBytes((Size_t)length, (const BYTE **)ppData); + } + + inline DrError PeekByte(/* out */ BYTE *pVal) + { + return PeekBytes(pVal, sizeof(*pVal)); + } + + inline DrError PeekInt8(/* out */ Int8 *pVal) + { + return PeekBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError PeekInt16(/* out */ Int16 *pVal) + { + // assumes little endian + return PeekBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError PeekInt32(/* out */ Int32 *pVal) + { + // assumes little endian + return PeekBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError PeekInt64(/* out */ Int64 *pVal) + { + // assumes little endian + return PeekBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError PeekUInt8(/* out */ UInt8 *pVal) + { + return PeekBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError PeekUInt16(/* out */ UInt16 *pVal) + { + // assumes little endian + return PeekBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError PeekUInt32(/* out */ UInt32 *pVal) + { + // assumes little endian + return PeekBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError PeekUInt64(/* out */ UInt64 *pVal) + { + // assumes little endian + return PeekBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError PeekFloat(/* out */ float *pVal) + { + // assumes little endian + return PeekBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError PeekDouble(/* out */ double *pVal) + { + // assumes little endian + return PeekBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError PeekGuid(/* out */ GUID *pVal) + { + // assumes little endian + return PeekBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError PeekBool(/* out */ bool *pVal) + { + UInt8 v; + + if (PeekBytes((BYTE *)(void *)&v, sizeof(v)) == DrError_OK) { + *pVal = (v != 0); + } + return status; + } + + inline DrError PeekDrError(/* out */ DrError *pVal) + { + // assumes little endian + return PeekBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError PeekTimeStamp(/* out */ DrTimeStamp *pVal) + { + // assumes little endian + return PeekBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + inline DrError PeekTimeInterval(/* out */ DrTimeInterval *pVal) + { + // assumes little endian + return PeekBytes((BYTE *)(void *)pVal, sizeof(*pVal)); + } + + // Peeks data into a preallocated caller-owned buffer + // just here for netlib naming convention + inline DrError PeekData(UInt32 length, /* out */ void *data) + { + return PeekBytes((BYTE *)data, (Size_t)length); + } + + // Returns a pointer into the buffer, or into a temporary copy. + inline DrError PeekData(UInt32 length, /* out */ const void **ppData) + { + return PeekBytes((Size_t)length, (const BYTE **)ppData); + } + + inline DrError SkipByte() + { + return SkipBytes(sizeof(BYTE)); + } + + inline DrError SkipInt8() + { + return SkipBytes(sizeof(Int8)); + } + + inline DrError SkipInt16() + { + return SkipBytes(sizeof(Int16)); + } + + inline DrError SkipInt32() + { + return SkipBytes(sizeof(Int32)); + } + + inline DrError SkipInt64() + { + return SkipBytes(sizeof(Int64)); + } + + inline DrError SkipUInt8() + { + return SkipBytes(sizeof(UInt8)); + } + + inline DrError SkipUInt16() + { + return SkipBytes(sizeof(UInt16)); + } + + inline DrError SkipUInt32() + { + return SkipBytes(sizeof(UInt32)); + } + + inline DrError SkipUInt64() + { + return SkipBytes(sizeof(UInt64)); + } + + inline DrError SkipFloat() + { + return SkipBytes(sizeof(float)); + } + + inline DrError SkipDouble() + { + return SkipBytes(sizeof(double)); + } + + inline DrError SkipGuid() + { + // assumes little endian + return SkipBytes(sizeof(GUID)); + } + + inline DrError SkipBool() + { + return SkipBytes(sizeof(UInt8)); + } + + inline DrError SkipDrError() + { + return SkipBytes(sizeof(DrError)); + } + + inline DrError SkipTimeStamp() + { + // assumes little endian + return SkipBytes(sizeof(DrTimeStamp)); + } + + inline DrError SkipTimeInterval() + { + // assumes little endian + return SkipBytes(sizeof(DrTimeInterval)); + } + + inline DrError SkipData(UInt32 length) + { + return SkipBytes((Size_t)length); + } + + /* + Property marshalling methods + */ + + // ReadNextPropertyTag. Reads the next property from the bag, along + // with its length (either 1- or 4-byte depending on the length + // bit in the property name). Returns DrError_EndOfStream if there is not enough + // data in the bag to read out the property name and length, but + // does not check that there are *pDataLen more bytes remaining. + DrError ReadNextPropertyTag( + /* out */ UInt16 *pEnumId, + /* out */ UInt32 *pDataLen); + + // PeekNextPropertyTag. Peeks at the next property from the bag, along + // with its length (either 1- or 4-byte depending on the length + // bit in the property name). Returns DrError_EndOfStream if there is not enough + // data in the bag to read out the property name and length, but + // does not check that there are *pDataLen more bytes remaining. + DrError PeekNextPropertyTag( + /* out */ UInt16 *pEnumId, + /* out */ UInt32 *pDataLen); + + // ReadNextProperty: Reads the next property in the bag and fills + // in its name and length to pEnumId and pDataLen respectively, + // placing the read pointer after the property value. + // Returns a pointer to the contiguous property value (either in + // the buffer or copied to make it contiguous).If ReadNextProperty returns an error, + // the position of the read pointer and values of *pEnumId, + // *pDataLen and *data are undefined. + inline DrError ReadNextProperty( + /* out */ UInt16 *pEnumId, + /* out */ UInt32 *pDataLen, + /* out */ const void **data) + { + if (ReadNextPropertyTag(pEnumId, pDataLen) == DrError_OK) { + ReadData(*pDataLen, data); + } + + return status; + } + + // PeekNextProperty: Peeks at the next property in the bag and fills + // in its name and length to pEnumId and pDataLen respectively, + // not advancing the the read pointer. + // Returns a pointer to the contiguous property value (either in + // the buffer or copied to make it contiguous).If PeekNextProperty returns an error, + // the values of *pEnumId, + // *pDataLen and *data are undefined. + DrError PeekNextProperty( + /* out */ UInt16 *pEnumId, + /* out */ UInt32 *pDataLen, + /* out */ const void **data); + + // SkipNextProperty: Skips the next property. + inline DrError SkipNextProperty() + { + UInt16 enumId; + UInt32 dataLen; + + if (ReadNextPropertyTag(&enumId, &dataLen) == DrError_OK) { + SkipData(dataLen); + } + + return status; + } + + // ReadNextKnownProperty: Reads the next property which is of + // known ID and length into a preallocated buffer, placing the read pointer + // after the property value. If ReadNextKnownProperty returns error, + // the position of the read pointer is undefined. + // returns DrError_InvalidProperty if the enum id or length don't match. + DrError ReadNextKnownProperty( + UInt16 enumId, + UInt32 dataLen, + void *pDest); + + // PeekNextKnownProperty: Peeks at the next property which is of + // known ID and length into a preallocated buffer. + // returns DrError_InvalidProperty if the enum id or length don't match. + DrError PeekNextKnownProperty( + UInt16 enumId, + UInt32 dataLen, + void *pDest); + + + // ReadNextProperty: Returns DrError_OK if the next property in the bag + // is enumId and is well-formed. If so, pDataLen is filled in with + // its length, *data points to the contiguous property value data, + // and the read pointer is advanced to just after the value data. If + // ReadNextProperty returns an error, the position of the read + // pointer and values of *pEnumId, *pDataLen and *data are + // undefined. + // Returns DrError_EndOfStream if there is not a full property remaining + // Returns DrError_InvalidProperty if the tag doesn't match + inline DrError ReadNextProperty( + UInt16 targetId, + /* out */ UInt32 *pDataLen, + /* out */ const void **data) + { + UInt16 enumId; + + if (ReadNextProperty(&enumId, pDataLen, data) == DrError_OK) { + if (enumId != targetId) { + SetStatus(DrError_InvalidProperty); + } + } + + return status; + } + + // The following methods each return DrError_OK if the next + // property in the bag is enumId, is of the length of the relevant + // type, and is well-formed, in which case the property value is + // filled in to *pValue and the read pointer is left after the + // property value data. If ReadNextProperty returns an error, the + // position of the read pointer is undefined but *pValue is + // guaranteed to be unmodified. + + + inline DrError ReadNextEmptyProperty(UInt16 enumId) + { + return ReadNextKnownProperty(enumId, 0, NULL); + } + + inline DrError ReadNextInt8Property(UInt16 enumId, /* out */ Int8 *pValue) + { + return ReadNextKnownProperty(enumId, sizeof(Int8), pValue); + } + + inline DrError ReadNextInt16Property(UInt16 enumId, /* out */ Int16 *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(Int16), pValue); + } + + inline DrError ReadNextInt32Property(UInt16 enumId, /* out */ Int32 *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(Int32), pValue); + } + + inline DrError ReadNextInt64Property(UInt16 enumId, /* out */ Int64 *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(Int64), pValue); + } + + inline DrError ReadNextUInt8Property(UInt16 enumId, /* out */ UInt8 *pValue) + { + return ReadNextKnownProperty(enumId, sizeof(UInt8), pValue); + } + + inline DrError ReadNextUInt16Property(UInt16 enumId, /* out */ UInt16 *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(UInt16), pValue); + } + + inline DrError ReadNextUInt32Property(UInt16 enumId, /* out */ UInt32 *pValue) + { + //assumes little endian + return ReadNextKnownProperty(enumId, sizeof(UInt32), pValue); + } + + inline DrError ReadNextUInt64Property(UInt16 enumId, /* out */ UInt64 *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(UInt64), pValue); + } + + inline DrError ReadNextUInt32PropertyToSize_t(UInt16 enumId, /* out */ Size_t *pValue) + { + //assumes little endian + UInt32 val = 0; + DrError ret = ReadNextUInt32Property(enumId, &val); + *pValue = (Size_t)val; + return ret; + } + + // This will return DrError_InvalidProperty if the encoded UInt64 will not fit in a Size_t. + inline DrError ReadNextUInt64PropertyToSize_t(UInt16 enumId, /* out */ Size_t *pValue) + { + //assumes little endian + *pValue = 0; + UInt64 val; + DrError ret = ReadNextUInt64Property(enumId, &val); + if (ret == DrError_OK) { + if ((UInt64)(Size_t)val == val) { + *pValue = (Size_t)val; + } else { + // too big for Size_t + ret = SetStatus(DrError_InvalidProperty); + } + } + + return ret; + } + + inline DrError ReadNextFloatProperty(UInt16 enumId, /* out */ float *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(float), pValue); + } + + inline DrError ReadNextDoubleProperty(UInt16 enumId, /* out */ double *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(double), pValue); + } + + inline DrError ReadNextGuidProperty(UInt16 enumId, /* out */ GUID *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(GUID), pValue); + } + + inline DrError ReadNextBoolProperty(UInt16 enumId, /* out */ bool *pValue) + { + UInt8 v; + if (ReadNextKnownProperty(enumId, sizeof(UInt8), &v) == DrError_OK) { + *pValue = (v != 0); + } + return status; + } + + inline DrError ReadNextDrErrorProperty(UInt16 enumId, /* out */ DrError *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(DrError), pValue); + } + + inline DrError ReadNextTimeStampProperty(UInt16 enumId, /* out */ DrTimeStamp *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(DrTimeStamp), pValue); + } + + inline DrError ReadNextTimeIntervalProperty(UInt16 enumId, /* out */ DrTimeInterval *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(DrTimeInterval), pValue); + } + + // Reads a string property that has been encoded with WriteStringProperty. + // If the string in the stream is longer than maxLength (not including null), DrError_StringTooLong is returned + DrError ReadNextStringProperty(UInt16 enumId, /* out */ const char **ppStr, Size_t maxLength = Max_Size_t); + + /* Read a string property from the buffer into a preallocated buffer. + If the embedded string is NULL, an empty string is returned. + If the string in the stream is longer than buffLength (including null), DrError_StringTooLong is returned + */ + DrError ReadNextStringProperty(UInt16 enumId, char *pStr, Size_t buffLength); + + // Reads a string property that has been encoded with WriteStringProperty. + // If the string in the stream is longer than maxLength (not including null), DrError_StringTooLong is returned + DrError ReadOrAppendNextStringProperty(bool fAppend, UInt16 enumId, /* out */ DrStr& strOut, Size_t maxLength = Max_Size_t); + + // The following methods each return DrError_OK if the next + // property in the bag is enumId, is of the length of the relevant + // type, and is well-formed, in which case the property value is + // filled in to *pValue. If PeekNextProperty returns an error, *pValue is + // guaranteed to be unmodified. + + inline DrError PeekNextEmptyProperty(UInt16 enumId) + { + return PeekNextKnownProperty(enumId, 0, NULL); + } + + inline DrError PeekNextInt8Property(UInt16 enumId, /* out */ Int8 *pValue) + { + return PeekNextKnownProperty(enumId, sizeof(Int8), pValue); + } + + inline DrError PeekNextInt16Property(UInt16 enumId, /* out */ Int16 *pValue) + { + // assumes little endian + return PeekNextKnownProperty(enumId, sizeof(Int16), pValue); + } + + inline DrError PeekNextInt32Property(UInt16 enumId, /* out */ Int32 *pValue) + { + // assumes little endian + return PeekNextKnownProperty(enumId, sizeof(Int32), pValue); + } + + inline DrError PeekNextInt64Property(UInt16 enumId, /* out */ Int64 *pValue) + { + // assumes little endian + return PeekNextKnownProperty(enumId, sizeof(Int64), pValue); + } + + inline DrError PeekNextUInt8Property(UInt16 enumId, /* out */ UInt8 *pValue) + { + return PeekNextKnownProperty(enumId, sizeof(UInt8), pValue); + } + + inline DrError PeekNextUInt16Property(UInt16 enumId, /* out */ UInt16 *pValue) + { + // assumes little endian + return PeekNextKnownProperty(enumId, sizeof(UInt16), pValue); + } + + inline DrError PeekNextUInt32Property(UInt16 enumId, /* out */ UInt32 *pValue) + { + //assumes little endian + return PeekNextKnownProperty(enumId, sizeof(UInt32), pValue); + } + + inline DrError PeekNextUInt64Property(UInt16 enumId, /* out */ UInt64 *pValue) + { + // assumes little endian + return PeekNextKnownProperty(enumId, sizeof(UInt64), pValue); + } + + inline DrError PeekNextFloatProperty(UInt16 enumId, /* out */ float *pValue) + { + // assumes little endian + return PeekNextKnownProperty(enumId, sizeof(float), pValue); + } + + inline DrError PeekNextDoubleProperty(UInt16 enumId, /* out */ double *pValue) + { + // assumes little endian + return PeekNextKnownProperty(enumId, sizeof(double), pValue); + } + + inline DrError PeekNextGuidProperty(UInt16 enumId, /* out */ GUID *pValue) + { + // assumes little endian + return PeekNextKnownProperty(enumId, sizeof(GUID), pValue); + } + + inline DrError PeekNextBoolProperty(UInt16 enumId, /* out */ bool *pValue) + { + UInt8 v; + if (PeekNextKnownProperty(enumId, sizeof(UInt8), &v) == DrError_OK) { + *pValue = (v != 0); + } + return status; + } + + inline DrError PeekNextDrErrorProperty(UInt16 enumId, /* out */ DrError *pValue) + { + //assumes little endian + return PeekNextKnownProperty(enumId, sizeof(DrError), pValue); + } + + inline DrError PeekNextTimeStampProperty(UInt16 enumId, /* out */ DrTimeStamp *pValue) + { + // assumes little endian + return PeekNextKnownProperty(enumId, sizeof(DrTimeStamp), pValue); + } + + inline DrError PeekNextTimeIntervalProperty(UInt16 enumId, /* out */ DrTimeInterval *pValue) + { + // assumes little endian + return PeekNextKnownProperty(enumId, sizeof(DrTimeInterval), pValue); + } + + // Consumes the (BeginTag, desiredTagType) property and closing (EndTag, desiredTagType) property, and calls you back + // on parser->OnParseProperty() for each decoded property. Each property it calls you back on has only been peeked, + // so you will need to read or skip over it. If another BeginTag appears, you will be called back with that. + DrError ReadAggregate(UInt16 desiredTagType, DrPropertyParser *parser, void *cookie); + + // If the next property is not a BeginTag, it simply skips it. + // If the next property is a BeginTag, then it skips everything through and including the EndTag, + // and handles recursion + DrError SkipNextPropertyOrAggregate(); + +public: + // Deprecated methods + // deprecated, use ReadNextEmptyProperty + inline DRDEPRECATED_UNTYPED DrError ReadNextProperty(UInt16 enumId) + { + return ReadNextKnownProperty(enumId, 0, NULL); + } + + inline DRDEPRECATED_UNTYPED DrError ReadNextProperty(UInt16 enumId, /* out */ Int8 *pValue) + { + return ReadNextKnownProperty(enumId, sizeof(Int8), pValue); + } + + inline DRDEPRECATED_UNTYPED DrError ReadNextProperty(UInt16 enumId, /* out */ Int16 *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(Int16), pValue); + } + + inline DRDEPRECATED_UNTYPED DrError ReadNextProperty(UInt16 enumId, /* out */ Int32 *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(Int32), pValue); + } + + inline DRDEPRECATED_UNTYPED DrError ReadNextProperty(UInt16 enumId, /* out */ Int64 *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(Int64), pValue); + } + + inline DRDEPRECATED_UNTYPED DrError ReadNextProperty(UInt16 enumId, /* out */ UInt8 *pValue) + { + return ReadNextKnownProperty(enumId, sizeof(UInt8), pValue); + } + + inline DRDEPRECATED_UNTYPED DrError ReadNextProperty(UInt16 enumId, /* out */ UInt16 *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(UInt16), pValue); + } + + inline DRDEPRECATED_UNTYPED DrError ReadNextProperty(UInt16 enumId, /* out */ UInt32 *pValue) + { + //assumes little endian + return ReadNextKnownProperty(enumId, sizeof(UInt32), pValue); + } + + inline DRDEPRECATED_UNTYPED DrError ReadNextProperty(UInt16 enumId, /* out */ UInt64 *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(UInt64), pValue); + } + + inline DRDEPRECATED_UNTYPED DrError ReadNextProperty(UInt16 enumId, /* out */ float *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(float), pValue); + } + + inline DRDEPRECATED_UNTYPED DrError ReadNextProperty(UInt16 enumId, /* out */ double *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(double), pValue); + } + + inline DRDEPRECATED_UNTYPED DrError ReadNextProperty(UInt16 enumId, /* out */ GUID *pValue) + { + // assumes little endian + return ReadNextKnownProperty(enumId, sizeof(GUID), pValue); + } + + // Reads a string property that has been encoded with WriteStringProperty. + // If the string in the stream is longer than maxLength (not including null), DrError_StringTooLong is returned + // NOTE: may be deprecated, please use ReadNextDrStrProperty to disambiguate + DrError ReadNextStringProperty(UInt16 enumId, /* out */ DrStr& strOut, Size_t maxLength = Max_Size_t) + { + return ReadOrAppendNextStringProperty(false, enumId, strOut, maxLength); + } + +protected: + // protected base-class methods + + // Must be subclassed, cannot be constructed directly: + DrMemoryReader(); + + // It is the responsibility of the subclass to deal with the + // case where the reader may not have been closed at destruct time, and clean up without crashing. + // Calling MemoryReaderDescructorClose is usually a good way to to that + void MemoryReaderDestructorClose(); + + // true if the underlying data is immutable and refcounted, so that a DrMemoryBuffer can be wrapped around a portion of it. + __forceinline void SetMemoryReaderAllowsCopyByReference(bool fMemoryReaderAllowsCopyByReference) + { + fAllowCopyBufferByReference = fMemoryReaderAllowsCopyByReference; + } + +public: + // virtual destructor: + + // It is the responsibility of the subclass to deal with the + // case where the reader may not have been closed at destruct time, and clean up without crashing. + // Calling MemoryReaderDescructorClose is usually a good way to to that + virtual ~DrMemoryReader(); + +protected: + // Protected subclass-overridable methods + + // Clears all buffer context values to initial defaults (including settting the physical stream position to 0 + // and discarding peekahead buffers), but does not discard temporary results or change the current status code + // or "closed" status. + // may be overridden by a subclass if it needs to free resources controlled by the + // context pointers before delegating to this method implementation. + virtual void DiscardMemoryReaderContext(); + + // This method clears all buffer context values to initial defaults (including settting the physical stream position to 0, + // discarding peekahead buffers and temporary results, clearing the "closed" state, and setting the status to DrError_OK. + // a subclass may override this to reset its own state along with forwarding the request to this class. + // Calls through the virtual DiscardMemoryReaderContext() before clearing the status and the closed flag. + virtual void ResetMemoryReader(); + + // This method should be overidden by memory readers that know how to advance to a new block. + // Returns DrError_EndOfStream if there are no more blocks to be read. + // The default implementation always returns DrError_EndOfStream, which is appropriate for + // single-block readers. + virtual DrError ReadNextBlock(/* out */ const BYTE **ppBytes, /* out */ Size_t *pLength); + + // Reads data from blocks starting *after* the current block, without advancing the current read pointer. + // Returns DrError_EndOfStream if the stream reaches the end before all data can be read (partial data + // may still be written into the byte array). + // Returns DrError_PeekTooFar if the underlying implemntation cannot peek that far forward (in this + // case, we will directly read the data and append it to a cached peek list). + virtual DrError FutureBlockPeekBytes(/* out */ void *pBytes, Size_t length); + + // This method should be overridden by memory readers that know how to advance to a new block. + // The implementation should return true if there are at least "length" readable bytes + // following the current block. + // + // The default implementation always returns false, which is appropriate for single-block readers + virtual bool FutureBlocksCanBeRead(Size_t length); + +private: + // private methods + + // called by close + destructor. Discards temporary results and peek blocks. + void InternalFree(); + + // Discards any lookahead peek blocks that have been allocated. + // This may invalidate the current block, so the current block should always be + // reinitialized after calling this. + void DiscardPeekBlocks(); + + // This method determines if a number of bytes can successfully be read, handling + // the case where the data will cross blocks. + inline bool CrossBlockCanBeRead(Size_t length) + { + Size_t nr = NumContiguousBytesRemaining(); + if (length <= nr) { + return true; + } else { + return FutureBlocksCanBeRead(length - nr); + } + } + + void AllocTempMemBlock(Size_t minLength); + + // Reserves a block of temporary memory that will be valid until this DrMemoryReader + // is destroyed. + BYTE *ReserveTempMemory(Size_t length); + + // Reads data from memory without advancing the current read pointer. + // Handles cross-block cases. + // Returns DrError_EndOfStream if the stream reaches the end before all data can be read (partial data + // is still wriiten into the byte array). + DrError CrossBlockPeekBytes(/* out */ BYTE *pBytes, Size_t length); + + // Reads data from memory, advancing the current read pointer. + // Handles cross-block cases. + // Returns DrError_EndOfStream if the stream reaches the end before all data can be read (partial data + // is still wriiten into the byte array). + DrError CrossBlockReadBytes(/* out */ BYTE *pBytes, Size_t length); + + // Skips data in memory, advancing the current read pointer. + // Handles cross-block cases. + // Returns DrError_EndOfStream if the stream reaches the end before all data can be skipped (partial data + // is still skipped). + DrError CrossBlockSkipBytes(Size_t length); + + // Appends the next block from the underlying stream to the list of peekable data blocks. + DrError AppendNextBlock(); + + // Advances the current block to the next available peek block, reading a new block if necessary + DrError AdvanceToNextPeekBlock(); + +private: + // Private type definitions + + // TempMemHeader allows us to maintain a stack of temporarily allocated + // return data (e.g., strings) that is cleaned up when the DrMemoryReader + // is destroyed. + class TempMemHeader + { + private: + TempMemHeader *pNext; + BYTE *pData; + Size_t length; + + // We override "new" to allocate the header and the content + // in a single allocation. + inline void *operator new(Size_t headersize, Size_t blocksize) + { + LogAssert(headersize == sizeof(TempMemHeader)); + LogAssert(headersize + blocksize >= headersize); // keep prefast happy + void *p = malloc(headersize + blocksize); + LogAssert(p != NULL); + return p; + } + + inline TempMemHeader(Size_t blocksize, TempMemHeader *pOldHead) + { + pData = ((BYTE *)(void *)this) + sizeof(*this); + length = blocksize; + pNext = pOldHead; + } + + public: + // We have to provide a matching delete... + inline void operator delete(void *pMem, Size_t blocksize) + { + (void)blocksize; + free(pMem); + } + + inline static TempMemHeader *Alloc(Size_t blocksize, TempMemHeader *pOldHead) + { + TempMemHeader *p = new(blocksize) TempMemHeader(blocksize, pOldHead); + return p; + } + + inline TempMemHeader *Detach() + { + TempMemHeader *p = pNext; + pNext = NULL; + return p; + } + + inline ~TempMemHeader() + { + LogAssert(pNext == NULL); + } + + inline BYTE *GetData() + { + return pData; + } + + inline Size_t GetLength() + { + return length; + } + + inline BYTE *ReserveData(Size_t len) + { + LogAssert(len <= length); + BYTE *pd = pData; + pData += len; + length -= len; + return pd; + } + + }; + + // PeekMemHeader allows us to maintain a queue of lookahead data + // for streams that don't support peek + class PeekMemHeader + { + private: + PeekMemHeader *pNext; + BYTE *pData; + Size_t length; + bool fOwnMemory; // True if the block at pData is owned by this object and should be deleted + + public: + // If pBytes is not null, we don't own the memory block. If pBytes is NULL, a block is allocated that we own. + inline PeekMemHeader(Size_t blocksize, PeekMemHeader *pOldTail, const void *pBytes = NULL) + { + fOwnMemory = (pBytes == NULL); + if (fOwnMemory) { + #pragma prefast(disable:419, "Don't need to check blocksize") + pData = (BYTE *)malloc(blocksize); + #pragma prefast(enable:419, "End prefast suppression") + LogAssert(pData != NULL); + } else { + pData = (BYTE *)(void *)pBytes; + } + length = blocksize; + pNext = NULL; + if (pOldTail != NULL) { + pOldTail->pNext = this; + } + } + + inline void EnsureIsOwned() + { + if (!fOwnMemory) { + BYTE *pNew = (BYTE *)malloc(length); + LogAssert(pNew != NULL); + memcpy(pNew, pData, length); + pData = pNew; + fOwnMemory = true; + } + } + + inline PeekMemHeader *Detach() + { + PeekMemHeader *p = pNext; + pNext = NULL; + return p; + } + + inline ~PeekMemHeader() + { + LogAssert(pNext == NULL); + if (fOwnMemory) { + free(pData); + } + } + + inline BYTE *GetData() + { + return pData; + } + + inline Size_t GetLength() + { + return length; + } + + inline PeekMemHeader *GetNext() + { + return pNext; + } + + }; + + // End private type definitions + +private: + // private member data + static const Size_t DEFAULT_TEMP_MEM_ALLOC_SIZE = 16384; + + TempMemHeader *pFirstTempMemHeader; + PeekMemHeader *pFirstPeekMemHeader; + PeekMemHeader *pLastPeekMemHeader; + + // true if CloseMemoryReader has already been called or is currently running + bool m_fMemoryReaderIsClosed; + + // true if the underlying data is immutable and refcounted, so that a DrMemoryBuffer can be wrapped around a portion of it. + bool fAllowCopyBufferByReference; +}; + + +// A simple writer that can only write into a fixed-size, ungrowable contiguous block of memory. +class DrSingleBlockWriter : public DrMemoryWriter +{ +public: + inline DrSingleBlockWriter(void *pBytes, Size_t length) + { + pData = pBlockBase = (BYTE *)pBytes; + blockLength = length; + } +}; + +// A simple reader that can only read from a single contiguous block of memory. +class DrSingleBlockReader : public DrMemoryReader +{ +public: + inline DrSingleBlockReader(const void *pBytes, Size_t length) + { + pData = pBlockBase = (BYTE *)(void *)pBytes; + blockLength = length; + } +}; + +// A simple writer that can write into a potentially growable DrMemoryBuffer +class DrMemoryBufferWriter : public DrMemoryWriter +{ +private: + DrRef m_pBuffer; // The discontiguous memory buffer we are writing to + bool m_fTruncateOnFlush; // True if the memory buffer we are writing to should be truncated to our current write position when we flush + +public: + inline DrMemoryBufferWriter(DrMemoryBuffer *pMemoryBuffer, bool fTruncateOnFlush = true) + { + m_pBuffer = pMemoryBuffer; + m_fTruncateOnFlush = fTruncateOnFlush; + } + + inline DrMemoryBufferWriter() + { + m_fTruncateOnFlush = true; + } + + // Set the current writing position in the buffer. Increases the current available size in the buffer if necessary. + // The physical stream position is adjusted to be consistent relative to the current "stream origin", if possible. + // Since the default physical stream position at the beginning of the buffer is 0, this means that the physical stream + // position by default becomes equal to the "offset" parameter. + // If the change would cause the current physical stream position to become negative, + // then the current physical stream position is set to 0. This could happen if, e.g., you wer at + // buffer offset 50, then you explicitly set the physical stream position to 0, then you set the buffer offset back to 0. + DrError SetBufferOffset(Size_t offset); + + Size_t GetBufferOffset() + { + return (Size_t)GetPhysicalStreamPosition(); + } + + virtual DrError FlushMemoryWriter(); + + // NOTE: after you call this method, the internal refcount to the buffer will be released. + virtual DrError CloseMemoryWriter(); + + virtual ~DrMemoryBufferWriter(); + +protected: + virtual DrError AdvanceToNextBlock(); + + virtual bool FutureBlocksCanBeWritten(Size_t length); + +private: + void InternalFree(); +}; + +// a simple reader that can read from a refcounted DrMemoryBuffer +class DrMemoryBufferReader : public DrMemoryReader +{ +private: + DrRef m_pBuffer; // The discontiguous memory buffer we are reading from (refcounted) + Size_t nextReadOffset; // The offset within the DrMemoryBuffer that should be used for the next call to ReadNextBlock + +public: + inline DrMemoryBufferReader(DrMemoryBuffer *pMemoryBuffer, bool fAllowCopyByReference = false) + { + SetMemoryReaderAllowsCopyByReference(fAllowCopyByReference); + nextReadOffset = 0; + m_pBuffer = pMemoryBuffer; + } + + // If true, references to the provided buffer may be handled out to the reading app rather + // than making copies. + void SetAllowCopyByReference(bool fAllow=true) + { + SetMemoryReaderAllowsCopyByReference(fAllow); + } + + // Set the current reading position in the buffer. Has no effect if there is an error status + DrError SetBufferOffset(Size_t offset); + + Size_t GetBufferOffset() + { + return (Size_t)uBlockBasePhysicalStreamPosition + NumContiguousBytesRead(); + } + + // returns the total number of bytes remaining to be read in the buffer + Size_t GetTotalAvailableBufferRemaining() + { + if (status == DrError_OK) { + Size_t available = m_pBuffer->GetAvailableSize(); + Size_t used = GetBufferOffset(); + LogAssert(available >= used); + return available - used; + } else { + return 0; + } + } + + // Default implementation just frees any temporary buffers and returns the reader status. + // If this is the first time CloseMemoryReader() has been called, then resources are freed regardless of + // the current status, allowing this method to be used to free resources even on failed streams. + // The current status is always returned. + // Unless MemoryStreamIsClosed(), subclasses that override this method should *always* delegate to their parent class + // after freeing their own resources, and should *always* return the current status at completion. + // NOTE: this method frees the reference to the buffer that is being read from + virtual DrError CloseMemoryReader(); + + virtual ~DrMemoryBufferReader(); + +protected: + virtual DrError ReadNextBlock(/* out */ const BYTE **ppBytes, /* out */ Size_t *pLength); + + virtual DrError FutureBlockPeekBytes(/* out */ void *pBytes, Size_t length); + virtual bool FutureBlocksCanBeRead(Size_t length); + +private: + void InternalFree(); + +}; + + + +#pragma warning (pop) diff --git a/DryadVertex/VertexHost/system/classlib/include/DrNodeAddress.h b/DryadVertex/VertexHost/system/classlib/include/DrNodeAddress.h new file mode 100644 index 0000000..6ba6092 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrNodeAddress.h @@ -0,0 +1,897 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#ifndef __DRYADNODEADDRESS_H__ +#define __DRYADNODEADDRESS_H__ + +#pragma once + +#pragma prefast(push) +#pragma prefast(disable:24002, "struct sockaddr not ipv6 compatible") + +class DrMemoryReader; +class DrMemoryWriter; +class DrLastAccessTable; + +extern DrLastAccessTable *g_pDrLastAccessTable; + +void DrInitLastAccessTable(); + +/* + * Dryad Node Address + * + * Used to represent the address of a cosmos node (EN, DRM, etc.) on the network + */ + + +static const int k_MaxHostNameLength = 255; +static const int k_MaxIpAddressesPerHostName = 16; + +class DrNodeAddress +{ +public: + + //Default c'tor. Nulls out both the internet address and port + inline DrNodeAddress(); + + inline DrNodeAddress( const char *pszName, DrPortNumber defaultPort ) + { + Set( pszName, defaultPort ); + } + + //Copy c'tor. + inline DrNodeAddress(const DrNodeAddress& addr); + + inline DrNodeAddress& operator=(const DrNodeAddress& addr); + + //Create address with specified internet address and port + inline DrNodeAddress(const IN_ADDR& ina, DrPortNumber wPort); + + // sets using an IP address in network byte order + inline void Set(const IN_ADDR& ina, DrPortNumber wPort); + + // returns false if not a valid ip/node address + bool Set(const struct sockaddr *pSockAddr, Size_t addrLen) + { + return Set((const struct sockaddr_in *)(const void *)pSockAddr, addrLen); + } + + // returns false if not a valid ip/node address + bool Set(const struct sockaddr_in *pSockAddr, Size_t addrLen = sizeof(struct sockaddr_in)); + + // sets using an IP address in host byte order + inline void Set(DrIpAddress ipAddr, DrPortNumber wPort); + + inline void Clear() + { + Set(DrInvalidIpAddress, DrInvalidPortNumber); + } + + inline bool operator==(const DrNodeAddress& other) const + { + return (m_ina.S_un.S_addr == other.m_ina.S_un.S_addr && m_wPort == other.m_wPort); + } + + inline bool operator!=(const DrNodeAddress& other) const + { + return (m_ina.S_un.S_addr != other.m_ina.S_un.S_addr || m_wPort != other.m_wPort); + } + + inline void SetLocal(DrPortNumber wPort) + { + Set(DrLocalIpAddress, wPort); + } + + void GetSockAddr(struct sockaddr_in *pAddr, Size_t len = sizeof(struct sockaddr_in)) const + { + LogAssert (len >= sizeof(struct sockaddr_in)); + memset(pAddr, 0, len); + pAddr->sin_family = AF_INET; + pAddr->sin_port = htons(m_wPort); + pAddr->sin_addr = m_ina; + } + + + // Gets the IP address in host byte order + inline DrIpAddress GetIpAddress() const + { + return ntohl(m_ina.S_un.S_addr); + } + + // Gets the IP address in network byte order + inline UInt32 GetIpAddressNetworkOrder() const + { + return m_ina.S_un.S_addr; + } + + // Gets the IP address in network byte order + const IN_ADDR& GetInAddr() const + { + return m_ina; + } + + // Gets the port number in host byte order + inline DrPortNumber GetPort() const + { + return m_wPort; + } + + // Sets the port number from a host byte ordered value + inline void SetPort(DrPortNumber port) + { + m_wPort = port; + } + + inline bool IsNull() const + { + return m_ina.S_un.S_addr == 0 && m_wPort == 0; + } + + inline void SetToNull() + { + m_ina.S_un.S_addr=0; + m_wPort=0; + } + + // Looks up a host name using DNS + // Note that this is a blocking request + // Returns up to addressBuffLen entries. If there are more entries than this, the list is truncated without error. + // Returns DrError_HostNotFound if no hosts match the name. + static DrError LookupHostName( + const char *pszHostName, + /* out */ DrIpAddress *pAddressBuff, + UInt32 addressBuffLen, + /* out */ UInt32 *pNumReturnedAddresses); + + // Parses a name in the form "#.#.#.#:port" or "dns-name:port" and splits out the host name and port. + // If ":port" is missing, uses the default port. + // Returns DrError_InvalidParameter if the string is malformed. + static DrError ParseHostPortName(/* out */ char *pHostNameBuffer, Size_t buffLen, /*out */ DrPortNumber *pPort, const char *pszName, DrPortNumber defaultPort = DrAnyPortNumber, UInt32 *pInstanceNumOut = NULL); + + // Parses a name in the form "#.#.#.#:port" or "dns-name:port" and splits out the host name and port. + // If ":port" is missing, uses the default port. + // strOut is replaced with the parsed host name + // Returns DrError_InvalidParameter if the string is malformed. + static DrError ParseHostPortName(DrStr& strOut, /*out */ DrPortNumber *pPort, const char *pszName, DrPortNumber defaultPort = DrAnyPortNumber, UInt32 *pInstanceNumOut = NULL); + + // Parses a stringified IP address in the form "#.#.#.#" into a host-order IP address. + // Returns DrError_InvalidParameter if the string is malformed. + static DrError ParseIpAddress(const char *pszIpAddress, /* out */ DrIpAddress *pIpAddr); + + // Parses a name in the form "#.#.#.#:port" or "dns-name:port" and resolves it to an address. + // If ":port" is missing, uses the default port. + // If there is more than one address associated with a DNS name, uses the first one. + // Returns DrError_InvalidParameter if the string is malformed. + // Note that this method may block for DNS resolution if a DNS name is used. + DrError Set(const char *pszName, DrPortNumber defaultPort=DrAnyPortNumber); + + // Converts the contained IP/port address to a string of the form "#.#.#.#:port". If the contained port number matches defaultPort, the + // port number is not included in the string. + // buffSize must be at least 22 or DrError_StringTooLong is returned. + DrError ToAddressPortString(char *pBuffer, Size_t buffSize, DrPortNumber defaultPort = DrAnyPortNumber) const; + DrError ToAddressPortString(WCHAR *pBuffer, Size_t buffSize, DrPortNumber defaultPort = DrAnyPortNumber) const; + DrStr& AppendToString(DrStr& strOut, DrPortNumber defaultPort = DrAnyPortNumber) const; + //JC DrWStr& AppendToString(DrWStr& strOut, DrPortNumber defaultPort = DrAnyPortNumber) const; + + // Generates a 32-bit hash of the node address + inline UInt32 Hash() const + { + return (UInt32)GetIpAddress() + (UInt32)GetPort(); + } + +private: + IN_ADDR m_ina; // IP address in network byte order + DrPortNumber m_wPort; // Port number in host byte order + +}; + +class DrNodeAddressString : public DrStr32 +{ +public: + DrNodeAddressString(const DrNodeAddress& addr) + { + addr.AppendToString(*this, DrAnyPortNumber); + } +}; + +// This macro can be used to obtain a temporary "const char *" string equivalent for a node address. It can be used +// as the parameter to a method call; the pointer will become invalid after the function returns +#define DRNODEADDRESSSTRING(addr) (DrNodeAddressString(addr).GetString()) + + +// DrNodeAddressList is a list of node addresses, as you might get from resolving a name. It is more efficient when there is only one name. +class DrNodeAddressList +{ +public: + DrNodeAddressList() + { + m_pMultipleAddresses = &m_singleAddress; + m_numAllocated = 1; + m_numEntries = 0; + } + + ~DrNodeAddressList() + { + Clear(); + } + + // discards all the entries, but doesn't free memory + void Discard() + { + m_numEntries = 0; + } + + // Discards all entries and frees allocated memory. + void Clear() + { + if (m_pMultipleAddresses != &m_singleAddress) { + delete[] m_pMultipleAddresses; + m_pMultipleAddresses = &m_singleAddress; + m_numAllocated = 1; + } + m_numEntries = 0; + } + + void GrowTo(UInt32 numAllocated) + { + if (numAllocated > m_numAllocated) { + UInt32 nNew = 2 * m_numAllocated; + if (nNew < numAllocated) { + nNew = numAllocated; + } + if (nNew < 8) { + nNew = 8; + } + DrNodeAddress *pNew = new DrNodeAddress[nNew]; + if (m_numEntries > 0) { + memcpy(pNew, m_pMultipleAddresses, m_numEntries * sizeof(DrNodeAddress)); + } + if (m_pMultipleAddresses != &m_singleAddress) { + delete[] m_pMultipleAddresses; + } + m_pMultipleAddresses = pNew; + m_numAllocated = nNew; + } + } + + DrNodeAddress *AddEntry(const DrNodeAddress *pOther = NULL) + { + GrowTo(m_numEntries+1); + DrNodeAddress *pEntry = m_pMultipleAddresses + m_numEntries; + m_numEntries++; + + if (pOther != NULL) { + (*pEntry) = (*pOther); + } else { + pEntry->Clear(); + } + + return pEntry; + } + + DrNodeAddress& operator[](UInt32 index) + { + LogAssert(index < m_numEntries); + return m_pMultipleAddresses[index]; + } + + const DrNodeAddress& operator[](UInt32 index) const + { + LogAssert(index < m_numEntries); + return m_pMultipleAddresses[index]; + } + + UInt32 GetLength() const + { + return m_numEntries; + } + + // This call may block for DNS + // It resolves the specified host name (with optional ":port") to a list of IP addresses and *appends* those to this DrNodeAddressList, filling in + // the port number for each. + // Note that since this request appends to the existing list, you must Clear() or Discard() the list before you make this + // call if you want the results to replace the existing set. + // Returns DrError_HostNotFound if no hosts match the name. + DrError ResolveHostName(const char *pszHostName, DrPortNumber defaultPort=DrInvalidPortNumber); + +private: + DrNodeAddress m_singleAddress; + DrNodeAddress *m_pMultipleAddresses; // Points to m_singleAddress if there is 1 entry; otherwise, a heap array. + UInt32 m_numAllocated; + UInt32 m_numEntries; +}; + + +// After the first failure to send to a host, don't send to it again until Now + this interval +const DrTimeInterval k_initialDelayedSendTimeInterval = DrTimeInterval_Second * 10; + +// Don't let the delayed send interval grow beyond this +const DrTimeInterval k_maxDelayedSendInterval = DrTimeInterval_Second * 60; + +// XStream-specific conversions between fabric fault domain and autopilot pod names +//JC DrError DrPodNameToFaultDomain(__in PCSTR pszPodName, __out XsFaultDomain *pFaultDomainOut); + +// The return value is an internalized string of the for "pod%u"; +//JC PCSTR DrFaultDomainToPodName( __in XsFaultDomain faultDomain); + + + +// Manages a host name and a port number +// Also keeps track of a fault domain and an upgrade domain (for use in load balancing/load optimization) +class DrHostAndPort +{ +public: + DrHostAndPort() + { + m_pszHostName = NULL; + m_portNumber = DrInvalidPortNumber; + m_pszPodName = NULL; +//JC m_upgradeDomain = 0; +//JC m_faultDomain = 0; +//JC m_fValidFaultDomain = true; + m_fValidPod = true; + } + + ~DrHostAndPort() + { + } + + DrHostAndPort& Set(const DrHostAndPort& other) + { + m_pszHostName = other.m_pszHostName; + m_portNumber = other.m_portNumber; + m_pszPodName = other.m_pszPodName; +//JC m_upgradeDomain = other.m_upgradeDomain; +//JC m_faultDomain = other.m_faultDomain; +//JC m_fValidFaultDomain = other.m_fValidFaultDomain; + m_fValidPod = other.m_fValidPod; + return *this; + } + + DrHostAndPort(const DrHostAndPort& other) + { + Set(other); + } + + DrHostAndPort& operator=(const DrHostAndPort& other) + { + return Set(other); + } + + // Note, doesn't compare pod and upgrade domain + bool operator==(const DrHostAndPort &other) const{ + // only need to compare hostname addresses since they are internalized + return (m_portNumber == other.m_portNumber) && + (m_pszHostName == other.m_pszHostName); + } + +/*JC +// Note: this replaces the fault domain with the one encoded in the pod name + void Set(const char *pszHostName, DrPortNumber port, const char *pszPodName, XsUpgradeDomain upgradeDomain) + { + m_pszHostName = g_DrInternalizedStrings.InternalizeStringLowerCase(pszHostName); + m_portNumber = port; + m_pszPodName = g_DrInternalizedStrings.InternalizeStringLowerCase(pszPodName); + m_fValidPod = true; +//JC ReplaceFaultDomainFromPod(); +//JC m_upgradeDomain = upgradeDomain; + } + + // Note: this replaces the fault domain with the one encoded in the pod name + void SetPodName(const char *pszPodName) + { + m_pszPodName = g_DrInternalizedStrings.InternalizeStringLowerCase(pszPodName); + m_fValidPod = true; + ReplaceFaultDomainFromPod(); + } +*/ + // Sets the pod name without explicitly changing the fault domain. + void SetPodNameNoFaultDomainUpdate(const char *pszPodName) + { + m_pszPodName = g_DrInternalizedStrings.InternalizeStringLowerCase(pszPodName); + m_fValidPod = true; + } +/*JC + // Note: this replaces the pod with "pod%u" + void SetFaultDomain(XsFaultDomain faultDomain) + { + m_faultDomain = faultDomain; + m_fValidFaultDomain = true; + ReplacePodFromFaultDomain(); + } + + // Sets the fault domain without explictly changing the pod name. + void SetFaultDomainNoPodUpdate(XsFaultDomain faultDomain) + { + m_faultDomain = faultDomain; + m_fValidFaultDomain = true; + } + + void SetUpgradeDomain(XsUpgradeDomain upgradeDomain) + { + m_upgradeDomain = upgradeDomain; + } + + + // Note: this replaces the fault domain with the one encoded in the pod name + DrError SetWithDefaultPort(const char *pszHostName, DrPortNumber defaultPort, const char *pszPodName, XsUpgradeDomain upgradeDomain) + { + DrStr64 strHost; + DrError err = DrNodeAddress::ParseHostPortName(strHost, &m_portNumber, pszHostName, defaultPort); + if (err == DrError_OK) { + m_pszHostName = g_DrInternalizedStrings.InternalizeStringLowerCase(strHost.GetString()); + m_pszPodName = g_DrInternalizedStrings.InternalizeStringLowerCase(pszPodName); + m_fValidPod = true; + ReplaceFaultDomainFromPod(); + m_upgradeDomain = upgradeDomain; + } + return err; + } +*/ + + void SetHostName(const char *pszHostName) + { + m_pszHostName = g_DrInternalizedStrings.InternalizeStringLowerCase(pszHostName); + } + + void SetPort(DrPortNumber port) + { + m_portNumber = port; + } + + const char *GetHostName() const + { + return m_pszHostName; + } + + DrPortNumber GetPort() const + { + return m_portNumber; + } + + // In XStream, the POD name is the fault domain in "pod%u" form + const char *GetPodName() const + { + return m_pszPodName; + } + +/* JC + // returns false if a POD was set that was not of the form "pod%u". + bool IsValidFaultDomain() const + { + return m_fValidFaultDomain; + } + + XsFaultDomain GetFaultDomain() const + { + return m_faultDomain; + } + + XsUpgradeDomain GetUpgradeDomain() const + { + return m_upgradeDomain; + } +*/ + + void Clear() + { + m_pszHostName = NULL; + m_portNumber = DrInvalidPortNumber; + // TODO: should pod, faultdomain, upgrade domain, etc. be updated? + } + + bool IsValid() const + { + return m_pszHostName != NULL; + } + + bool IsInvalid() const + { + return m_pszHostName == NULL; + } + + // This call may block for DNS + // It resolves the host name to a list of IP addresses and *appends* those to the specified DrNodeAddressList, filling in + // the port number for each. + // Note that since this request appends to the existing list, you must Clear() or Discard() the list before you make this + // call if you want the results to replace the existing set. + // Returns DrError_HostNotFound if no hosts match the name. + DrError ResolveToAddresses(DrNodeAddressList *pAddresses); + + DrError Unserialize(DrMemoryReader *pReader); + DrError Serialize(DrMemoryWriter *pWriter) const; + + DrStr& AppendToString(DrStr& strOut) const + { + strOut.AppendF("%s:%u", m_pszHostName, m_portNumber); + return strOut; + } + + DrStr& ToString(DrStr& strOut) const + { + strOut.SetToEmptyString(); + return AppendToString(strOut); + } + +/* JC +private: + // Updates the fault domain from the pod name. Note that if the pod name is not of the form "pod%u", the fault domain + // will be set to 0, m_fValidFaultDomain will be set to false, and and DRError_InvalidParameter will be returned. + DrError ReplaceFaultDomainFromPod(); + + // Sets the pod name to "pod%u" from the fault domain. + void ReplacePodFromFaultDomain() + { + m_pszPodName = DrFaultDomainToPodName(m_faultDomain); + m_fValidPod = true; + } +*/ + +private: + const char * m_pszHostName; // Internalized + DrPortNumber m_portNumber; + const char *m_pszPodName; // Internalized. String form of fault domain as "pod%u" + //JC XsUpgradeDomain m_upgradeDomain; // default = 0 + //JC XsFaultDomain m_faultDomain; // Fault domain (default = 0). NUmeric form of pod; + //JC bool m_fValidFaultDomain; // Set to false if fault domain was not encountered in Unserialize(). After Unserialize, Always true unless pod is not "pod%u". + bool m_fValidPod; // Used only during serialize/Unserialize. Set to false if pod was not encountered in Unserialize(). Always true after Unserialize: +}; + +class DrHostAndPortString : public DrStr32 +{ +public: + DrHostAndPortString(const DrHostAndPort& host) + { + host.AppendToString(*this); + } +}; + +// For each (ip address, port), keeps track of the last time it was accessed +class DrLastAccessEntry +{ +public: + DrLastAccessEntry() + { + m_nextHash = NULL; + m_nextAttemptAllowed = DrTimeStamp_LongAgo; + m_delayTime = 0; + m_lastError = DrError_OK; + } + + UInt32 Hash() + { + return m_nodeAddress.Hash(); + } + +public: + DrLastAccessEntry* m_nextHash; + DrNodeAddress m_nodeAddress; + + // If this is not DrTimeStamp_LongAgo, don't send requests until this time + DrTimeStamp m_nextAttemptAllowed; + + // Amount of time to delay next protocol request + DrTimeInterval m_delayTime; + + // Last error code reported + DrError m_lastError; +}; + + +// This table keeps track of when we last accessed a particular IP/port +// It is used for throttling when we are determining the primary DRM +// Currently only DRM service descriptors go into this table, which is why it is small +const UInt32 k_numLastAccessTableBuckets = 100; + +class DrLastAccessTable +{ +public: + DrLastAccessTable(); + + // Successful send to this node address, reset throttling + void UpdateSuccess(const DrNodeAddress& nodeAddress); + + // Send failure. + // Returns true if we were already at the maximum allowed delay value + bool UpdateFailure(const DrNodeAddress& nodeAddress, DrError error); + + // When can we next send to this node address? + // If when is DrTimeStamp_LongAgo, it means there is no delay at all + DrError GetDelay(const DrNodeAddress& nodeAddress, DrTimeStamp* when); + +private: + void Lock() + { + m_lock.Enter(); + } + + void Unlock() + { + m_lock.Leave(); + } + + // You must have the lock to call these functions + DrLastAccessEntry* FindOrCreate(const DrNodeAddress& nodeAddress); + DrLastAccessEntry* Find(const DrNodeAddress& nodeAddress); + +public: + DrCriticalSection m_lock; + DrLastAccessEntry* m_head[k_numLastAccessTableBuckets]; +}; + + +#define DRHOSTANDPORTSTRING(host) (DrHostAndPortString(host).GetString()) + +// Manages a list of host names (with optional port #) strings +class DrHostNameList : public DrPropertyParser{ +private: + const static UInt32 INVALID_PRIMARY_HOST = 0xFFFFFFFF; +public: + DrHostNameList() + { + m_pMultipleHosts = &m_singleHost; + m_numAllocated = 1; + m_numEntries = 0; + m_primary = INVALID_PRIMARY_HOST; + } + + ~DrHostNameList() + { + Clear(); + } + + // Set does a diff with what is already there, so this must be initialized + DrHostNameList& Set(const DrHostNameList& other, bool forceReordering = false); + + DrHostNameList(const DrHostNameList& other) + { + m_pMultipleHosts = &m_singleHost; + m_numAllocated = 1; + m_numEntries = 0; + m_primary = INVALID_PRIMARY_HOST; + + Set(other); + } + + DrHostNameList& operator=(const DrHostNameList& other) + { + // Free previously allocated memory when and only when new name list will not fit + // into already allocated region + if ((other.m_numEntries > m_numAllocated) && (m_pMultipleHosts != &m_singleHost)) + Clear(); + + return Set(other, true); + } + + void Clear() + { + if (m_pMultipleHosts != &m_singleHost) { + delete[] m_pMultipleHosts; + m_pMultipleHosts = &m_singleHost; + m_numAllocated = 1; + } + m_numEntries = 0; + m_primary = INVALID_PRIMARY_HOST; + } + + void GrowTo(UInt32 numAllocated) + { + if (numAllocated > m_numAllocated) { + UInt32 nNew = 2 * m_numAllocated; + if (nNew < numAllocated) { + nNew = numAllocated; + } + if (nNew < 8) { + nNew = 8; + } + DrHostAndPort *pNew = new DrHostAndPort[nNew]; + for (UInt32 i = 0; i < m_numEntries; i++) { + pNew[i] = m_pMultipleHosts[i]; + } + if (m_pMultipleHosts != &m_singleHost) { + delete[] m_pMultipleHosts; + } + m_pMultipleHosts = pNew; + m_numAllocated = nNew; + } + } + + DrHostAndPort *AddEntry() + { + GrowTo(m_numEntries+1); + DrHostAndPort *pEntry = m_pMultipleHosts + m_numEntries; + m_numEntries++; + return pEntry; + } + + const DrHostAndPort& operator[](UInt32 index) const + { + LogAssert(index < m_numEntries); + return m_pMultipleHosts[index]; + } + + DrHostAndPort& operator[](UInt32 index) + { + LogAssert(index < m_numEntries); + return m_pMultipleHosts[index]; + } + + UInt32 GetLength() const + { + return m_numEntries; + } + + // Demote host to bottom of the list + // If the host was the primary, then invalidate the primary + void DemoteHost(DrHostAndPort &host){ + DrHostAndPort saveHost; + saveHost.Set(host); + + LogAssert(m_numEntries > 0); + + for(UInt32 i = 0; i < m_numEntries; i ++) + { + if(host == m_pMultipleHosts[i]) + { + if (m_primary == i) + m_primary = INVALID_PRIMARY_HOST; // this host is deemed dead + + if (i < m_numEntries - 1) + { + for(; i < m_numEntries - 1; i++) + { + m_pMultipleHosts[i].Set(m_pMultipleHosts[i + 1]); + } + + m_pMultipleHosts[m_numEntries - 1].Set(saveHost); + } + + break; + } + } + } + bool IsPrimaryValid(void) const + { + return m_primary != INVALID_PRIMARY_HOST; + } + + bool IsPrimaryInvalid(void) const + { + return m_primary == INVALID_PRIMARY_HOST; + } + + void SetPrimaryInvalid(void) + { + m_primary = INVALID_PRIMARY_HOST; + } + + void SetPrimary(UInt32 primary) + { + m_primary = primary; + } + + UInt32 GetPrimary(void) const{ + return m_primary; + } + + DrError Serialize(DrMemoryWriter *pWriter) const; + virtual DrError OnParseProperty(DrMemoryReader *reader, UInt16 property, UInt32 dataLen, void *cookie); + + // This call may block for DNS + // It resolves the list of host names to a list of IP addresses and *appends* those to the specified DrNodeAddressList, filling in + // the port number for each. + // Note that since this request appends to the existing list, you must Clear() or Discard() the list before you make this + // call if you want the results to replace the existing set. + // Returns DrError_HostNotFound if no hosts match the name. + DrError ResolveToAddresses(DrNodeAddressList *pAddresses); + + // return one IP/port pair in *pAddresses* and return one host name in *host* + DrError ResolveOneHostToAddresses(DrNodeAddressList *pAddressses, bool wantPrimary, DrHostAndPort &host); + + DrStr& AppendToString(DrStr& strOut) const + { + if (m_numEntries == 0) { + strOut.Append(""); + + } else { + m_pMultipleHosts[0].AppendToString(strOut); + for (UInt32 i = 1; i < m_numEntries; i++) { + strOut.Append(';'); + m_pMultipleHosts[i].AppendToString(strOut); + } + } + return strOut; + } + + DrStr& ToString(DrStr& strOut) const + { + strOut.SetToEmptyString(); + return AppendToString(strOut); + } + + +protected: + + void SelectOneHost(DrHostAndPort &host, bool wantPrimary = true); + +private: + DrHostAndPort m_singleHost; + DrHostAndPort *m_pMultipleHosts; // Points to m_singleHost if there is 1 entry; otherwise, a heap array. If there is a primary it is always the first entry. + UInt32 m_numAllocated; + UInt32 m_numEntries; + UInt32 m_primary; // This is 0 if there is a primary, and INVALID_PRIMARY_HOST otherwise +}; + +class DrHostNameListString : public DrStr64 +{ +public: + DrHostNameListString(const DrHostNameList& hostList) + { + hostList.AppendToString(*this); + } +}; + +#define DRHOSTNAMELISTSTRING(hostList) (DrHostNameListString(hostList).GetString()) + +/* + * Inline methods for DrNodeAddress + */ + +inline DrNodeAddress::DrNodeAddress() +{ + m_ina.S_un.S_addr=0; + m_wPort=0; +} + +inline DrNodeAddress::DrNodeAddress(const DrNodeAddress& addr) +{ + m_ina=addr.m_ina; + m_wPort=addr.m_wPort; +} + +inline DrNodeAddress& DrNodeAddress::operator=(const DrNodeAddress& addr) +{ + m_ina=addr.m_ina; + m_wPort=addr.m_wPort; + return *this; +} + +inline DrNodeAddress::DrNodeAddress(const IN_ADDR& ina, DrPortNumber wPort) +{ + m_ina=ina; + m_wPort=wPort; +} + +inline void DrNodeAddress::Set(const IN_ADDR& ina, DrPortNumber wPort) +{ + m_ina=ina; + m_wPort=wPort; +} + +inline void DrNodeAddress::Set(DrIpAddress ipAddr, DrPortNumber wPort) +{ + m_ina.s_addr=htonl(ipAddr); + m_wPort=wPort; +} + +#pragma prefast(pop) + +#endif + diff --git a/DryadVertex/VertexHost/system/classlib/include/DrProperties.h b/DryadVertex/VertexHost/system/classlib/include/DrProperties.h new file mode 100644 index 0000000..0fb238d --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrProperties.h @@ -0,0 +1,79 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//JC Check this file for redundant information. + +// Dryad properties (see propertybag.h) + +// General purpose parameters + +// (BeginTag, UInt16 tag) marks the beginning of a set of data +// (EndTag, UInt16 tag) marks the end of the a set of data +// See cstags.h for a description of known tag values + +// This file must consist only of DEFINE_DRPROPERTY statements + +// Current Allowed types are: +// +// UInt32 +// UInt64 +// String +// Guid +// TimeStamp +// TimeInterval +// BeginTag +// EndTag +// DrError +// DrExitCode +// Blob +// EnvironmentBlock + +// ****************************************************************************************************** +// +// DO NOT RESERVE RANGES IN THIS FILE, USE NEXT AVAILABLE ID AND LEAVE NO HOLES +// +// ****************************************************************************************************** + +// This is a BeginTag +DEFINE_DRPROPERTY(Prop_Dryad_BeginTag, PROP_SHORTATOM(0x1200), BeginTag, "BeginTag") + +// This is an EndTag +DEFINE_DRPROPERTY(Prop_Dryad_EndTag, PROP_SHORTATOM(0x1201), EndTag, "EndTag") + +// This is a Blob value - a win32 environment variable block +DEFINE_DRPROPERTY(Prop_Dryad_EnvironmentBlock, PROP_LONGATOM(0x129A), EnvironmentBlock, "EnvironmentBlock") + +// This is a UInt16 +DEFINE_DRPROPERTY(Prop_Dryad_Port, PROP_SHORTATOM(0x1294), UInt16, "Port") + +DEFINE_DRPROPERTY(Prop_Dryad_ShortHostName, PROP_SHORTATOM(0x1293), String, "ShortHostName") +DEFINE_DRPROPERTY(Prop_Dryad_LongHostName, PROP_LONGATOM(0x1037), String, "LongHostName") + +// This is a String +DEFINE_DRPROPERTY(Prop_Dryad_PodName, PROP_LONGATOM(0x12C6), String, "PodName") + +// This is a UInt32 - gives number of entries that follow +DEFINE_DRPROPERTY(Prop_Dryad_NumEntries, PROP_SHORTATOM(0x124C), UInt32, "NumEntries") + +// This is a UInt32 value - a pointer to the primary host in a host list +DEFINE_DRPROPERTY(Prop_Dryad_PrimaryHost, PROP_SHORTATOM(0x12A8), UInt32, "PrimaryHost") + +// This is a UInt32 value - a pointer to the next host to return in a host list +DEFINE_DRPROPERTY(Prop_Dryad_NextHost, PROP_SHORTATOM(0x12A9), UInt32, "NextHost") diff --git a/DryadVertex/VertexHost/system/classlib/include/DrPropertiesDef.h b/DryadVertex/VertexHost/system/classlib/include/DrPropertiesDef.h new file mode 100644 index 0000000..551b8c2 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrPropertiesDef.h @@ -0,0 +1,60 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +class DrPropertyDumper; + +#ifdef DECLARE_DRPROPERTYTYPE +#undef DECLARE_DRPROPERTYTYPE +#endif + +#define DECLARE_DRPROPERTYTYPE(type) \ + extern DrError DrPropertyToText_##type(DrPropertyDumper *pDumper, UInt16 enumId, const char *propertyName); + +#include "DrPropertyType.h" + +#undef DECLARE_DRPROPERTYTYPE + + +#ifdef DEFINE_DRPROPERTY +#undef DEFINE_DRPROPERTY +#endif + +#define DEFINE_DRPROPERTY(var, value, type, propertyName) \ + static const UInt16 var = value; + +#include "DrProperties.h" + +#undef DEFINE_DRPROPERTY + + +// This is a special value for an offset meaning unknown +const UInt64 DrStreamOffset_Unknown = 0xFFFFFFFFFFFFFFFF; + +// This is a special value for an extent offset meaning unknown +// Note that -1 is used to mean an invalid offset +const UInt64 DrExtentOffset_Unknown = (UInt64 ) -2; //$TODO(DanielD) - for consistency (and logging, etc - see SamMck) swap values to have DrExtentOffset_Unknown==-1 +const UInt64 DrExtentOffset_Invalid = (UInt64 ) -1; + +const UInt64 DrExtentLength_Invalid = (UInt64) -1; + +//Flags for Prop_Dryad_PublishedCrc64 +const UInt64 DrExtentCrc64_Suspect = (UInt64)-1; diff --git a/DryadVertex/VertexHost/system/classlib/include/DrPropertyDumper.h b/DryadVertex/VertexHost/system/classlib/include/DrPropertyDumper.h new file mode 100644 index 0000000..ecef922 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrPropertyDumper.h @@ -0,0 +1,712 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "DrList.h" + +#pragma warning(push) +#pragma warning(disable:4995) + +// A structure that defines a mapping between a particular bit range within a UInt32 and a readable name +typedef struct { + const char *fieldName; // The name of the bit field + UInt32 bitMask; // The bits that relate to this mask + const char **prgValueNames; // An array of enum names for the field (of length 2**n-1, where n is the number of bits including + // the least significant and most significant 1 bit in bitMask. If an entry is NULL, + // the bitmask value will not be displayed. If prgValueNames is NULL, all values will + // be displayed as a UInt. +} DrBitFieldMap; + +// A structure that defines a mapping between a UInt32 and a set of bit fields +typedef struct { + int numFields; // The number of fields + const DrBitFieldMap *prgFields; // The array of bit field mappiings +} DrBitMaskMap; + +/* JC +class DrPropertyDumper +{ +private: + typedef enum { + // We are at the beginning of a property id + DrPropertyDumper_StartOfTag = 0, + + // we have read the first byte of the property ID and are waiting for the second byte + DrPropertyDumper_ReadingPropertyId, + + // We have read the property ID, but are waiting on 1-4 remaining length bytes (1 byte for a shortatom, 4 bytes for a longatom) + DrPropertyDumper_ReadingLength, + + // We have read the property ID and the length, , but are waiting on the remaining data + DrPropertyDumper_ReadingData + } DrPropertyDumperState; + +public: + +public: + // The default maximum size that will be dumped for blob properties, unknown properties, and partial/malformed properties. Blob properties + // that exceed this length will be truncated in the dump with an "...(n bytes remaining)" annotation. + static const size_t k_nbDefaultMaxBlobSize = 512; + + // The default maximum size that will be dumped for + static const size_t k_nbDefaultMaxPayloadSize = 64; + static const int k_nDefaultIndentSpacesPerLevel = 2; + +private: + void Construct(size_t maxBlobSize, bool fIncludeComments, int nIndentSpacesPerLevel) + { + m_state = DrPropertyDumper_StartOfTag; + m_partialPropertyId = 0; + m_partialPropertyTotalLength = 0; + m_pReader = NULL; + m_pWriter = NULL; + m_fDeleteReader = false; + m_fDeleteWriter = false; + m_indent = 0; + m_fAtBol = true; + m_fOutputCrlf = true; + m_nbMaxBlobSize = maxBlobSize; + m_nbMaxPayloadSize = k_nbDefaultMaxPayloadSize; + if (m_nbMaxPayloadSize > m_nbMaxBlobSize) { + m_nbMaxPayloadSize = m_nbMaxBlobSize; + } + m_nIndentSpacesPerLevel = nIndentSpacesPerLevel; + m_fIncludeComments = fIncludeComments; + } + + +public: + DrPropertyDumper() + { + Construct(k_nbDefaultMaxBlobSize, true, k_nDefaultIndentSpacesPerLevel); + } + + DrPropertyDumper(bool fFullContent, bool fIncludeComments = true, int nIndentSpacesPerLevel = k_nDefaultIndentSpacesPerLevel) + { + Construct(fFullContent ? MAX_SIZE_T : k_nbDefaultMaxBlobSize, fIncludeComments, nIndentSpacesPerLevel); + } + + DrPropertyDumper(size_t maxBlobSize, bool fIncludeComments = true, int nIndentSpacesPerLevel = k_nDefaultIndentSpacesPerLevel) + { + Construct(maxBlobSize, fIncludeComments, nIndentSpacesPerLevel); + } + + __declspec(deprecated) void SetReader(DrMemoryReader *pReader, bool fDelete = false) + { + if (m_pReader != NULL && m_fDeleteReader) { + delete m_pReader; + } + m_pReader = pReader; + m_fDeleteReader = fDelete; + } + + // If true, then the full content of the property set (including large blobs) will be encoded rather than summarized + // If false, then blobs will be limited to the first k_nbDefaultMaxBlobSize bytes. + void SetFullContent(bool fFullContent=true) + { + m_nbMaxBlobSize = fFullContent ? MAX_SIZE_T : k_nbDefaultMaxBlobSize; + } + + // If parameter is true, causes bare '\n' to terminate lines, rather than "\r\n" + void SetSuppressOutputCr(bool fSuppressOutputCr = true) + { + m_fOutputCrlf = !fSuppressOutputCr; + } + + // Will lower, but not raise the max payload size + void SetMaxBlobSize(size_t nbMaxBlobSize) + { + m_nbMaxBlobSize = nbMaxBlobSize; + if (m_nbMaxPayloadSize > m_nbMaxBlobSize) { + m_nbMaxPayloadSize = m_nbMaxBlobSize; + } + } + + // Will not raise the max payload size + void SetUnlimitedBlobSize() + { + SetMaxBlobSize(MAX_SIZE_T); + } + + size_t GetMaxBlobSize() + { + return m_nbMaxBlobSize; + } + + void SetIncludeComments(bool fIncludeComments = true) + { + m_fIncludeComments = fIncludeComments; + } + + bool ShouldIncludeComments() + { + return m_fIncludeComments; + } + + // Does not affect the limits for non-payload blobs + void SetMaxPayloadSize(size_t nbMaxPayloadSize) + { + m_nbMaxPayloadSize = nbMaxPayloadSize; + } + + // Does not affect the limits for non-payload blobs + void SetUnlimitedPayloadSize() + { + SetMaxPayloadSize(MAX_SIZE_T); + } + + size_t GetMaxPayloadSize() + { + return m_nbMaxPayloadSize; + } + + DrMemoryReader *GetReader() + { + return m_pReader; + } + + void SetWriter(DrMemoryWriter *pWriter, bool fDelete = false) + { + if (m_pWriter != NULL && m_fDeleteWriter) { + delete m_pWriter; + } + m_pWriter = pWriter; + m_fDeleteWriter = fDelete; + } + + DrMemoryWriter *GetWriter() + { + return m_pWriter; + } + + void SetNumIndentSpacesPerLevel(int n) + { + m_nIndentSpacesPerLevel = n; + } + + int GetNumIndentSpacesPerLevel() + { + return m_nIndentSpacesPerLevel; + } + + void SetIndent(int n) + { + m_indent = n; + } + + int GetIndent() + { + return m_indent; + } + + void Indent() + { + m_indent += m_nIndentSpacesPerLevel; + } + + void Unindent() + { + m_indent -= m_nIndentSpacesPerLevel; + if (m_indent < 0) { + m_indent = 0; + } + } + + DrError WriteF(const char *pszFormat, ...); + DrError VWriteF(const char *pszFormat, va_list args); + + // Writes an XML start tag. Does NOT affect the DrTag Nesting level + DrError WriteStartTag(const char *pszTagName, Size_t length, bool fIncludeLength, bool fNewlineAfter) + { + DrError err; + if (fIncludeLength) { + err = WriteF("<%s length=\"%Iu\"%s", pszTagName, length, (fNewlineAfter ? ">\n" : ">")); + } else { + err = WriteF("<%s%s", pszTagName, (fNewlineAfter ? ">\n" : ">")); + m_fAtBol = fNewlineAfter; + } + Indent(); + return err; + } + + DrError WriteSimpleStartTag(const char *pszTagName, bool fNewlineAfter) + { + DrError err; + err = WriteF("<%s%s", pszTagName, (fNewlineAfter ? ">\n" : ">")); + m_fAtBol = fNewlineAfter; + Indent(); + return err; + } + + DrError FlushFileWriter() + { + } + + DrError WriteXmlFileHeader() + { + return PutStr("\n"); + } + + // Writes an XML start tag. Does NOT affect the DrTag Nesting level + // pszAttributes is an optional XML attributes string in the form: + // attrib1="value" attrib2="value" + // + DrError WriteStartTagWithAttributes(__in const char *pszTagName, __in_opt const char *pszAttributes, bool fNewlineAfter) + { + DrError err; + if (pszAttributes == NULL) { + pszAttributes = ""; + } else { + while(pszAttributes[0] == ' ') { + pszAttributes++; + } + } + err = WriteF("<%s%s%s%s", pszTagName, (pszAttributes[0] == '\0') ? "" : " ", pszAttributes, (fNewlineAfter ? ">\n" : ">")); + m_fAtBol = fNewlineAfter; + Indent(); + return err; + } + + DrError WriteStartTagWithLengthAndAttributes(__in const char *pszTagName, Size_t length, __in_opt const char *pszAttributes, bool fNewlineAfter) + { + DrError err; + if (pszAttributes == NULL) { + pszAttributes = ""; + } else { + while(pszAttributes[0] == ' ') { + pszAttributes++; + } + } + err = WriteF("<%s length=\"%Iu\"%s%s%s", pszTagName, length, (pszAttributes[0] == '\0') ? "" : " ", pszAttributes, (fNewlineAfter ? ">\n" : ">")); + m_fAtBol = fNewlineAfter; + Indent(); + return err; + } + + // Writes an XML end tag. DOes NOT affect the DrTag Nesting level + DrError WriteEndTag(const char *pszTagName, bool fNewlineBefore) + { + Unindent(); + DrError err = WriteF("%s/%s>\n", (fNewlineBefore ? "\n<" : "<"), pszTagName); + return err; + } + + // The string generated by pszFormat must be XML-encoded or contain no XML delimiters + DrError VWriteSimpleTagValue( + const char *pszTagName, + UInt32 length, + bool fIncludeLength, + bool fSeperateLine, + const char *pszFormat, + va_list args) + { + DrError err; + + WriteStartTag(pszTagName, length, fIncludeLength, fSeperateLine); + VWriteF(pszFormat, args); + err = WriteEndTag(pszTagName, fSeperateLine); + return err; + } + + // The string generated by pszFormat must be XML-encoded or contain no XML delimiters + DrError WriteSimpleTagValue( + const char *pszTagName, + UInt32 length, + bool fIncludeLength, + bool fSeperateLine, + const char *pszFormat, + ...) + { + va_list args; + va_start(args, pszFormat); + return VWriteSimpleTagValue(pszTagName, length, fIncludeLength, fSeperateLine, pszFormat, args); + } + + // The string generated by pszFormat must be XML-encoded or contain no XML delimiters + DrError WriteSimpleTagValue( + const char *pszTagName, + const char *pszFormat, + ...) + { + va_list args; + va_start(args, pszFormat); + return VWriteSimpleTagValue(pszTagName, 0, false, false, pszFormat, args); + } + + // The string generated by pszFormat must be XML-encoded or contain no XML delimiters + DrError WriteSimpleTagValue( + const char *pszTagName, + UInt32 length, + const char *pszFormat, + ...) + { + va_list args; + va_start(args, pszFormat); + return VWriteSimpleTagValue(pszTagName, length, true, false, pszFormat, args); + } + + // Puts a string to output, indenting lines as appropriate + // Must obey XML conventions + DrError PutStr(const char *psz, int len = -1); + + DrError PutBolLeader() + { + if (m_indent > 0) { + m_fAtBol = false; + for (int i = 0; i < m_indent; i++) { + m_pWriter->WriteChar(' '); + } + } + return m_pWriter->GetStatus(); + } + + DrError PutUInt64TagValue(const char *pszTagName, UInt64 val, bool fHex = false) + { + const char *pszFormat = (fHex ? "0x%016I64x" : "%I64u"); + return WriteSimpleTagValue(pszTagName, pszFormat, val); + } + + DrError PutUInt32TagValue(const char *pszTagName, UInt32 val, bool fHex = false) + { + const char *pszFormat = (fHex ? "0x%08x" : "%u"); + return WriteSimpleTagValue(pszTagName, pszFormat, val); + } + + DrError PutDoubleTagValue(const char *pszTagName, double val) + { + const char *pszFormat = "%lf"; + return WriteSimpleTagValue(pszTagName, pszFormat, val); + } + + DrError PutUInt32BitMaskValue(const char *pszTagName, UInt32 val, const DrBitMaskMap *pMap); + + DrError PutUInt16TagValue(const char *pszTagName, UInt16 val, bool fHex = false) + { + const char *pszFormat = (fHex ? "0x%04x" : "%u"); + return WriteSimpleTagValue(pszTagName, pszFormat, (UInt32)val); + } + + DrError PutTagIdTagValue(const char *pszTagName, UInt16 val) + { + const char *pszTagIdName = GetTagName(val); + if (pszTagIdName == NULL) { + return WriteSimpleTagValue(pszTagName, "0x%04x", (UInt32)val); + } else { + return WriteSimpleTagValue(pszTagName, "0x%04x ", (UInt32)val, pszTagIdName); + } + } + + DrError PutInt64TagValue(const char *pszTagName, Int64 val, bool fHex = false) + { + const char *pszFormat = (fHex ? "0x%016I64x" : "%I64d"); + return WriteSimpleTagValue(pszTagName, pszFormat, val); + } + + DrError PutInt32TagValue(const char *pszTagName, Int32 val, bool fHex = false) + { + const char *pszFormat = (fHex ? "0x%08x" : "%d"); + return WriteSimpleTagValue(pszTagName, pszFormat, val); + } + + DrError PutInt16TagValue(const char *pszTagName, Int16 val, bool fHex = false) + { + const char *pszFormat = (fHex ? "0x%04x" : "%d"); + return WriteSimpleTagValue(pszTagName, pszFormat, (Int32)val); + } + + DrError PutBooleanTagValue(const char *pszTagName, bool val) + { + return WriteSimpleTagValue(pszTagName, "%s", (val ? "true" : "false")); + } + + // safely XML-encodes the string + DrError PutStringTagValue(const char *pszTagName, const char *val, int len = -1) + { + if (val == NULL) { + return WriteSimpleTagValue(pszTagName, "%s", ""); + } else { + if (len < 0) { + len = (int)strlen(val); + } + DrStr512 strXml; + strXml.AppendXmlEncodedString(val, (Size_t)len, true); + return WriteSimpleTagValue(pszTagName, "%s", strXml.GetString()); + } + } + + DrError PutGuidTagValue(const char *pszTagName, const DrGuid& val) + { + char buff[40]; + return WriteSimpleTagValue(pszTagName, "%s", val.ToString(buff, false)); + } + + DrError PutTimeStampTagValue(const char *pszTagName, DrTimeStamp val) + { + char buff[64]; + DrError err; + if (val == DrTimeStamp_Never) { + err = WriteSimpleTagValue(pszTagName, "never", (UInt64)val, buff); + } else if (val == DrTimeStamp_LongAgo) { + err = WriteSimpleTagValue(pszTagName, "long ago", (UInt64)val, buff); + } else { + err = DrTimeStampToString(val, buff, sizeof(buff), false); + if (err == DrError_OK) { + err = WriteSimpleTagValue(pszTagName, "%I64u", (UInt64)val, buff); + } else { + m_pWriter->SetStatus(err); + } + } + return err; + } + + DrError PutNodeAddressTagValue(const char *pszTagName, const DrNodeAddress& val) + { + char buff[64]; + DrError err = val.ToAddressPortString(buff, sizeof(buff)); + if (err == DrError_OK) { + err = WriteSimpleTagValue(pszTagName, "%s", buff); + } else { + m_pWriter->SetStatus(err); + } + return err; + } + + DrError PutTimeIntervalTagValue(const char *pszTagName, DrTimeInterval val) + { + DrError err; + + char buff[k_DrTimeIntervalStringBufferSize]; + err = DrTimeIntervalToString(val, buff, sizeof(buff)); + LogAssert(err == DrError_OK); + + err = WriteSimpleTagValue(pszTagName, "%I64d", (Int64)val, buff); + + return err; + } + + DrError PutDrErrorTagValue(const char *pszTagName, DrError val) + { + char buff[1024]; + const char *pszErr = DrGetErrorDescription(val, buff, sizeof(buff)); + // Assumes error code descriptions are XML-friendly + DrError err = WriteSimpleTagValue(pszTagName, "0x%08x", (UInt32)val, pszErr); + return err; + } + + DrError PutDrExitCodeTagValue(const char *pszTagName, DrExitCode val) + { + // Assumes exit code descriptions are XML-friendly + DrError err = WriteSimpleTagValue(pszTagName, "0x%08x", (UInt32)val, DREXITCODESTRING(val)); + return err; + } + + DrError PutBlobTagValue(const char *pszTagName, DrMemoryReader *pReader, Size_t length, Size_t maxShowLength, const char *pszExtraAttributes = NULL); + DrError PutEnvironmentBlockTagValue(const char *pszTagName, const DrEnvironmentStrings& envBlock); + + static const char *GetPropertyName(UInt16 enumId); + static const char *GetTagName(UInt16 tagId); + + + // Returns the number of nested DrTag begintags that are currently in effect. 0 means we are not inside a begin tag + UInt32 GetCurrentTagLevel() + { + return m_beginTagStack.NumEntries(); + } + + // returns DrTag_InvalidTag if we are at level 0 + UInt16 GetCurrentNestedTagId() + { + if (m_beginTagStack.NumEntries() == 0) { + return DrTag_InvalidTag; + } else { + return m_beginTagStack.TopOfStack(); + } + } + + // returns NULL if the current nested DrTag tag name is not known + const char *GetCurrentNestedTagName() + { + const char * pszTagName = GetTagName(GetCurrentNestedTagId()); + return pszTagName; + } + + // Returns the DrTag tag level below which we are not allowed to pop, due to a balanced block being in effect + UInt32 GetCurrentMinimumTagLevel() + { + if (m_balancedBlockStack.NumEntries() == 0) { + return 0; + } else { + return m_balancedBlockStack.TopOfStack(); + } + } + + // returns true if there is a partially written DrProperty waiting to be completed before it is dumped + bool PartialPropertyIsPending() + { + return (m_state != DrPropertyDumper_StartOfTag); + } + + + // Begins a section where begin/end DrTags must be balanced and there must not be be a partial property at the beginning or end + // On entry, any partially written property is written out as a specially tagged "incomplete" property + void BeginBalancedBlock(); + + // If there is a partially written property pending, it is written out as a specially tagged "incomplete" property. + // This resets the parser to expect a property ID next. + // Also, if there are outstanding open begin-DrTag beyond the currend balanced block level, writes end DrTags for them (with XML comments + // describing them as missing). + // After this call, the current DrTag level will be equal to the current balanced block level. + DrError SynchronizeToBalancedBlock(); + + // endss a section where begin/end DrTags must be balanced. + // If there is a partially written property pending, it is written out as a specially tagged "incomplete" property. + // This resets the parser to expect a property ID next. + // Also, if there are outstanding open begin-DrTag beyond the currend balanced block level, writes end DrTags for them (with XML comments + // describing them as missing). + // After this call, the current DrTag level will be equal to the current balanced block level. + DrError EndBalancedBlock() + { + LogAssert(m_balancedBlockStack.NumEntries() != 0); + SynchronizeToBalancedBlock(); + LogAssert(GetCurrentTagLevel() == GetCurrentMinimumTagLevel()); + m_balancedBlockStack.Pop(); + return m_pWriter->GetStatus(); + } + + // writes an XML equivalent to a DrTag "begin tag", and pushes a level + DrError WriteAndPushDrBeginTag(UInt16 tagId); + + // writes an XML equivalent to a DrTag "end tag", and pops a level + DrError WriteAndPopDrEndTag(UInt16 tagId); + + // If there is a partially written property pending, it is written out as a specially tagged "incomplete" property. + // This resets the parser to expect a property ID next. + DrError WritePartialProperty(); + + // Final flush of the dumper. must be called from balanced block level 0 + DrError WriteIncompleteEndTagsAndFlushWriter() + { + LogAssert(m_balancedBlockStack.NumEntries() == 0); + SynchronizeToBalancedBlock(); + return m_pWriter->FlushMemoryWriter(); + } + // reads all remaining bytes from the reader and generates XML output as appropriate + // The content of the reader does not need to be balanced or complete, and it need not begin or end at a property boundary; you may call this multiple times to + // represent a coherent property stream + // If you want to enforce tag balancing accross a section, call BeginBalancedBlock before calling this, and EndBalancedBlock after calling this. + // if maxLength is specified, the function will return after reading maxLength bytes even if end of stream is not reached + DrError ParseAndWriteFromReader(DrMemoryReader *pReader, Size_t maxLength=MAX_SIZE_T); + + DrError ParseAndWriteFromBuffer(DrMemoryBuffer *pBuffer, Size_t initialOffset=0, Size_t maxLength = MAX_SIZE_T) + { + DrMemoryBufferReader reader(pBuffer); + if (initialOffset != 0) { + reader.SetBufferOffset(initialOffset); + } + return ParseAndWriteFromReader(&reader, maxLength); + } + + DrError ParseAndWriteFromSingleBlock(const void *pData, Size_t length) + { + DrSingleBlockReader reader(pData, length); + return ParseAndWriteFromReader(&reader, length); + } + + + // reads the next single tag&value and outputs it. Does not reads aggregates--begintag and endtag are + // trated as independent values; however, it does increment indentation in begintag, and decrement it on endtag. + __declspec(deprecated) DrError PutNextPropertyTagValue() + { + return PutNextPropertyTagValue(m_pReader); + } + + + DrError PutNextPropertyTagValue(DrMemoryReader *pReader); + + // Reads the remainder of the input stream (Until DrError_EndOfStream is returned) as a sequential list + // of 0 or more properties and/or aggregates, and outputs them accordingly. Does not return until the entire stream has been read. + // Will always write a balanced set even if the input is unbalanced (will add missing end tags, etc., as necessary). + // returns the writer status + __declspec(deprecated) DrError PutNestedPropertyList() + { + return PutNestedPropertyList(m_pReader); + } + DrError PutNestedPropertyList(DrMemoryReader *pReader) + { + BeginBalancedBlock(); + ParseAndWriteFromReader(pReader); + return EndBalancedBlock(); + } + + bool IsAtBeginningOfLine() + { + return m_fAtBol; + } + +private: + DrMemoryReader *m_pReader; + bool m_fDeleteReader; + DrMemoryWriter *m_pWriter; + bool m_fDeleteWriter; + int m_indent; + int m_nIndentSpacesPerLevel; + bool m_fAtBol; // Beginning-of-line flag + bool m_fOutputCrlf; // True if '\r' should precede each '\n' on output. + size_t m_nbMaxBlobSize; + size_t m_nbMaxPayloadSize; + bool m_fIncludeComments; + + DrPropertyDumperState m_state; + + // A stack of unmatched begintag tagids. The top of stack is the end tag we are currently looking for. + // The number of entries in this list is the current tag "level" + DrValList m_beginTagStack; + + // A stack of pushed "BeginBalancedBlock" "level" parameters. + // The top of the stack is the current balanced block level -- the parser will not allow end tags to pop the beginTagStack + // above this level. + DrValList m_balancedBlockStack; + + // buffer to store an incomplete property we are parsing + // Includes property tag + DrRef m_pPartialPropertyBuffer; + + UInt16 m_partialPropertyId; + Size_t m_partialPropertyTotalLength; + +}; +*/ + +extern void DrInitPropertyTable(); +extern void DrInitTagTable(); + +/* JC +typedef DrError (*DrPropertyConverter)(DrPropertyDumper *pDumper, UInt16 enumId, const char *propertyName); + +// returns ERROR_ALREADY_ASSIGNED if the property has already been defined +extern DrError DrAddPropertyToDumper(UInt16 prop, const char *pszDescription, DrPropertyConverter pConverter); + +// returns ERROR_ALREADY_ASSIGNED if the tag has already been defined +extern DrError DrAddTagToDumper(UInt16 tag, const char *pszDescription); +*/ + +#pragma warning(pop) + diff --git a/DryadVertex/VertexHost/system/classlib/include/DrPropertyType.h b/DryadVertex/VertexHost/system/classlib/include/DrPropertyType.h new file mode 100644 index 0000000..e9de15f --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrPropertyType.h @@ -0,0 +1,63 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// This file must consist only of DECLARE_DRPROPERTYTYPE statements + +DECLARE_DRPROPERTYTYPE(Unknown) +DECLARE_DRPROPERTYTYPE(Boolean) +DECLARE_DRPROPERTYTYPE(Int16) +DECLARE_DRPROPERTYTYPE(Int32) +DECLARE_DRPROPERTYTYPE(Int64) +DECLARE_DRPROPERTYTYPE(UInt16) +DECLARE_DRPROPERTYTYPE(UInt32) +DECLARE_DRPROPERTYTYPE(UInt64) +DECLARE_DRPROPERTYTYPE(Double) +DECLARE_DRPROPERTYTYPE(HexUInt16) +DECLARE_DRPROPERTYTYPE(HexUInt32) +DECLARE_DRPROPERTYTYPE(HexUInt64) +DECLARE_DRPROPERTYTYPE(String) +DECLARE_DRPROPERTYTYPE(Guid) +DECLARE_DRPROPERTYTYPE(TimeStamp) +DECLARE_DRPROPERTYTYPE(TimeInterval) +DECLARE_DRPROPERTYTYPE(BeginTag) +DECLARE_DRPROPERTYTYPE(EndTag) +DECLARE_DRPROPERTYTYPE(DrError) +DECLARE_DRPROPERTYTYPE(DrExitCode) +DECLARE_DRPROPERTYTYPE(Blob) +DECLARE_DRPROPERTYTYPE(Payload) +DECLARE_DRPROPERTYTYPE(EnvironmentBlock) +DECLARE_DRPROPERTYTYPE(PropertyList) +DECLARE_DRPROPERTYTYPE(TagIdValue) + +DECLARE_DRPROPERTYTYPE(AppendExtentOptions) +DECLARE_DRPROPERTYTYPE(AppendBlockOptions) +DECLARE_DRPROPERTYTYPE(SyncOptions) +DECLARE_DRPROPERTYTYPE(SyncDirectiveOptions) +DECLARE_DRPROPERTYTYPE(ReadExtentOptions) +DECLARE_DRPROPERTYTYPE(AppendStreamOptions) +DECLARE_DRPROPERTYTYPE(EnumDirectoryOptions) +DECLARE_DRPROPERTYTYPE(EnInfoBits) +DECLARE_DRPROPERTYTYPE(UpdateExtentMetadataOptions) +DECLARE_DRPROPERTYTYPE(StreamInfoBits) +DECLARE_DRPROPERTYTYPE(StreamCapabilityBits) +DECLARE_DRPROPERTYTYPE(ExtentInfoBits) +DECLARE_DRPROPERTYTYPE(ExtentInstanceInfoBits) +DECLARE_DRPROPERTYTYPE(FailureInjectionOptions) + diff --git a/DryadVertex/VertexHost/system/classlib/include/DrRefCounter.h b/DryadVertex/VertexHost/system/classlib/include/DrRefCounter.h new file mode 100644 index 0000000..7fb0899 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrRefCounter.h @@ -0,0 +1,839 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#ifndef __DRYADREFCOUNTER_H__ +#define __DRYADREFCOUNTER_H__ + +#pragma once + +template __forceinline Ty& DrRemoveConst(const Ty& x) +{ + return const_cast(x); +} + +template __forceinline Ty *DrRemovePtrConst(const Ty *pX) +{ + return const_cast(pX); +} + +UInt64 GetUniqueObjectID(); + +// a type for automatically assigning a unique object ID for static initializers +class DrObjID +{ +public: + DrObjID() { _val = GetUniqueObjectID();} + UInt64 Value() const { return _val; }; +private: + UInt64 _val; +}; + +/* + * DrRefCounter + * + * Implements simple reference count functionality. + * Derive any ref counted objects from this class. + */ + +class DrRefCounter; + +static const LONG k_lDecommissionedRefCount = (LONG)-1234; + +class DrRefCountMonitor +{ +public: + virtual void OnRefCountChanged(const DrRefCounter *pCounter, void *pContext, LONG newCount, const char *pszReason) = 0; +}; + +class IDrRefCounter +{ +public: + virtual LONG IncRef() = 0; + LONG IncRef() const + { + // IncRef/DecRef are allowed even for const instances + return DrRemovePtrConst(this)->IncRef(); + } + + // called by weak reference pointers (under the shared lock) to IncRef this object. + // returns TRUE if the IncRef could be performed, false if the object is already unreferenceable + // The default implementation succeeds unless the current refcount is 0, on the assumption that the implementation + // will clear all existing weak references after the refcount becomes 0 and before the object is deleted. + virtual bool IncRefFromWeakReferenceLocked() = 0; + bool IncRefFromWeakReferenceLocked() const + { + // IncRef/DecRef are allowed even for const instances + return DrRemovePtrConst(this)->IncRefFromWeakReferenceLocked(); + } + + // called by weak reference Freememory implementations (under the shared lock) to get the refcount on this object. + virtual LONG GetRefCountLocked() const = 0; + + virtual LONG DecRef() = 0; + LONG DecRef() const + { + // IncRef/DecRef are allowed even for const instances + return DrRemovePtrConst(this)->DecRef(); + } + virtual void FreeMemory() = 0; + + void FreeMemory() const + { + // FreeMemory is allowed even for const instances + DrRemovePtrConst(this)->FreeMemory(); + } + + virtual ~IDrRefCounter() + { + } + virtual UInt64 GetOID() const = 0; // Get object-id, an (almost) unique identifier that can be used for logging (this way we can see related log lines for a given object) +}; + + +class DrOneInitializedVolatileLong +{ +public: + volatile LONG m_long; + + DrOneInitializedVolatileLong() + { + m_long = 1; + } +}; + +class DrOneInitializedZeroDestroyedVolatileLong : public DrOneInitializedVolatileLong +{ +public: + ~DrOneInitializedZeroDestroyedVolatileLong() + { + LogAssert(m_long == 0 || m_long == k_lDecommissionedRefCount); + } +}; + +class DrOneInitializedOneDestroyedVolatileLong : public DrOneInitializedVolatileLong +{ +public: + ~DrOneInitializedOneDestroyedVolatileLong() + { + LogAssert(m_long == 1 || m_long == k_lDecommissionedRefCount); + } +}; + + +#define DRREFCOUNTNOIMPL \ + public: \ + virtual LONG IncRef() override = 0; \ + virtual LONG DecRef() override = 0; \ + virtual UInt64 GetOID() const override = 0; \ + virtual bool IncRefFromWeakReferenceLocked() override = 0; \ + virtual LONG GetRefCountLocked() const override = 0; \ + virtual void FreeMemory() override = 0; \ + + +// NOTE: the following could use DrOneInitializedZeroDestroyedVolativeLong, but some people create static instances of this... +#define DRREFCOUNTIMPL_NOFREEMEMORY_BASE(InterfaceClass) \ + protected: \ + mutable DrOneInitializedVolatileLong m_iRefCount; \ + DrObjID m_oid; \ + public: \ + virtual LONG IncRef() override\ + { \ + LONG i; \ + i = InterlockedIncrement (&(m_iRefCount.m_long)); \ + LogAssert (i > 1); \ + return i; \ + } \ + virtual LONG DecRef() override\ + { \ + LONG i; \ + i = InterlockedDecrement (&(m_iRefCount.m_long)); \ + if (i <= 0) \ + { \ + LogAssert (i == 0); \ + FreeMemory (); \ + } \ + return i; \ + } \ + virtual UInt64 GetOID() const override{ return m_oid.Value(); } \ + virtual bool IncRefFromWeakReferenceLocked() override\ + { \ + if (m_iRefCount.m_long == 0) { \ + return false; \ + } \ + LONG i = InterlockedIncrement (&(m_iRefCount.m_long)); \ + LogAssert (i > 1); \ + return true; \ + } \ + virtual LONG GetRefCountLocked() const override\ + { \ + return m_iRefCount.m_long; \ + } \ + + +#define DRREFCOUNTIMPL_STATIC_BASE(InterfaceClass) \ + private: \ + static void * operator new( size_t) { LogAssert(false); return NULL; } \ + static void operator delete( void *) { LogAssert(false); } \ + protected: \ + mutable DrOneInitializedOneDestroyedVolatileLong m_iRefCount; \ + DrObjID m_oid; \ + public: \ + virtual LONG IncRef() override\ + { \ + LONG i; \ + i = InterlockedIncrement (&(m_iRefCount.m_long)); \ + LogAssert (i > 1); \ + return i; \ + } \ + virtual LONG DecRef() override\ + { \ + LONG i; \ + i = InterlockedDecrement (&(m_iRefCount.m_long)); \ + LogAssert(i > 0); \ + return i; \ + } \ + virtual UInt64 GetOID() const override{ return m_oid.Value(); } \ + virtual bool IncRefFromWeakReferenceLocked() override \ + { \ + LONG i = InterlockedIncrement (&(m_iRefCount.m_long)); \ + LogAssert (i > 1); \ + return true; \ + } \ + virtual LONG GetRefCountLocked() const override\ + { \ + return m_iRefCount.m_long; \ + } \ + + +#define DRREFCOUNTIMPL_NOFREEMEMORY DRREFCOUNTIMPL_NOFREEMEMORY_BASE(IDrRefCounter) + +#define DRREFCOUNTIMPL_BASE(InterfaceClass) \ + DRREFCOUNTIMPL_NOFREEMEMORY_BASE(InterfaceClass) \ + protected: \ + virtual void FreeMemory() override\ + { \ + delete this; \ + } + + +#define DRREFCOUNTIMPL DRREFCOUNTIMPL_BASE(IDrRefCounter) +#define DRREFCOUNTIMPL_STATIC DRREFCOUNTIMPL_STATIC_BASE(IDrRefCounter) + +class DrRefCounter : public IDrRefCounter +{ +protected: + mutable volatile LONG m_iRefCount; + DrRefCountMonitor *m_pMonitor; + void *m_pMonitorContext; + UInt64 m_oid; + + DrRefCounter() + { + m_iRefCount = 1; + m_pMonitor = NULL; + m_pMonitorContext = NULL; + m_oid = GetUniqueObjectID(); + } + + + virtual ~DrRefCounter() + { + LogAssert (m_iRefCount == 0 || m_iRefCount == k_lDecommissionedRefCount); + } + + // Called when the refcount becomes zero. + virtual void FreeMemory() override + { + delete this; + } + +public: + + // not threadsafe + void SetRefCountMonitor(DrRefCountMonitor *pMonitor, void *pContext = NULL) + { + m_pMonitor = pMonitor; + m_pMonitorContext = pContext; + if (m_pMonitor != NULL) { + m_pMonitor->OnRefCountChanged(this, m_pMonitorContext, m_iRefCount, "SetMonitor"); + } + } + + virtual LONG IncRef() override + { + LONG i; + i = InterlockedIncrement (&m_iRefCount); + LogAssert (i > 1); + if (m_pMonitor != NULL) { + m_pMonitor->OnRefCountChanged(this, m_pMonitorContext, i, "IncRef"); + } + return i; + } + + // Increments without checking that the previous refcount was > 0. Used for special circumstances where we + // may temporarily increment from a 0 refcount + LONG IncRefNoCheck() + { + LONG i; + i = InterlockedIncrement (&m_iRefCount); + if (m_pMonitor != NULL) { + m_pMonitor->OnRefCountChanged(this, m_pMonitorContext, i, "IncRefNoCheck"); + } + return i; + } + + // called by weak reference pointers (under the shared lock) to IncRef this object. + // returns TRUE if the IncRef could be performed, false if the object is already unreferenceable + // The default implementation succeeds unless the current refcount is 0, on the assumption that the implementation + // will clear all existing weak references after the refcount becomes 0 and before the object is deleted. + virtual bool IncRefFromWeakReferenceLocked() override + { + if (m_iRefCount == 0) { + return false; + } + LONG i = InterlockedIncrement (&m_iRefCount); + LogAssert (i > 1); \ + if (m_pMonitor != NULL) { + m_pMonitor->OnRefCountChanged(this, m_pMonitorContext, i, "IncRefFromWeakReferenceLocked"); + } + return true; + } + + virtual LONG GetRefCountLocked() const override + { + return m_iRefCount; + } + + virtual LONG DecRef() override + { + LONG i; + if (m_pMonitor != NULL) { + // HACK: the refcount may be wrong, but it will work + m_pMonitor->OnRefCountChanged(this, m_pMonitorContext, m_iRefCount-1L, "DecRef"); + } + i = InterlockedDecrement (&m_iRefCount); + if (i <= 0) + { + LogAssert (i == 0); + FreeMemory (); + } + return i; + } + + // Decrements without special handling when the refcount becomes zero. Used for special circumstances + // where the refcount has been temporarily incremented from zero and another context will be + // calling FreeMemory(). + LONG DecRefNoFree() + { + LONG i; + if (m_pMonitor != NULL) { + // HACK: the refcount may be wrong, but it will work + m_pMonitor->OnRefCountChanged(this, m_pMonitorContext, m_iRefCount-1L, "DecRefNoFree"); + } + i = InterlockedDecrement (&m_iRefCount); + LogAssert(i >= 0); + return i; + } + + virtual UInt64 GetOID() const override + { + return m_oid; + } + + + void ResetRefCounter() + { + LONG i; + i = InterlockedExchange (&m_iRefCount, 1); + if ( i != 0 ) + { + DrLogE( "DecRef error - refCount=%ld, oid=%I64x", i, GetOID() ); + LogAssert (i == 0); + } + if (m_pMonitor != NULL) { + m_pMonitor->OnRefCountChanged(this, m_pMonitorContext, 1, "ResetRefCounter"); + } + } + + // Abandons the refcounter. Used when a subclass overrides the implementation. Prevents the refcounter from being used, and + // prevents an assertion failure when the object is destructed. + void DecommissionRefCounter() + { + LONG i; + // We set a magic number to indicate that it is decommissioned. Any call to IncRef or DecRef will assert, and destruction will succeed + i = InterlockedExchange (&m_iRefCount, k_lDecommissionedRefCount); + LogAssert(i == 1 || i == k_lDecommissionedRefCount); + } + + // The value returned by this method is not stable unless the caller provides another mechanism for guaranteeing that + // noone calls IncRef or DecRef. + LONG GetRefCount() const + { + return m_iRefCount; + } +}; + +template class DrRef +{ + +public: + DrRef() + { + p = NULL; + } + + DrRef(T* lp) + { + p = lp; + + if (p != NULL) { + DrRemovePtrConst(p)->IncRef(); + } + } + + explicit DrRef(const DrRef& ref) + { + p = ref.p; + + if (p != NULL) { + DrRemovePtrConst(p)->IncRef(); + } + } + + ~DrRef() + { + if (p != NULL) { + DrRemovePtrConst(p)->DecRef(); + p = NULL; // Make sure we AV in case someone is using DrRef after DecRef + } + } + + operator T*() const + { + return p; + } + + T& operator*() const + { + return *p; + } + + T* operator->() const + { + return p; + } + + bool operator!() const + { + return (p == NULL); + } + + bool operator<(T* pT) const + { + return (p < pT); + } + + bool operator>(T* pT) const + { + return (p < pT); + } + + bool operator<=(T* pT) const + { + return (p <= pT); + } + + bool operator>=(T* pT) const + { + return (p < pT); + } + + bool operator==(T* pT) const + { + return (p == pT); + } + + bool operator!=(T* pT) const + { + return (p != pT); + } + + DrRef& Set(T* lp) + { + if (p != lp) { + if (lp != NULL) { + DrRemovePtrConst(lp)->IncRef(); + } + + if (p != NULL) { + DrRemovePtrConst(p)->DecRef(); + } + + p = lp; + } + + return *this; + } + + DrRef& operator=(T* lp) + { + return Set(lp); + } + + DrRef& operator=(const DrRef& ref) + { + return Set(ref.p); + } + + // Release the interface and set to NULL + void Release() + { + T* pTemp = p; + if (pTemp != NULL) { + p = NULL; + DrRemovePtrConst(pTemp)->DecRef(); + } + } + + // + // Attach to an existing interface (does not IncRef) + // + void Attach(T* p2) + { + // + // Remove reference to previous interface + // + if (p != NULL) + { + DrRemovePtrConst(p)->DecRef(); + } + + // + // Update current interface + // + p = p2; + } + + // Detach the interface (does not DecRef) + T* Detach() + { + T* pt = p; + p = NULL; + return pt; + } + + void TransferFrom( DrRef& source) + { + Attach(source.Detach()); + } + + template void TransferFrom( DrRef& source) + { + T2 *p2 = source.Detach(); + if (p != NULL) { + DrRemovePtrConst(p)->DecRef(); + } + if (p2== NULL) { + p = NULL; + } else { + p = dynamic_cast(p2); + LogAssert(p != NULL); + } + } + + T* Ptr() const + { + return p; + } + +private: + T* p; +}; + +// Growable vector of DrRef smart pointers to arbitrary refcounted typed items +// Insertions and deletions can be performed both at the head and at the tail of the list, making it suitable for queues. +template class DrRefList +{ +public: + DrRefList() + { + m_nEntries = 0; + m_nAllocated = 0; + m_prgEntries = NULL; + } + + ~DrRefList() + { + if (m_prgEntries != NULL) { + delete[] m_prgEntries; + m_prgEntries = NULL; + } + m_nEntries = 0; + m_nAllocated = 0; + } + + // forces the buffer to be reallocated with the given size, even if it + // is the same as the current size. + // the requested size must be big enough to hold the valid entries. + // On exit, the valid entries are always contiguous starting at offset 0 + // if n==0, frees the buffer + void ForceRealloc(::UInt32 n) + { + LogAssert(n >= m_nEntries); + DrRef *pnew = NULL; + if (n != 0) { + pnew = new DrRef[n]; + LogAssert(pnew != NULL); + ::UInt32 uFrom = m_uFirstEntry; + for (::UInt32 i = 0; i < m_nEntries; i++) { + pnew[i].TransferFrom(m_prgEntries[uFrom++]); + if (uFrom >= m_nAllocated) { + uFrom = 0; + } + } + } + if (m_prgEntries != NULL) { + delete[] m_prgEntries; + } + m_prgEntries = pnew; + m_nAllocated = n; + m_uFirstEntry = 0; + } + + // reallocates the buffer if there are not at least n elements allocated + void GrowTo(::UInt32 n) + { + if (n > m_nAllocated) { + if (n < 2 * m_nAllocated) { + n = 2 * m_nAllocated; + } + if (n < 20) { + n = 20; + } + ForceRealloc(n); + } + } + + // Converts a potentially wrapped list (if you have moved the head) into a + // contiguous list, and returns a pointer to the first item in the contiguous list + // If possible, nothing is moved. If the list is wrapped, a new buffer is allocated (easier than moving everything in a full list) + // This operation is always cheap if you never remove from or add to the head. + DrRef *MakeContiguous() + { + if (m_uFirstEntry + m_nEntries > m_nAllocated) { + ForceRealloc(m_nEntries); + } + + return m_prgEntries + m_uFirstEntry; + } + + UInt32 NumEntries() const + { + return m_nEntries; + } + + UInt32 NumAllocated() const + { + return m_nAllocated; + } + + DrRef& EntryAt(UInt32 index) + { + LogAssert(index < m_nEntries); + return m_prgEntries[NormalizeEntryIndex(index)]; + } + + const t *ConstEntryAt(UInt32 index) const + { + LogAssert(index < m_nEntries); + return m_prgEntries[NormalizeEntryIndex(index)]; + } + + const t *EntryAt(UInt32 index) const + { + ConstEntryAt(index); + } + + DrRef& operator[](UInt32 index) + { + return EntryAt(index); + } + + const t *operator[](UInt32 index) const + { + return ConstEntryAt(index); + } + + // returns NULL if list is empty + t *Head() + { + if (m_nEntries == 0) { + return NULL; + } + return m_prgEntries[m_uFirstEntry]; + } + + // returns NULL if list is empty + const t *Head() const + { + if (m_nEntries == 0) { + return NULL; + } + return m_prgEntries[m_uFirstEntry]; + } + + // returns NULL if list is empty + t *Tail() + { + if (m_nEntries == 0) { + return NULL; + } + return EntryAt(m_nEntries-1); + } + + // returns NULL if list is empty + const t *Tail() const + { + if (m_nEntries == 0) { + return NULL; + } + return EntryAt(m_nEntries-1); + } + + // Invalidates all entry references and pointers previously returned + DrRef& AddEntryToTail(t *pNewEntry = NULL) + { + GrowTo(m_nEntries+1); + DrRef& newEntry = m_prgEntries[NormalizeEntryIndex(m_nEntries++)]; + newEntry = pNewEntry; + return newEntry; + } + + // Invalidates all entry references and pointers previously returned + DrRef& AddEntryToHead(t *pNewEntry = NULL) + { + GrowTo(m_nEntries+1); + if (m_uFirstEntry == 0) { + m_uFirstEntry = m_nAllocated - 1; + } else { + m_uFirstEntry--; + } + m_nEntries++; + DrRef& newEntry = m_prgEntries[m_uFirstEntry]; + newEntry = pNewEntry; + return newEntry; + } + + // Invalidates all entry references and pointers previously returned + DrRef& AddEntry(t *pNewEntry = NULL) + { + return AddEntryToTail(pNewEntry); + } + + void RemoveEntryFromTail(DrRef *pValOut) + { + LogAssert(m_nEntries != 0); + pValOut->TransferFrom(m_prgEntries[NormalizeEntryIndex(--m_nEntries)]); + } + + void RemoveEntryFromHead(DrRef *pValOut) + { + LogAssert(m_nEntries != 0); + pValOut->TransferFrom(m_prgEntries[m_uFirstEntry++]); + if (m_uFirstEntry >= m_nAllocated) { + m_uFirstEntry = 0; + } + m_nEntries--; + } + + + void Clear() + { + ::UInt32 uIndex = m_uFirstEntry; + for (::UInt32 i = 0; i < m_nEntries; i++) { + m_prgEntries[uIndex++] = NULL; + if (uIndex >= m_nAllocated) { + uIndex = 0; + } + } + m_nEntries = 0; + m_uFirstEntry = 0; // might as well reset to the beginning of the array + } + + void AddNullEntriesToTail(UInt32 numNulls) + { + if (numNulls != 0) + { + // Depends on the fact that all unused entries are NULL + LogAssert(m_nEntries + numNulls > m_nEntries); // Check for overflow + GrowTo(m_nEntries + numNulls); + m_nEntries += numNulls; + } + } + + typedef int (__cdecl *PDRREF_COMPARE_FUNCTION)(void *context, t *p1, t *p2); + + typedef struct { + PDRREF_COMPARE_FUNCTION compare; + void *context; + } DrRefSortContext; + + static int __cdecl InternalDrRefCompare(void *context, const void *p1, const void *p2) + { + DrRefSortContext *pSortContext = (DrRefSortContext *)context; + return (*pSortContext->compare)(pSortContext->context, *(DrRef *)p1, *(DrRef *)p2); + } + + // performs a quicksort on the list, with a user-provided compare function + // Invalidates all entry references and pointers previously returned + void Sort(PDRREF_COMPARE_FUNCTION compare, void * context = NULL) + { + if (m_nEntries > 1) { + DrRef *pFirst = MakeContiguous(); + DrRefSortContext ctx; + ctx.compare = compare; + ctx.context = context; + qsort_s(pFirst, (size_t)m_nEntries, sizeof(*pFirst), InternalDrRefCompare, &ctx); + } + } + +protected: + UInt32 NormalizeEntryIndex(UInt32 index) const + { + LogAssert(index < m_nAllocated); + if (m_uFirstEntry != 0) { + index += m_uFirstEntry; + if (index >= m_nAllocated) { + index -= m_nAllocated; + } + } + return index; + } + +protected: + UInt32 m_uFirstEntry; // Normally 0, the index of the first entry (for circular buffers) + +private: + UInt32 m_nEntries; + UInt32 m_nAllocated; + DrRef *m_prgEntries; +}; + +#endif //end if not defined __DRYADREFCOUNTER_H__ diff --git a/DryadVertex/VertexHost/system/classlib/include/DrStringUtil.h b/DryadVertex/VertexHost/system/classlib/include/DrStringUtil.h new file mode 100644 index 0000000..c40d1ce --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrStringUtil.h @@ -0,0 +1,1011 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#define NULLSTR ((char *)NULL) + +// This macro can be used to obtain a temporary "const char *" error description for an error. It can be used +// as the parameter to a method call; the pointer will become invalid after the function returns +#define DRERRORSTRING(err) (DrErrorString(err).GetString()) + +// This macro can be used to obtain a temporary "const char *" UTF-8 string equivalent for a unicode string. It can be used +// as the parameter to a method call; the pointer will become invalid after the function returns +#define DRWSTRINGTOUTF8(s) (DrStr64(DrStr((const WCHAR *)(s))).GetString()) + +// This macro can be used to obtain a temporary "const char *" description for a status value. It can be used +// as the parameter to a method call; the pointer will become invalid after the function returns +// succeeded string: OK +// failed string: FAILED 80001234 'Error text' +#define DR_STATUS_STRING(status) (DrStatusString(status).GetString()) + +// This macro can be used to obtain a temporary "const char *" string equivalent for a GUID. It can be used +// as the parameter to a method call; the pointer will become invalid after the function returns +#define DRGUIDSTRING(guid) (DrGuidString(guid, false).GetString()) + +__inline bool ISSPACE(char c) +{ + return isspace((unsigned char)c) != 0; +} + +static const Size_t DrStr_InvalidIndex = Max_Size_t; + +// Simple growable string classes +class DrStr +{ +public: + DrStr() + { + InitZero(); + } + + DrStr(const DrStr& other) + { + InitZero(); + Set(other); + } + + DrStr(const char *psz) + { + InitZero(); + Set(psz); + } + + DrStr(const char *psz, size_t len) + { + InitZero(); + Set(psz, len); + } + + DrStr(const WCHAR *psz) + { + InitZero(); + Set(psz); + } + + DrStr(const WCHAR *psz, size_t len) + { + InitZero(); + Set(psz, len); + } +/*JC + DrStr(const WCHAR *psz, size_t len, UINT codePage) + { + InitZero(); + Set(psz, len, codePage); + } + + DrStr(const DrWStr& other); +*/ + + virtual ~DrStr() + { + DiscardString(); + } + + // may be NULL + const char *GetString() const + { + return m_pBuffer; + } + + size_t GetLength() const + { + return m_stringLen; + } + + size_t GetBufferLength() const + { + return m_nbBuffer; + } + + operator const char *() const + { + return m_pBuffer; + } + + DrStr& Set(const char *psz, size_t len) + { + if (psz == NULL) { + DiscardString(); + } else { + GrowTo(len); + memcpy(m_pBuffer, psz, len); + m_pBuffer[len] = '\0'; + m_stringLen = len; + } + return *this; + } + + DrStr& Set(const char *psz) + { + size_t len = 0; + if (psz != NULL) { + len = strlen(psz); + } + return Set(psz, len); + } + + DrStr& Set(const DrStr& other) + { + return Set(other.m_pBuffer, other.m_stringLen); + } + + // Converts from unicode + DrStr& Set(const WCHAR *psz, size_t len, UINT codePage) + { + if (psz == NULL) { + DiscardString(); + } else { + SetToEmptyString(); + Append(psz, len, codePage); + } + return *this; + } + + // Converts from unicode to UTF8 + DrStr& Set(const WCHAR *psz, size_t len) + { + return Set(psz, len, CP_UTF8); + } + + DrStr& Set(const WCHAR *psz) + { + size_t len = 0; + if (psz != NULL) { + len = wcslen(psz); + } + return Set(psz, len); + } + + DrStr& VSetF(const char *pszFormat, va_list args) + { + SetToEmptyString(); + return VAppendF(pszFormat, args); + } + +#pragma warning (push) +#pragma warning (disable: 4793) // function compiled as native + DrStr& SetF(const char *pszFormat, ...) + { + va_list args; + va_start(args, pszFormat); + + return VSetF(pszFormat, args); + } +#pragma warning (pop) + + // + // Set string value to environment variable value + // + DrError SetFromEnvironmentVariable(const char *pszVarName) + { + SetToEmptyString(); + return AppendFromEnvironmentVariable(pszVarName); + } + + DrStr& operator=(const char *psz) + { + return Set(psz); + } + + DrStr& operator=(const DrStr& other) + { + return Set(other); + } + + // Returns a writable buffer of >= 0 chars, but without growing if not necessary + // The max chars to write are returned in *pLenOut + char *GetAnyWritableBuffer(size_t *pLenOut, size_t offset = 0) + { + GrowTo(offset); + *pLenOut = m_nbBuffer - offset - 1; + return m_pBuffer + offset; + } + + DrStr& Append(const char *psz, size_t len) + { + if (psz != NULL) { + size_t oldLen = m_stringLen; + GrowTo(oldLen + len); + memcpy(m_pBuffer + oldLen, psz, len); + m_stringLen = oldLen + len; + m_pBuffer[m_stringLen] = '\0'; + } + return *this; + } + + DrStr& Append(const char *psz) + { + size_t len = 0; + if (psz != NULL) { + len = strlen(psz); + } + return Append(psz, len); + } + + DrStr& Append(const DrStr& other) + { + return Append(other.m_pBuffer, other.m_stringLen); + } + + DrStr& Append(char c) + { + GrowTo(m_stringLen+1); + m_pBuffer[m_stringLen++] = c; + m_pBuffer[m_stringLen] = '\0'; + return *this; + } + + DrStr& Append(const WCHAR *psz, size_t len, UINT codePage); + + DrStr& Append(const WCHAR *psz, size_t len) + { + return Append(psz, len, CP_UTF8); + } + + DrStr& Append(const WCHAR *psz) + { + size_t len = 0; + if (psz != NULL) { + len = wcslen(psz); + } + return Append(psz, len); + } + + DrStr& VAppendF(const char *pszFormat, va_list args); + +#pragma warning (push) +#pragma warning (disable: 4793) // function compiled as native + DrStr& AppendF(const char *pszFormat, ...) + { + va_list args; + va_start(args, pszFormat); + + return VAppendF(pszFormat, args); + } +#pragma warning (pop) + + DrStr& AppendErrorString(DrError val) + { + return DrAppendErrorDescription(*this, val); + } + + DrStr& AppendXmlEncodedString(const char *pszUnencoded, Size_t len, bool fEscapeNewlines=false); + + DrStr& AppendXmlEncodedString(const DrStr& strUnencoded, bool fEscapeNewlines=false) + { + return AppendXmlEncodedString(strUnencoded.GetString(), strUnencoded.GetLength(), fEscapeNewlines); + } + + DrStr& AppendXmlEncodedString(const char *pszUnencoded, bool fEscapeNewlines=false) + { + return AppendXmlEncodedString(pszUnencoded, pszUnencoded == NULL ? (Size_t)0 : strlen(pszUnencoded), fEscapeNewlines); + } + + DrStr& SetErrorString(DrError val) + { + SetToEmptyString(); return AppendErrorString(val); + } + + DrStr& operator+=(const char *psz) + { + return Append(psz); + } + + DrStr& operator+=(const DrStr& other) + { + return Append(other); + } + + DrStr& operator+=(char c) + { + return Append(c); + } + + // Asserts if not inside the string. You cannot fetch the null terminator with this + // The return value is an lvalue, so it can be set as well as fetched + char& operator[](size_t index) + { + LogAssert(m_pBuffer != NULL && index <= m_stringLen); + return m_pBuffer[index]; + } + + char& operator[](int index) + { + LogAssert(m_pBuffer != NULL && index >= 0 && (size_t)index <= m_stringLen); + return m_pBuffer[index]; + } + + char operator[](size_t index) const + { + LogAssert(m_pBuffer != NULL && index <= m_stringLen); + return m_pBuffer[index]; + } + + char operator[](int index) const + { + LogAssert(m_pBuffer != NULL && index >= 0 && (size_t)index <= m_stringLen); + return m_pBuffer[index]; + } + + // Uses _strlwr (not multibyte aware) + // leaves NULL strings alone + DrStr& ToLowerCase(); + + // Uses _strupr (not multibyte aware) + // leaves NULL strings alone + DrStr& ToUpperCase(); + + bool IsEqual(const char *pszOther, size_t len) const + { + if (len != m_stringLen) { + return false; + } + if (pszOther == NULL || m_pBuffer == NULL) { + return (m_pBuffer == pszOther); + } + if (len == 0) { + return true; + } + return (memcmp(m_pBuffer, pszOther, len) == 0); + } + + bool IsEqual(const DrStr& other) const + { + return IsEqual(other.GetString(), other.GetLength()); + } + + bool IsEqual(const char *pszOther) const + { + return IsEqual(pszOther, (pszOther == NULL) ? (size_t)0 : strlen(pszOther)); + } + + bool IsNullOrEmpty( bool ignoreWhitespace = false ) + { + if ( m_pBuffer == NULL ) + return true; + const unsigned char* p = (const unsigned char *)m_pBuffer; + if ( ignoreWhitespace ) + while ( isspace( *p ) ) + p++; + return '\0' == *p; + } + + bool operator==(const char *pszOther) const + { + return IsEqual(pszOther); + } + + bool operator==(const DrStr& other) const + { + return IsEqual(other.m_pBuffer, other.m_stringLen); + } + + + bool operator!=(const char *pszOther) const + { + return !IsEqual(pszOther); + } + + bool operator!=(const DrStr& other) const + { + return !IsEqual(other.m_pBuffer, other.m_stringLen); + } + + bool SubstrIsEqual(size_t index, const char *pszMatch, size_t matchLen) const; + + bool SubstrIsEqual(size_t index, const char *pszMatch) const + { + return SubstrIsEqual(index, pszMatch, strlen(pszMatch)); + } + + bool StartsWith(const char *pszMatch, size_t len) const + { + return SubstrIsEqual(0, pszMatch, len); + } + + bool StartsWith(const char *pszMatch) const + { + return StartsWith(pszMatch, strlen(pszMatch)); + } + + // Not multibyte aware + bool SubstrIsEqualNoCase(size_t index, const char *pszMatch, size_t matchLen) const; + + // Not multibyte aware + bool SubstrIsEqualNoCase(size_t index, const char *pszMatch) const + { + return SubstrIsEqualNoCase(index, pszMatch, strlen(pszMatch)); + } + + // Not multibyte aware + bool StartsWithNoCase(const char *pszMatch, size_t len) const + { + return SubstrIsEqualNoCase(0, pszMatch, len); + } + + // Not multibyte aware + bool StartsWithNoCase(const char *pszMatch) const + { + return StartsWithNoCase(pszMatch, strlen(pszMatch)); + } + + // Returns DrStr_InvalidIndex if there is no match or the starting index is out of range + // Returns the string length if the null terminator is matched + // Uses strchr - not multibyte aware. + size_t IndexOfChar(char c, size_t startIndex = 0) const; + + char *GetWritableBuffer(size_t maxStringLen, size_t offset = 0) + { + LogAssert(offset + maxStringLen >= offset); // overflow check + GrowTo(offset + maxStringLen); + return m_pBuffer + offset; + } + + char *GetWritableAppendBuffer(size_t maxStringLen) + { + return GetWritableBuffer(maxStringLen, m_stringLen); + } + + DrError AppendFromEnvironmentVariable(const char *pszVarName); + + void SetToNull() + { + DiscardString(); + } + + void SetToEmptyString() + { + GrowTo(0); + m_stringLen = 0; + m_pBuffer[0] = '\0'; + } + + void EnsureNotNull() + { + GrowTo(0); + } + + void UpdateLength(size_t stringLen) + { + if (m_pBuffer == NULL) { + LogAssert(stringLen == 0); + LogAssert(m_stringLen == 0); + } else { + LogAssert(stringLen < m_nbBuffer); + m_pBuffer[stringLen] = '\0'; + m_stringLen = stringLen; + } + } + +protected: + // Ensures that there are at least maxStringLength+1 bytes available (including null terminator) for the string. + // Does not change the string length. + // If the string was NULL, it becomes an empty string. + void GrowTo(size_t maxStringLength); + + void DiscardString() + { + if (m_pBuffer != NULL && m_pBuffer != m_pStaticBuffer) { + delete[] m_pBuffer; + } + m_pBuffer = NULL; + m_nbBuffer = 0; + m_stringLen = 0; + } + +protected: + DrStr(char *pStaticBuffer, size_t nbStatic, bool isFastStr) + { + m_pStaticBuffer = pStaticBuffer; + m_nbStatic = nbStatic; + m_pBuffer = NULL; + m_nbBuffer = 0; + m_stringLen = 0; + } + + void SetToStaticBuffer() + { + m_pBuffer = m_pStaticBuffer; + m_nbBuffer = m_nbStatic; + m_stringLen = m_nbStatic; + } + + virtual char* AllocateBiggerBuffer( size_t newSize ) { + return new char[newSize]; + } + +private: + void InitZero() + { + m_pStaticBuffer = NULL; + m_nbStatic = 0; + m_pBuffer = NULL; + m_nbBuffer = 0; + m_stringLen = 0; + } + +private: + char *m_pStaticBuffer; // stack-allocated buffer in subclass, or null + size_t m_nbStatic; // Number of bytes in the static buffer (including null terminator) + char *m_pBuffer; // if equal to m_pStaticBuffer, then we are using the static buffer. Otherwise a heap block. + size_t m_nbBuffer; // size of the buffer + size_t m_stringLen; // Number of characters in the string, not including null terminator + +}; + +// Variable size template for growable strings +template class DrFastStr : public DrStr +{ +public: + DrFastStr() : + DrStr(m_szStaticBuffer, buffSize, true) + { + } + + DrFastStr(const DrStr& other) : + DrStr(m_szStaticBuffer, buffSize, true) + { + Set(other); + } + + // explicitly defined copy constructor is very important; without it, C++ generates a default one that doesn't work + DrFastStr(const DrFastStr& other) : + DrStr(m_szStaticBuffer, buffSize, true) + { + Set(other); + } + + DrFastStr(const char *psz) : + DrStr(m_szStaticBuffer, buffSize, true) + { + Set(psz); + } + + DrFastStr(const char *psz, size_t len) : + DrStr(m_szStaticBuffer, buffSize, true) + { + Set(psz, len); + } + +/*JC + explicit DrFastStr(const WCHAR *psz) : + DrStr(m_szStaticBuffer, buffSize, true) + { + Set(psz); + } + + DrFastStr(const WCHAR *psz, size_t len) : + DrStr(m_szStaticBuffer, buffSize, true) + { + Set(psz, len); + } + + DrFastStr(const WCHAR *psz, size_t len, UINT codePage) : + DrStr(m_szStaticBuffer, buffSize, true) + { + Set(psz, len, codePage); + } + + DrFastStr(const DrWStr& other) : + DrStr(m_szStaticBuffer, buffSize, true) + { + Set(other); + } +*/ + + DrStr& operator=(const char *psz) + { + return Set(psz); + } + + DrStr& operator=(const DrStr& other) + { + return Set(other); + } + + // explicitly provided copy assignment operator to make sure C++ doesn't try to outsmart us + DrStr& operator=(const DrFastStr& other) + { + return Set(other); + } + + +/*JC DrStr& operator=(const DrWStr& other) + { + return Set(other); + } +*/ + +protected: + char m_szStaticBuffer[buffSize]; +}; + +typedef DrStr DrStr0; +typedef DrFastStr<4> DrStr4; +typedef DrFastStr<8> DrStr8; +typedef DrFastStr<16> DrStr16; +typedef DrFastStr<32> DrStr32; +typedef DrFastStr<64> DrStr64; +typedef DrFastStr<128> DrStr128; +typedef DrFastStr<256> DrStr256; +typedef DrFastStr<512> DrStr512; +typedef DrFastStr<1024> DrStr1024; +typedef DrFastStr<1024> DrStr1K; +typedef DrFastStr<2048> DrStr2K; +typedef DrFastStr<4096> DrStr4K; +typedef DrFastStr<8192> DrStr8K; +typedef DrFastStr<16384> DrStr16K; +typedef DrFastStr<32768> DrStr32K; +typedef DrFastStr<65536> DrStr64K; + +class DrErrorString : public DrStr128 +{ +public: + DrErrorString(DrError err) + { + SetErrorString(err); + } +}; + +class DrGuidString : public DrStr64 +{ +public: + DrGuidString( const DrGuid& guid, bool fBraces ) + { + guid.AppendToString( *this, fBraces); + } + + DrGuidString( const GUID& guid, bool fBraces ) + { + DrGuid g; + g.Set(guid); + g.AppendToString( *this, fBraces ); + } +}; + +class DrStatusString : public DrStr128 +{ +public: + DrStatusString(DrError status) + { + if ( SUCCEEDED( status ) ) + { + if ( status == DrError_OK ) + { + Append( "OK" ); + } + else + { + AppendF( "OK %08x", status ); + } + } + else + { + PCSTR s; + if (status == DrErrorFromWin32(ERROR_IO_PENDING)) + { + s = "IN_PROGRESS"; + } + else if (status == DrError_AlreadyCompleted) + { + s = "ASYNC_COMPLETION"; + } + else + { + s = "FAILED"; + } + + AppendF("%s %08x \'", s, status); + AppendErrorString( status ); + Append( '\'' ); + } + } +}; + +class DrTempStringPool +{ +private: + + // TempStringBlock allows us to maintain a stack of temporarily allocated + // return strings that is cleaned up when the stack is destroyed. + class TempStringBlock + { + private: + TempStringBlock *pNext; + BYTE *pData; + Size_t length; + + // We override "new" to allocate the header and the content + // in a single allocation. + inline void *operator new(Size_t headersize, Size_t blocksize) + { + LogAssert(headersize == sizeof(TempStringBlock)); + LogAssert(headersize + blocksize >= headersize); // keep prefast happy + void *p = malloc(headersize + blocksize); + LogAssert(p != NULL); + return p; + } + + inline TempStringBlock(Size_t blocksize, TempStringBlock *pOldHead) + { + pData = ((BYTE *)(void *)this) + sizeof(*this); + length = blocksize; + pNext = pOldHead; + } + + public: + // We have to provide a matching delete... + inline void operator delete(void *pMem, Size_t blocksize) + { + (void)blocksize; + free(pMem); + } + + inline static TempStringBlock *Alloc(Size_t blocksize, TempStringBlock *pOldHead) + { + TempStringBlock *p = new(blocksize) TempStringBlock(blocksize, pOldHead); + return p; + } + + inline TempStringBlock *Detach() + { + TempStringBlock *p = pNext; + pNext = NULL; + return p; + } + + inline ~TempStringBlock() + { + LogAssert(pNext == NULL); + } + + inline BYTE *GetData() + { + return pData; + } + + inline Size_t GetLength() + { + return length; + } + + inline BYTE *ReserveData(Size_t len) + { + LogAssert(len <= length); + BYTE *pd = pData; + pData += len; + length -= len; + return pd; + } + + }; + +public: + DrTempStringPool(Size_t growSize = 8192) + { + m_growSize = growSize; + m_pHead = NULL; + } + + void discardAll() + { + while (m_pHead != NULL) { + TempStringBlock *p = m_pHead; + m_pHead = m_pHead->Detach(); + delete p; + } + } + + ~DrTempStringPool() + { + discardAll(); + } + + void setTempStringGrowSize(Size_t growSize) + { + m_growSize = growSize; + } + + const char *dupStr(const char *pszStr) + { + return dupStr( pszStr, pszStr ? strlen(pszStr) : 0 ); + } + + const char *dupStr(const char *pszStr, size_t len ) + { + if (pszStr == NULL) { + return NULL; + } + char *pszOut = (char *)allocMem(len+1); + memcpy(pszOut, pszStr, len+1); + return pszOut; + } + + // Not multibyte aware + const char *dupStrLowerCase(const char *pszStr, size_t* pLength = NULL ) + { + if (pszStr == NULL) { + return NULL; + } + Size_t len = strlen(pszStr); + if ( pLength != NULL ) + *pLength = len; + char *pszOut = (char *)allocMem(len+1); + memcpy(pszOut, pszStr, len+1); + _strlwr(pszOut); + return pszOut; + } + + const char *dupStrUpperCase(const char *pszStr, size_t* pLength = NULL) + { + if (pszStr == NULL) { + return NULL; + } + Size_t len = strlen(pszStr); + if ( pLength != NULL ) + *pLength = len; + char *pszOut = (char *)allocMem(len+1); + memcpy(pszOut, pszStr, len+1); + _strupr(pszOut); + return pszOut; + } + + const void *dupMem(const void *pMem, Size_t len) + { + if (pMem == NULL) { + return NULL; + } + void *pOut = allocMem(len); + memcpy(pOut, pMem, len); + return pOut; + } + + void *allocMem(Size_t len) + { + if (m_pHead == NULL || len > m_pHead->GetLength()) { + Size_t n = m_growSize; + if (n < len) { + n = len; + } + while (m_growSize < n) { + m_growSize = 2 * m_growSize; + } + m_pHead = TempStringBlock::Alloc(m_growSize, m_pHead); + } + return m_pHead->ReserveData(len); + } + + inline const char *setString(const char *pszStr, bool fCopy = true) + { + if (fCopy) { + pszStr = dupStr(pszStr); + } + + return pszStr; + } + + inline const void *setBlob(const void *pMem, Size_t len, bool fCopy = true) + { + if (fCopy) { + pMem = dupMem(pMem, len); + } + + return pMem; + } + +private: + + Size_t m_growSize; + TempStringBlock *m_pHead; + +}; + +// An internalized string pool manages reusable strings that can be located by their hash. Strings are put into the table, and references to the same string +// can be reused as often as desired, as long as the DrInternalizedStringPool is not deleted. +// There is no mechanism for deleting entries once they are added, so this class is only useful for string sets that don't grow unbounded (e.g., datacenter machine names, +// volume names, etc.) +// +// This class is threadsafe +class DrInternalizedStringPool : public DrTempStringPool +{ +private: + static const UInt32 k_internalizedStringMagic = (UInt32)'rtSi'; + + class InternalizedStringHeader + { + public: + UInt32 magic; + UInt32 hash; + InternalizedStringHeader *pNext; + }; + +public: + DrInternalizedStringPool(UInt32 hashTableSize = 20011, Size_t growSize = 65536) : DrTempStringPool(growSize) + { + m_hashTableSize = hashTableSize; + m_pBuckets = new InternalizedStringHeader *[m_hashTableSize]; + LogAssert(m_pBuckets != NULL); + memset(m_pBuckets, 0, m_hashTableSize * sizeof(InternalizedStringHeader *)); + m_Mutex = new MSMutex (); + } + + ~DrInternalizedStringPool() + { + delete[] m_pBuckets; + m_pBuckets = NULL; + m_Mutex = NULL; + } + + // Computes the hash of a string, and if pLength is supplied, its length as well + static UInt32 StringHash(const char *pszString, /* out */ Size_t *pLength = NULL) + { + LogAssert(pszString != NULL); + + Size_t length = strlen( pszString ); + + if (pLength != NULL) { + *pLength = length; + } + + return DrHash32::Compute( pszString, length ); + } + + // Returns a pointer to a duplicate of the given string that will be valid as long as this + // pool is not deleted. It is guaranteed that the return value will be the same for any two identical strings. + const char *InternalizeString(const char *pszString); + + // Same as InternalizeString, but normalizes the string to lower case + // before internalizing. + const char *InternalizeStringLowerCase(const char *pszString); + + // Same as InternalizeString, but normalizes the string to upper case + // before internalizing. + const char *InternalizeStringUpperCase(const char *pszString); + + // Verifies that a string is internalized. This should only be used for Assertions, since it is dangerous to do for non-internalized strings. + static bool IsInternalized(const char *pszString) + { + return GetHeaderOfInternalizedString(pszString)->magic == k_internalizedStringMagic; + } + + static UInt32 HashOfInternalizedString(const char *pszInternalizedString) + { + LogAssert(pszInternalizedString != NULL); + const InternalizedStringHeader *pHeader = GetHeaderOfInternalizedString(pszInternalizedString); + LogAssert(pHeader->magic == k_internalizedStringMagic); + return pHeader->hash; + } + +protected: + static const InternalizedStringHeader *GetHeaderOfInternalizedString(const char *pszInternalizedString) + { + return ((const InternalizedStringHeader *)(const void *)pszInternalizedString) - 1; + } + +private: + Ptr m_Mutex; + InternalizedStringHeader **m_pBuckets; + UInt32 m_hashTableSize; +}; + +extern DrInternalizedStringPool g_DrInternalizedStrings; + diff --git a/DryadVertex/VertexHost/system/classlib/include/DrTags.h b/DryadVertex/VertexHost/system/classlib/include/DrTags.h new file mode 100644 index 0000000..ca9039b --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrTags.h @@ -0,0 +1,410 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// Tags used for BeginTag/EndTag + +// The must *only* be DEFINE_DRTAG directives in this file + +DEFINE_DRTAG(DrTag_InvalidTag, 0, "InvalidTag") // An invalid tag value + + +// DRM Tags +DEFINE_DRTAG(Drm_CreateStream, 1, "CreateStream") +DEFINE_DRTAG(Drm_CreateStreamResponse, 2, "CreateStreamResponse") +DEFINE_DRTAG(Drm_AppendStream, 3, "AppendStream") +DEFINE_DRTAG(Drm_AppendStreamResponse, 4, "AppendStreamResponse") +DEFINE_DRTAG(Drm_EnumDirectory, 5, "EnumDirectory") +DEFINE_DRTAG(Drm_EnumDirectoryResponse, 6, "EnumDirectoryResponse") +DEFINE_DRTAG(Drm_GetVolumeInfo, 7, "GetVolumeInfo") +DEFINE_DRTAG(Drm_GetVolumeInfoResponse, 8, "GetVolumeInfoResponse") +DEFINE_DRTAG(Drm_DeleteStream, 9, "DeleteStream") +DEFINE_DRTAG(Drm_DeleteStreamResponse, 10, "DeleteStreamResponse") +DEFINE_DRTAG(DrTag_AppendExtentRequest, 11, "AppendExtentRequest") +DEFINE_DRTAG(DrTag_AppendExtentResponse, 12, "AppendExtentResponse") +DEFINE_DRTAG(Drm_ReadStream, 13, "ReadStream") +DEFINE_DRTAG(Drm_ReadStreamResponse, 14, "ReadStreamResponse") +DEFINE_DRTAG(Drm_UpdateExtentMetadataEvent, 15, "UpdateExtentMetadataEvent") +DEFINE_DRTAG(Drm_UpdateExtentInstanceMetadataEvent, 16, "UpdateExtentInstanceMetadataEvent") +DEFINE_DRTAG(DrTag_ExtentNodeStatusEvent, 17, "ExtentNodeStatusEvent") +DEFINE_DRTAG(DrTag_DrmGarbageCollectEvent, 19, "DrmGarbageCollect") +DEFINE_DRTAG(DrTag_DrmSyncEnEvent, 20, "DrmSyncEN") +DEFINE_DRTAG(Drm_RenameStream, 21, "RenameStream") +DEFINE_DRTAG(Drm_RenameStreamResponse, 22, "RenameStreamResponse") +DEFINE_DRTAG(Drm_NewVolumeEvent, 23, "NewVolumeEvent") +DEFINE_DRTAG(DrTag_AppendExtentRangeRequest, 24, "AppendExtentRangeRequest") +DEFINE_DRTAG(DrTag_AppendExtentRangeResponse, 25, "AppendExtentRangeResponse") +DEFINE_DRTAG(DrTag_DrmSealExtentRequest, 26, "DrmSealExtentRequest") +DEFINE_DRTAG(DrTag_DrmSealExtentResponse, 27, "DrmSealExtentResponse") +DEFINE_DRTAG(DrTag_DrmSetStreamProperties, 28, "SetStreamProperties") +DEFINE_DRTAG(DrTag_DrmSetStreamPropertiesResponse, 29, "SetStreamPropertiesResponse") +// the following four tags are for synchronized version of metadata update +DEFINE_DRTAG(Drm_UpdateExtentMetadataRequest, 30, "UpdateExtentMetadataRequest") +DEFINE_DRTAG(Drm_UpdateExtentMetadataResponse, 31, "UpdateExtentMetadataResponse") +DEFINE_DRTAG(Drm_UpdateExtentInstanceMetadataRequest, 32, "UpdateExtentInstanceMetadataRequest") +DEFINE_DRTAG(Drm_UpdateExtentInstanceMetadataResponse, 33, "UpdateExtentInstanceMetadataResponse") +DEFINE_DRTAG(DrTag_AppendExtentRangeRequest2, 34, "AppendExtentRangeRequest2") +DEFINE_DRTAG(DrTag_AppendExtentRangeResponse2, 35, "AppendExtentRangeResponse2") +DEFINE_DRTAG(Drm_SubstituteStreamRequest, 36, "SubstituteStreamRequest") +DEFINE_DRTAG(Drm_SubstituteStreamResponse, 37, "SubstituteStreamResponse") +DEFINE_DRTAG(DrTag_ExtentRangeList, 38, "ExtentRangeList") +DEFINE_DRTAG(DrTag_ExtentRange, 39, "ExtentRange") +DEFINE_DRTAG(Drm_ExtentList, 51, "ExtentList") // array of CEIDs + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// REDDOG DRM extensions (shared with cosmos) +DEFINE_DRTAG(Drm_MultiModifyStreamEntry, 61, "MultiModifyStreamEntry") // a single stream mod entry within a list +DEFINE_DRTAG(Drm_MultiModifyStream, 62, "MultiModifyStream") +DEFINE_DRTAG(Drm_MultiModifyStreamResponse, 63, "MultiModifyStreamResponse") +DEFINE_DRTAG(Drm_MultiModifyExtentRange, 64, "MultiModifyExtentRange") // a single concatenate extent range within a list + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// EN Tags +// First set are for requests from csm/client +DEFINE_DRTAG(DrTag_EnAppendExtentRequest, 100, "EnAppendExtentRequest") +DEFINE_DRTAG(DrTag_EnSealExtentRequest, 101, "EnSealExtentRequest") +DEFINE_DRTAG(DrTag_EnReadExtentRequest, 102, "EnReadExtentRequest") +DEFINE_DRTAG(DrTag_EnCreateExtentRequest, 103, "EnCreateExtentRequest") +DEFINE_DRTAG(DrTag_EnStartSealingExtentInstanceRequest, 104, "EnStartSealingExtentInstanceRequest") +DEFINE_DRTAG(DrTag_EnSyncRequest, 105, "EnSyncRequest") +DEFINE_DRTAG(DrTag_EnReplicateExtentRequest, 106, "EnReplicateExtentRequest") +DEFINE_DRTAG(DrTag_EnRecoverExtentRequest, 107, "EnRecoverExtentRequest") +DEFINE_DRTAG(DrTag_EnCmdRequest, 108, "EnCmdRequest" ) +DEFINE_DRTAG(DrTag_EnReadaheadInfo, 109, "EnReadaheadInfo" ) + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// These are the matching response tags for replies to above requests +DEFINE_DRTAG(DrTag_EnAppendExtentResponse, 200, "EnAppendExtentResponse") +DEFINE_DRTAG(DrTag_EnSealExtentResponse, 201, "EnSealExtentResponse") +DEFINE_DRTAG(DrTag_EnReadExtentResponse, 202, "EnReadExtentResponse") +DEFINE_DRTAG(DrTag_EnCreateExtentResponse, 203, "EnCreateExtentResponse") +DEFINE_DRTAG(DrTag_EnStartSealingExtentInstanceResponse, 204, "EnStartSealingExtentInstanceResponse") +DEFINE_DRTAG(DrTag_EnSyncResponse, 205, "EnSyncResponse") +DEFINE_DRTAG(DrTag_EnReplicateExtentResponse, 206, "EnReplicateExtentResponse") +DEFINE_DRTAG(DrTag_EnRecoverExtentResponse, 207, "EnRecoverExtentResponse") +DEFINE_DRTAG(DrTag_EnCmdResponse, 208, "EnCmdResponse" ) + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// Second set are for requests between ENs to handle replication +DEFINE_DRTAG(DrTag_EnReplicateCreateExtentRequest, 300, "EnReplicateCreateExtentRequest") +DEFINE_DRTAG(DrTag_EnReplicateAppendExtentRequest, 301, "EnReplicateAppendExtentRequest") +//DEFINE_DRTAG(DrTag_EnReplicateExtentLengthRequest, 302, "EnReplicateExtentLengthRequest") +DEFINE_DRTAG(DrTag_EnReplicateSealRequest, 303, "EnReplicateSealRequest") + +// And these are the matching responses to the above requests +DEFINE_DRTAG(DrTag_EnReplicateCreateExtentResponse, 400, "EnReplicateCreateExtentResponse") +DEFINE_DRTAG(DrTag_EnReplicateAppendExtentResponse, 401, "EnReplicateAppendExtentResponse") +DEFINE_DRTAG(DrTag_EnReplicateExtentLengthResponse, 402, "EnReplicateExtentLengthResponse") +DEFINE_DRTAG(DrTag_EnReplicateSealResponse, 403, "EnReplicateSealResponse") + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// Generic messages +DEFINE_DRTAG(DrTag_DrMalformedMessage, 500, "DrMalformedMessage") +DEFINE_DRTAG(DrTag_DrErrorResponseMessage, 501, "DrErrorResponseMessage") + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// Name Resolution tags +DEFINE_DRTAG(DrTag_DrResolveNameRequest, 600, "DrResolveNameRequest") +DEFINE_DRTAG(DrTag_DrResolveNameResponse, 601, "DrResolveNameResponse") +DEFINE_DRTAG(DrTag_DrGetNameResolutionMapRequest, 602, "DrGetNameResolutionMapRequest") +DEFINE_DRTAG(DrTag_DrGetNameResolutionMapResponse, 603, "DrGetNameResolutionMapResponse") +DEFINE_DRTAG(DrTag_DrRegisterNamesRequest, 604, "DrRegisterNamesRequest") +DEFINE_DRTAG(DrTag_DrRegisterNamesResponse, 605, "DrRegisterNamesResponse") +DEFINE_DRTAG(DrTag_DrUpdateNameResolution, 606, "DrUpdateNameResolution") +DEFINE_DRTAG(DrTag_DrUpdateNameResolutionResponse, 607, "DrUpdateNameResolutionResponse") +DEFINE_DRTAG(DrTag_DrHostNameList, 608, "DrHostNameList") + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// ProcessNode tags +DEFINE_DRTAG(DrTag_PnCreateProcessRequest, 700, "PnCreateProcessRequest") +DEFINE_DRTAG(DrTag_PnCreateProcessResponse, 701, "PnCreateProcessResponse") +DEFINE_DRTAG(DrTag_PnGetProcessStatusRequest, 702, "PnGetProcessStatusRequest") +DEFINE_DRTAG(DrTag_PnGetProcessStatusResponse, 703, "PnGetProcessStatusResponse") +DEFINE_DRTAG(DrTag_PnTerminateProcessRequest, 704, "PnTerminateProcessRequest") +DEFINE_DRTAG(DrTag_PnTerminateProcessResponse, 705, "PnTerminateProcessResponse") +DEFINE_DRTAG(DrTag_PnSetProcessPropertyRequest, 706, "PnSetProcessPropertyRequest") +DEFINE_DRTAG(DrTag_PnSetProcessPropertyResponse, 707, "PnSetProcessPropertyResponse") +DEFINE_DRTAG(DrTag_PnGetProcessPropertyRequest, 708, "PnGetProcessPropertyRequest") +DEFINE_DRTAG(DrTag_PnGetProcessPropertyResponse, 709, "PnGetProcessPropertyResponse") +DEFINE_DRTAG(DrTag_PnEnumerateProcessesRequest, 710, "PnEnumerateProcessesRequest") +DEFINE_DRTAG(DrTag_PnEnumerateProcessesResponse, 711, "PnEnumerateProcessesResponse") +DEFINE_DRTAG(DrTag_PnUserData, 712, "PnUserData") +DEFINE_DRTAG(DrTag_PnProcessStat, 713, "PnProcessStat") +DEFINE_DRTAG(DrTag_PnJobStat, 714, "PnJobStat") +DEFINE_DRTAG(DrTag_PnGetJobStatRequest, 715, "PnGetJobStatRequest") +DEFINE_DRTAG(DrTag_PnGetJobStatResponse, 716, "PnGetJobStatResponse") + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// Pipe protocol tags +DEFINE_DRTAG(DrTag_DrRendezvousConnect, 800, "DrRendezvousConnect") +DEFINE_DRTAG(DrTag_DrRendezvousConnected, 801, "DrRendezvousConnected") +DEFINE_DRTAG(DrTag_DrRendezvousConnectedAck, 802, "DrRendezvousConnectedAck") +DEFINE_DRTAG(DrTag_DrRendezvousConnectFailed, 803, "DrRendezvousConnectFailed") +DEFINE_DRTAG(DrTag_DrRendezvousRedirect, 804, "DrRendezvousRedirect") +DEFINE_DRTAG(DrTag_DrRendezvousWait, 805, "DrRendezvousWait") + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// Execution monitoring tags +DEFINE_DRTAG(DrTag_JmJobExecutionStatistics, 900, "JmJobExecutionStatistics") +DEFINE_DRTAG(DrTag_JmVertexClassExecutionStatistics, 901, "JmVertexClassExecutionStatistics") + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// Cache Manager request tags +DEFINE_DRTAG(DrTag_CmSeedRequest, 1000, "CmSeedRequest") +DEFINE_DRTAG(DrTag_CmQueryFileRequest, 1001, "CmQueryFileRequest") + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// Cache Manager response tags +DEFINE_DRTAG(DrTag_CmSeedResponse, 1010, "CmSeedResponse") +DEFINE_DRTAG(DrTag_CmQueryFileResponse, 1011, "CmQueryFileResponse") + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// Cache Manager event tags +DEFINE_DRTAG(DrTag_CmGossipEvent, 1020, "CmGossipEvent") + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// Dryad Task Scheduler tags +DEFINE_DRTAG(DrTag_ScheduleTaskRequest, 1100, "ScheduleTaskRequest") +DEFINE_DRTAG(DrTag_ScheduleTaskResponse, 1101, "ScheduleTaskResponse") +DEFINE_DRTAG(DrTag_TaskJournalStart, 1102, "TaskJournalStart") +DEFINE_DRTAG(DrTag_TaskJournalEnd, 1103, "TaskJournalEnd") +DEFINE_DRTAG(DrTag_GetTaskInfoRequest, 1104, "GetTaskInfoRequest") +DEFINE_DRTAG(DrTag_GetTaskInfoResponse, 1105, "GetTaskInfoResponse") +DEFINE_DRTAG(DrTag_GetTaskJournalRequest, 1106, "GetTaskJournalRequest") +DEFINE_DRTAG(DrTag_GetTaskJournalResponse, 1107, "GetTaskJournalResponse") +DEFINE_DRTAG(DrTag_DeleteTaskRequest, 1108, "DeleteTaskRequest") +DEFINE_DRTAG(DrTag_DeleteTaskResponse, 1109, "DeleteTaskResponse") +DEFINE_DRTAG(DrTag_EnumerateTasksRequest, 1110, "EnumerateTasksRequest") +DEFINE_DRTAG(DrTag_EnumerateTasksResponse, 1111, "EnumerateTasksResponse") +DEFINE_DRTAG(DrTag_EnumeratedTask, 1112, "EnumeratedTask") +DEFINE_DRTAG(DrTag_TaskJournalEntry, 1113, "TaskJournalEntry") +DEFINE_DRTAG(DrTag_EnumerateSchedulerLogRequest, 1114, "EnumerateSchedulerLogRequest") +DEFINE_DRTAG(DrTag_EnumerateSchedulerLogResponse,1115, "EnumerateSchedulerLogResponse") +DEFINE_DRTAG(DrTag_SetSchedulerPropertiesRequest,1116, "SetSchedulerPropertiesRequest") +DEFINE_DRTAG(DrTag_SetSchedulerPropertiesResponse,1117, "SetSchedulerPropertiesResponse") +DEFINE_DRTAG(DrTag_SetTaskPropertiesRequest,1118, "SetTaskPropertiesRequest") +DEFINE_DRTAG(DrTag_SetTaskPropertiesResponse,1119, "SetTaskPropertiesResponse") +DEFINE_DRTAG(DrTag_AppendJobDataRequest,1120, "AppendJobDataRequest") +DEFINE_DRTAG(DrTag_AppendJobDataResponse,1121, "AppendJobDataResponse") + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// jforbes: Tags between 1122-1139 reserved for task scheduler + +// Dryad Task scheduler log +DEFINE_DRTAG(DrTag_TaskLogTaskSubmitted, 1140, "TaskSubmitted") +DEFINE_DRTAG(DrTag_TaskLogTaskDeleted, 1141, "TaskDeleted") +DEFINE_DRTAG(DrTag_TaskLogTaskCompleted, 1142, "TaskCompleted") +DEFINE_DRTAG(DrTag_TaskLogJobStarted, 1143, "JobStarted") +DEFINE_DRTAG(DrTag_TaskLogJobEnded, 1144, "JobEnded") +DEFINE_DRTAG(DrTag_TaskLogGainedMasterStatus, 1145, "GainedMasterStatus") +DEFINE_DRTAG(DrTag_TaskLogSetSchedulerProperties, 1146, "SetSchedulerProperties") +DEFINE_DRTAG(DrTag_TaskLogSetTaskProperties, 1147, "SetTaskProperties") +DEFINE_DRTAG(DrTag_TaskLogJobData, 1148, "JobData") + +// This is not actually a log entry, but the custom property bag part of a TaskLogJobData entry +DEFINE_DRTAG(DrTag_TaskLogJobDataContents, 1149, "JobDataContents") + +// This is a tag that will be used within a DrTag_TaskLogJobDataContents +// It describes data that has been uploaded Dryad URL, typically by Dryad, that should be +// preserved. Sub-properties will include a Dryad URI and a flag indicating whether it is +// temporary output (e.g. stdout.txt for the job) that should be cleaned up by the scheduler +// at some later point. +DEFINE_DRTAG(DrTag_JobOutput, 1150, "JobOutput") + +// This entry indicates the task scheduler appended its state to the permanent storage stream +// at a particular sequence number. +DEFINE_DRTAG(DrTag_TaskLogArchive, 1151, "Archive") + +// This entry indicates an error or warning related to the scheduler +// This may include being unable to start a job +DEFINE_DRTAG(DrTag_TaskLogError, 1152, "SchedulerError") + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// jforbes: Tags between 1153-1199 reserved for task scheduler + +// Dryad Async RPC Server Protocol tags +DEFINE_DRTAG(DrTag_RpcTunnelTransportRequest, 1200, "RpcTunnelTransportRequest") +DEFINE_DRTAG(DrTag_RpcTunnelTransportRequestHeader, 1201, "RpcTunnelTransportRequestHeader") +DEFINE_DRTAG(DrTag_RpcTunnelTransportRequestBody, 1202, "RpcTunnelTransportRequestBody") +DEFINE_DRTAG(DrTag_RpcTunnelTransportResponse, 1203, "RpcTunnelTransportResponse") +DEFINE_DRTAG(DrTag_RpcTunnelTransportResponseHeader, 1204, "RpcTunnelTransportResponseHeader") +DEFINE_DRTAG(DrTag_RpcTunnelTransportResponseBody, 1205, "RpcTunnelTransportResponseBody") +DEFINE_DRTAG(DrTag_RpcTunnelTransportSessionOpenRequest, 1206, "RpcTunnelTransportSessionOpenRequest") +DEFINE_DRTAG(DrTag_RpcTunnelTransportSessionOpenResponse, 1207, "RpcTunnelTransportSessionOpenResponse") +DEFINE_DRTAG(DrTag_RpcTunnelTransportSessionCloseRequest, 1208, "RpcTunnelTransportSessionCloseRequest") +DEFINE_DRTAG(DrTag_RpcTunnelTransportSessionCloseResponse, 1209, "RpcTunnelTransportSessionCloseResponse") +DEFINE_DRTAG(DrTag_RpcTunnelTransportSessionEnqueueRequest, 1210, "RpcTunnelTransportSessionEnqueueRequest") +DEFINE_DRTAG(DrTag_RpcRequest, 1211, "RpcRequest") +DEFINE_DRTAG(DrTag_RpcTunnelTransportSessionEnqueueResponse, 1212, "RpcTunnelTransportSessionEnqueueResponse") +DEFINE_DRTAG(DrTag_RpcTunnelTransportSessionPollRequest, 1213, "RpcTunnelTransportSessionPollRequest") +DEFINE_DRTAG(DrTag_RpcTunnelTransportSessionPollResponse, 1214, "RpcTunnelTransportSessionPollResponse") +DEFINE_DRTAG(DrTag_RpcResponse, 1215, "RpcResponse") +DEFINE_DRTAG(DrTag_RpcRequestHeader, 1216, "RpcRequestHeader") +DEFINE_DRTAG(DrTag_RpcGenerateEmptyResultRequest, 1217, "RpcGenerateEmptyResultRequest") +DEFINE_DRTAG(DrTag_RpcResponseHeader, 1218, "RpcResponseHeader") +DEFINE_DRTAG(DrTag_RpcSessionKey, 1219, "RpcSessionKey") +DEFINE_DRTAG(DrTag_RpcProxySendMessageRequest, 1220, "RpcProxySendMessageRequest") +DEFINE_DRTAG(DrTag_RpcProxySendMessageResponse, 1221, "RpcProxySendMessageResponse") +DEFINE_DRTAG(DrTag_DrWrappedProtocolMessage, 1222, "DrWrappedProtocolMessage") +DEFINE_DRTAG(DrTag_RpcRequestPacketHeader, 1223, "RpcRequestPacketHeader") +DEFINE_DRTAG(DrTag_RpcResponsePacketHeader, 1224, "RpcResponsePacketHeader") +DEFINE_DRTAG(DrTag_CompoundOpaqueKey, 1225, "CompundOpaqueKey") +DEFINE_DRTAG(DrTag_RpcRequestPacket, 1226, "RpcRequestPacket") +DEFINE_DRTAG(DrTag_RpcResponsePacket, 1227, "RpcResponsePacket") + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// +// cosmos resource selector Tags +// +DEFINE_DRTAG(DrTag_GetResourceListRequest, 1300, "DrTag_GetResourceListRequest") +DEFINE_DRTAG(DrTag_GetResourceListResponse, 1301, "DrTag_GetResourceListResponse") + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// base protocol tags +DEFINE_DRTAG(DrTag_DrMessage, 9000, "DrMessage") +DEFINE_DRTAG(DrTag_DrClientMessageHeader, 9001, "DrClientMessageHeader") +DEFINE_DRTAG(DrTag_DrServerMessageHeader, 9002, "DrServerMessageHeader") +DEFINE_DRTAG(DrTag_DrServiceDescriptor, 9003, "DrServiceDescriptor") +DEFINE_DRTAG(DrTag_DrExtentNodeInfo, 9004, "DrExtentNodeInfo") +DEFINE_DRTAG(DrTag_DrStreamInfo, 9005, "DrStreamInfo") +DEFINE_DRTAG(DrTag_DrExtentInfo, 9006, "DrExtentInfo") +DEFINE_DRTAG(DrTag_DrExtentInstanceInfo, 9007, "DrExtentInstanceInfo") +DEFINE_DRTAG(DrTag_DrExtentInstanceMetadata, 9008, "DrExtentInstanceMetadata") +DEFINE_DRTAG(DrTag_DrNameResolutionMapEntry, 9009, "DrNameResolutionMapEntry") +DEFINE_DRTAG(DrTag_DrEnSyncDirective, 9010, "DrEnSyncDirective") +DEFINE_DRTAG(DrTag_DrHostAndPort, 9011, "DrHostAndPort") +DEFINE_DRTAG(DrTag_DrSimpleProcessFile, 9012, "DrSimpleProcessFile") +DEFINE_DRTAG(DrTag_DrPingRequest, 9013, "DrPingRequest") +DEFINE_DRTAG(DrTag_DrPingResponse, 9014, "DrPingResponse") +DEFINE_DRTAG(DrTag_DrProcessConstraints, 9015, "DrProcessConstraints") +DEFINE_DRTAG(DrTag_DrNodeErrorInfo, 9016, "DrNodeErrorInfo") +DEFINE_DRTAG(DrTag_DrProcessInfo, 9017, "DrProcessInfo") +DEFINE_DRTAG(DrTag_DrProcessPropertyInfo, 9018, "DrProcessPropertyInfo") +DEFINE_DRTAG(DrTag_DrUserDescriptor, 9019, "DrUserDescriptor") +DEFINE_DRTAG(DrTag_DrUserTicket, 9020, "DrUserTicket") +DEFINE_DRTAG(DrTag_DrJobDescriptor, 9021, "DrJobDescriptor") +DEFINE_DRTAG(DrTag_DrJobTicket, 9022, "DrJobTicket") +DEFINE_DRTAG(DrTag_DrProcessDescriptor, 9023, "DrProcessDescriptor") +DEFINE_DRTAG(DrTag_DrProcessTicket, 9024, "DrProcessTicket") +DEFINE_DRTAG(DrTag_DrParentProcess, 9025, "DrParentProcess") +DEFINE_DRTAG(DrTag_DrRootProcess, 9026, "DrRootProcess") +DEFINE_DRTAG(DrTag_DrExtentInstanceMetadataSet, 9027, "DrExtentInstanceMetadataSet") +DEFINE_DRTAG(DrTag_DrDetailedError, 9028, "DrDetailedError") + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// BUGBUG: These were added by XStore and should not be here, they will collide with Dryad integrations. +// TODO: consider breaking XStream interop to move them to the right place! + +DEFINE_DRTAG(DrTag_ParamList, 9029, "ParamList" ) +DEFINE_DRTAG(DrTag_ValueList, 9030, "ValueList" ) +DEFINE_DRTAG(DrTag_ParamValueList, 9031, "ParamValueList") +DEFINE_DRTAG(DrTag_DrStreamPolicyInfo, 9032, "StreamPolicyInfo") + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + +// NOTE NOTE NOTE +// +// 10000-10999 are reserved for Dryad + +// WARNING: DO NOT INSERT REDDOG TAGS HERE! + + +// NOTE NOTE NOTE +// +// 11000-11999 are reserved for xstore +// +// +// XStream Tags +// +// WARNING: DO NOT INSERT REDDOG TAGS BEFORE THIS POINT +// DO NOT REORDER, DELETE, OR CHANGE THE VALUE OF ANY ENTRY IN THIS FILE +DEFINE_DRTAG(XsTag_ClientMessageHeader, 11000, "ClientMessageHeader") +DEFINE_DRTAG(XsTag_ServerMessageHeader, 11001, "ServerMessageHeader") +DEFINE_DRTAG(XsTag_XStoreMessage, 11002, "XStoreMessage") +DEFINE_DRTAG(XsTag_XStoreMessageBody, 11003, "XStoreMessageBody") +DEFINE_DRTAG(DrTag_DrFailureInjectionRequest, 11004, "DrFailureInjectionRequest") +DEFINE_DRTAG(DrTag_DrFailureInjectionResponse, 11005, "DrFailureInjectionResponse") +DEFINE_DRTAG(DrTag_DrEnumFileDirRequest, 11006, "DrEnumFileDirRequest") +DEFINE_DRTAG(DrTag_DrEnumFileDirResponse, 11007, "DrEnumFileDirResponse") +DEFINE_DRTAG(DrTag_DrReadFileRequest, 11008, "DrReadFileRequest") +DEFINE_DRTAG(DrTag_DrReadFileResponse, 11009, "DrReadFileResponse") +DEFINE_DRTAG(DrTag_DrWriteFileRequest, 11010, "DrWriteFileRequest") +DEFINE_DRTAG(DrTag_DrWriteFileResponse, 11011, "DrWriteFileResponse") +DEFINE_DRTAG(DrTag_DrDirectoryEntry, 11012, "DrDirectoryEntry") +DEFINE_DRTAG(DrTag_EnAppendBlockInfo, 11013, "EnAppendBlockInfo" ) +DEFINE_DRTAG(Drm_SetStreamPolicyRequest, 11014, "SetStreamPolicyRequest") +DEFINE_DRTAG(Drm_SetStreamPolicyResponse, 11015, "SetStreamPolicyResponse") + +// SAMMCK: renamed from QueryStreamPolicyRequest to match the API +DEFINE_DRTAG(Drm_GetStreamPoliciesRequest, 11016, "GetStreamPolociesRequest") + +// SAMMCK: renamed from QueryStreamPolicyResponse to match the API +DEFINE_DRTAG(Drm_GetStreamPoliciesResponse, 11017, "QueryStreamPolicyResponse") +DEFINE_DRTAG(Drm_PolicyList, 11018, "PolicyList") +DEFINE_DRTAG(Drm_ENConfigChangeEvent, 11019, "ENConfigChangeEvent") +DEFINE_DRTAG(DrTag_DrCommandLineParams, 11020, "DrCommandLineParams" ) +DEFINE_DRTAG(DrTag_DrCommandLineResults, 11021, "DrCommandLineResults" ) +DEFINE_DRTAG(DrTag_XcPsScheduleProcessRequest, 11022, "XcPsScheduleProcessRequest") +DEFINE_DRTAG(DrTag_XcPsScheduleProcessResponse, 11023, "XcPsScheduleProcessResponse") +DEFINE_DRTAG(DrTag_DrExecuteCommandLineRequest, 11024, "DrExecuteCommandLineRequest" ) +DEFINE_DRTAG(DrTag_DrExecuteCommandLineResponse, 11025, "DrExecuteCommandLineResponse" ) +DEFINE_DRTAG(DrTag_DrReferencingStreamInfo, 11033, "ReferencingStreamInfo") +DEFINE_DRTAG(Drm_GetExtentInfoRequest, 11034, "GetExtentInfoRequest") +DEFINE_DRTAG(Drm_GetExtentInfoResponse, 11035, "GetExtentInfoResponse") +DEFINE_DRTAG(DrTag_EmbeddedDrmCommand, 11036, "EmbeddedDrmCommand") +DEFINE_DRTAG(DrTag_EmbeddedRslRequest, 11037, "EmbeddedRslRequest") +DEFINE_DRTAG(Drm_GetLazyReplicationInfo, 11038, "GetLazyReplicationInfo") +DEFINE_DRTAG(Drm_GetLazyReplicationInfoResponse, 11039, "GetLazyReplicationInfoResponse") +DEFINE_DRTAG(DrEnTag_RequestLatenciesDeprecated, 11040, "RequestLatenciesDeprecated") + +// The following two tags are used for extent +// recovery after DRM emergency rollback. +DEFINE_DRTAG(DrTag_RecoverAndAppendExtentsRequest, 11041, "RecoverAndAppendExtentsRequest") +DEFINE_DRTAG(DrTag_RecoverAndAppendExtentsResponse, 11042, "RecoverAndAppendExtentsResponse") + +// The following two tags are used to enumerate unsealable +// extents following a rollback/recovery operation. +DEFINE_DRTAG(DrTag_EnumerateUnsealableExtentsRequest, 11043, "EnumerateUnsealableExtentsRequest") +DEFINE_DRTAG(DrTag_EnumerateUnsealableExtentsResponse, 11044, "EnumerateUnsealableExtentsResponse") + +// The following two tags are deprecated and may be reused +// once all known nodes have been upgraded. +DEFINE_DRTAG(DrTag_ProtocolMessageTimestampsDeprecated, 11045, "ProtocolMessageTimestampsDeprecated") +DEFINE_DRTAG(DrEnTag_RequestLatencies, 11046, "RequestLatencies") + +// ADD ALL XSTREAM/XSTORE/REDDOG TAGS ABOVE THIS POINT! + diff --git a/DryadVertex/VertexHost/system/classlib/include/DrTagsDef.h b/DryadVertex/VertexHost/system/classlib/include/DrTagsDef.h new file mode 100644 index 0000000..cd8fa70 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrTagsDef.h @@ -0,0 +1,33 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#ifdef DEFINE_DRTAG +#undef DEFINE_DRTAG +#endif + +#define DEFINE_DRTAG(var, value, tagName) \ + static const UInt16 var = value; + +#include "DrTags.h" + +#undef DEFINE_DRTAG + diff --git a/DryadVertex/VertexHost/system/classlib/include/DrThread.h b/DryadVertex/VertexHost/system/classlib/include/DrThread.h new file mode 100644 index 0000000..6e18b09 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrThread.h @@ -0,0 +1,1036 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#ifndef __DRYADTHREAD_H__ +#define __DRYADTHREAD_H__ + +#pragma once + +//JC#include "XsAutoActivity.h" + +//Handy macro for mapping from memember of struct/class to its container object +//TODO: Isn't there already a standard MACRO defined for this? Replace if so. +//TODO: This should probably be in a common types/macros header rather than this one +#ifndef DR_GET_CONTAINER +#define DR_GET_CONTAINER(type, address, field) ((type *)( \ + (PCHAR)(address) - \ + (UINT_PTR)(&((type *)0)->field))) +#endif // if not defined DR_GET_CONTAINER + + +// Abstraction of a Thread-local storage pointer. These should be created sparingly, because they are a scarce resource. +class DrTlsVoidPointer +{ +public: + DrTlsVoidPointer() + { + m_tlsSlot = TlsAlloc(); + LogAssert(m_tlsSlot != TLS_OUT_OF_INDEXES); + } + + ~DrTlsVoidPointer() + { + TlsFree(m_tlsSlot); + } + + const void *GetConstVoidPtr() const + { + const void *p = TlsGetValue(m_tlsSlot); + LogAssert(p != NULL || GetLastError() == ERROR_SUCCESS); + return p; + } + + void *GetVoidPtr() + { + return (void *)GetConstVoidPtr(); + } + + void SetVoidPtr(const void *p) + { + BOOL fSuccess = TlsSetValue(m_tlsSlot, (void *)p); + LogAssert(fSuccess); + } + + bool IsNull() const + { + return GetConstVoidPtr() == NULL; + } + +private: + DWORD m_tlsSlot; +}; + +template class DrTlsPtr : public DrTlsVoidPointer +{ +public: + DrTlsPtr() + { + } + + operator t *() + { + return (t *)GetVoidPtr(); + } + + operator const t *() const + { + return (const t *)GetConstVoidPtr(); + } + + DrTlsPtr& operator =(const t *p) + { + SetVoidPtr(p); + return *this; + } + +}; + +class DrThreadPool; + +extern DrThreadPool *g_pDrClientThreads; + +class DrThread; +class DrJob; + +// we keep a refcounted pointer to our own thread in TLS + +extern DrTlsPtr *t_ppThread; + +__inline DrTlsPtr& DrGetThreadTlsPtr() +{ + if (t_ppThread == NULL) { + t_ppThread = new DrTlsPtr(); + LogAssert(t_ppThread != NULL); + } + + return *t_ppThread; +} + +void StartDumpCountersThread(const char *componentName); + +#define t_pThread (DrGetThreadTlsPtr()) + +// A DrJobHash is a wrapper for an optiional 32-bit hashcode that can be used to serialize +// jobs. If a hashcode is associated with a job, the job will be serialized with other jobs that +// have the same hashcode. +class DrJobHash +{ +public: + DrJobHash() + { + m_hashCode = 0; + m_fHasHash = false; + } + + DrJobHash(const DrJobHash& other) + { + m_hashCode = other.m_hashCode; + m_fHasHash = other.m_fHasHash; + } + + DrJobHash(UInt32 hash) + { + m_hashCode = hash; + m_fHasHash = true; + } + + DrJobHash& operator=(const DrJobHash& other) + { + m_hashCode = other.m_hashCode; + m_fHasHash = other.m_fHasHash; + return *this; + } + + DrJobHash& operator=(UInt32 hash) + { + m_hashCode = hash; + m_fHasHash = true; + return *this; + } + + operator UInt32() const + { + return m_hashCode; + } + + bool HasHash() const + { + return m_fHasHash; + } + + UInt32 GetHashcode() const + { + return m_hashCode; + } + + void SetHasHash(bool fHasHash = true) + { + m_fHasHash = fHasHash; + if (!fHasHash) { + m_hashCode = 0; + } + } + + void SetNoHash() + { + m_hashCode = 0; + m_fHasHash = false; + } + + void SetHandleHash(HANDLE h) + { + m_fHasHash = true; + m_hashCode = (UInt32)((ULONG_PTR)h >> 3); + } + + void SetSocketHash(SOCKET h) + { + m_fHasHash = true; + m_hashCode = (UInt32)((ULONG_PTR)h >> 3); + } + + void SetHeapItemHash(const void *p) + { + m_fHasHash = true; + m_hashCode = (UInt32)((ULONG_PTR)p >> 4); + } + + void SetUlongPtrHash(ULONG_PTR p) + { + SetHeapItemHash((const void *)p); + } + + static UInt32 GetSequentialHashcode() + { + LONG val = InterlockedIncrement(&s_nextHash); + return (UInt32)val; + } + + void SetSequentialHash() + { + m_fHasHash = true; + m_hashCode = GetSequentialHashcode(); + } + + // returns -1 if there is no hash code + int GetBucket(UInt32 numBuckets) const + { + if (!m_fHasHash) { + return -1; + } + LogAssert(numBuckets != 0); + return (int)(m_hashCode % numBuckets); + } + +private: + UInt32 m_hashCode; + bool m_fHasHash; + static volatile LONG s_nextHash; +}; + +extern DrJobHash DrNoJobHash; + +class DrThread : public DrRefCounter, public DrLockable +{ + friend class DrThreadPool; + +public: + DrThread(const char *pszThreadClass, const char *pszShortClass, DrThreadPool *pPool = NULL, int iBucket = -1); + virtual ~DrThread(); + + HANDLE GetThreadIocp() + { + return m_hIocp; + } + + DrThreadPool *GetThreadPool() + { + return m_pPool; + } + + int GetJobHashBucket() + { + return m_iBucket; + } + + // may be corrupted if called from outside this thread + // This version creates a copy for you + DrJobHash GetCurrentJobHash() + { + return m_currentJobHash; + } + + // may become invalid if called from outside this thread, or if + // the job deletes itself. + DrJob *GetCurrentJob() + { + return m_pCurrentJob; + } + + HANDLE GetThreadHandle() + { + return m_hThread; + } + + DWORD GetThreadId() + { + return m_dwThreadId; + } + + const char *GetDescription() + { + return m_strDescription; + } + + const char *GetShortTag() + { + return m_strTag; + } + + const char *GetShortThreadClass() + { + return m_strShortClass; + } + + const char *GetThreadClass() + { + return m_strClass; + } + + DrError Start( + LPSECURITY_ATTRIBUTES lpThreadAttributes = NULL, + SIZE_T dwStackSize = 0, + DWORD dwCreationFlags = 0 + ); + + bool WaitForTermination(DWORD dwMilliseconds = INFINITE) + { + DWORD ret = WaitForSingleObject(m_hThread, dwMilliseconds); + //m_e2eTransfer.ProcessReceive(); + return (ret == WAIT_OBJECT_0); + } + + // param points to the DrThread. + static DWORD WINAPI ThreadEntryStatic(void * param); + + void AttachToCurrentThread(); + + // This thread object may be deleted if the current thread is the only reference + void DetachFromCurrentThread(); + + // called whenever thread metadata is updated that might affect the description (e.g., threadid) + // by default, builds a simple description and tag from from the long and short class name and the thread id. + virtual void UpdateTagAndDescription(); + +protected: + // subclass gets called here when thread starts + virtual DWORD ThreadEntry() = 0; + +protected: + DrThreadPool *m_pPool; // NULL if not owned by a thread pool + int m_iBucket; + DrJobHash m_currentJobHash; + DrJob *m_pCurrentJob; // May become invalid as the job runs if it deletes itself + HANDLE m_hThread; + DWORD m_dwThreadId; + DrStr64 m_strDescription; + DrStr32 m_strTag; + DrStr16 m_strShortClass; + DrStr32 m_strClass; + HANDLE m_hIocp; // I/O completion port used by this thread. The thread doesn't automatically close this. +}; + +// An unmanaged thread class that can be attached to a running win32 thread +// The ThreadEntry is never called. +class DrGenericThread : public DrThread +{ +public: + DrGenericThread(const char *pszThreadClass = "DrGenericThread", const char *pszShortClass = "????") : DrThread(pszThreadClass, pszShortClass) + { + } + + virtual DWORD ThreadEntry() + { + LogAssert(false); + return (DWORD)-1; + } +}; + +DrThread *DrGenerateCurrentThread(); + +// creates a DrGenericThread object if this thread is not yet managed by cosmos libraries +inline DrThread *DrGetCurrentThread() +{ + DrThread *pThread = t_pThread; + if (pThread == NULL) { + pThread = DrGenerateCurrentThread(); + } + return pThread; +} + +/* + * DrJob + * + * This is an abstract base class for a Dryad job. + * User must implement the JobReady and JobFailed methods. + * See DrEmbeddedJob below for an alternative usage pattern + */ + +class DrJob : protected DryadHeapItem +{ + friend class DrThreadPool; + friend class DrTimerThread; + friend class DrPoolWorkerThread; + +public: + DrJob() + { + ZeroMemory(&m_overlapped, sizeof(m_overlapped)); + m_pDefaultThreadPool = NULL; + m_postedError = DrError_OK; + m_isActiveTimer=false; + m_expiryTime = 0; + } + + virtual ~DrJob() + { + } + + void SetJobHash(const DrJobHash& jobHash) + { + m_jobHash = jobHash; + } + + void SetJobHash(UInt32 hash) + { + m_jobHash = hash; + } + + const DrJobHash& GetJobHash() + { + return m_jobHash; + } + + //Return the overlapped structure associated with this job + //Use this when submitting the job as an IO operation + OVERLAPPED * GetOverlapped() + { + return &m_overlapped; + } + + HANDLE GetOverlappedEvent() + { + return m_overlapped.hEvent; + } + + void SetOverlappedEvent(HANDLE hEvent) + { + m_overlapped.hEvent = hEvent; + } + + UInt64 GetOverlappedOffset() + { + // Assumes little endian + return *(UInt64 *)&(m_overlapped.Offset); + } + + void SetOverlappedOffset(UInt64 offset) + { + // Assumes little endian + *(UInt64 *)&(m_overlapped.Offset) = offset; + } + + //Return the job object associated with a specific overlapped structure + static DrJob * MapOverlappedToJob(OVERLAPPED * pOverlapped) + { + return DR_GET_CONTAINER(DrJob, pOverlapped, m_overlapped); + } + + DrThreadPool *GetDefaultThreadPool() + { + return (m_pDefaultThreadPool != NULL) ? m_pDefaultThreadPool : g_pDrClientThreads; + } + + void SetDefaultThreadPool(DrThreadPool *pThreadPool) + { + m_pDefaultThreadPool = pThreadPool; + } + + //Called when a job is completed OK. + virtual void JobReady(DWORD numBytes, ULONG_PTR key)=0; + + //Called when a job fails. errorCode gives the reason + virtual void JobFailed(DWORD numBytes, ULONG_PTR key, DrError errorCode)=0; + +private: + + bool IsHigherPriorityThan(DryadHeapItem *other) + { + DrJob * otherTimer = (DrJob*) other; + + return (((int ) (otherTimer->m_expiryTime - m_expiryTime)) > 0); + } + +protected: + OVERLAPPED m_overlapped; + DrError m_postedError; + DrJobHash m_jobHash; + bool m_isActiveTimer; + DWORD m_expiryTime; + DrThreadPool *m_pDefaultThreadPool; +}; + + +/* + * DrTimer + * + * This is an abstract base class for a Dryad timer. + * User must implement the TimerFired method. + * See DrEmbeddedTimer below for an alternative usage pattern + */ + + +class DrTimer : + public DrJob +{ + friend class DrThreadPool; + +public: + + + DrTimer() + { + } + + virtual ~DrTimer() + { + } + + DWORD GetExpiryTime() + { + return m_expiryTime; + } + + //Implement this to do work when the timer goes off + virtual void TimerFired(DWORD firedTime)=0; + + //Called when a job is completed OK. + virtual void JobReady(DWORD numBytes, ULONG_PTR key) + { + TimerFired(numBytes); + } + + //Called when a job fails. errorCode gives the reason + virtual void JobFailed(DWORD numBytes, ULONG_PTR key, DrError errorCode) + { + LogAssert(false); + } + +}; + + +/* + * DrThreadPoolUser, DrEmbeddedJob and DrEmbeddedTimer + * + * These classes provide an alternative approach to using the timer and job classes. + * Rather than having to implement a new class for each timer and/or job, simply + * make the embedded job/timer a member of your class and inherit from the ThreadPoolUser + * interface class. This is useful when you've got a single object that wants to handle + * multiple timers and jobs running at once. + */ + +class DrThreadPoolUser +{ +public: + + virtual void JobReady(DrJob * job, DWORD numBytes, ULONG_PTR key)=0; + + virtual void JobFailed(DrJob * job, DWORD numBytes, ULONG_PTR key, DrError errorCode)=0; + + virtual void TimerFired(DrTimer * timer, DWORD firedTime)=0; + +}; + +class DrEmbeddedJob : public DrJob +{ +public: + + DrEmbeddedJob() : DrJob() + { m_user=NULL; }; + + //Call this before using it to assign the user object + void Initialize(DrThreadPoolUser * user) + { m_user=user; }; + + void JobReady(DWORD numBytes, ULONG_PTR key) + { m_user->JobReady(this, numBytes, key); }; + + virtual void JobFailed(DWORD numBytes, ULONG_PTR key, DrError errorCode) + { m_user->JobFailed(this, numBytes, key, errorCode); }; + +private: + + DrThreadPoolUser * m_user; +}; + + +class DrEmbeddedTimer : public DrTimer +{ +public: + + DrEmbeddedTimer() : DrTimer() + { m_user=NULL; }; + + //Call this before using it to assign the user object + void Initialize(DrThreadPoolUser * user) + { m_user=user; }; + + void TimerFired(DWORD firedTime) + { m_user->TimerFired(this, firedTime); }; + +private: + + DrThreadPoolUser * m_user; +}; + + +class DrTimerThread : public DrThread +{ +public: + DrTimerThread(DrThreadPool *pPool); + virtual ~DrTimerThread(); + + virtual DWORD ThreadEntry(); + + void ScheduleTimerMs(DrJob * timer, DWORD delay); + bool CancelTimer(DrJob * timer); + + void Signal() + { + BOOL fSuccess = SetEvent(m_timerEventHandle); + LogAssert(fSuccess); + } + +private: + //Heap of timers + DryadHeap m_timerHeap; + + //Event used to signal to timer thread that new timers have been queued + HANDLE m_timerEventHandle; + + //Time the timer thread is due to wake at + DWORD m_timerThreadWakesAt; + + //Set to the period the timer thread is sleeping for + //INFINITY if its never planning to wake + DWORD m_dwTimerThreadSleepPeriod; +}; + +class DrPoolWorkerThread : public DrThread +{ +public: + DrPoolWorkerThread(const char *pszThreadClass, const char *pszShortClass, DrThreadPool *pPool = NULL, int iBucket = -1) + : DrThread(pszThreadClass, pszShortClass, pPool, iBucket) + { + } + + DWORD PoolWorkerEntry(); + + virtual DWORD ThreadEntry() = 0; + +private: +}; + + +class DrPoolThread : public DrPoolWorkerThread +{ +public: + DrPoolThread(DrThreadPool *pPool) : DrPoolWorkerThread("DrPoolThread", "POOL", pPool) + { + } + + virtual DWORD ThreadEntry(); + +private: +}; + +class DrHashThread : public DrPoolWorkerThread +{ +public: + DrHashThread(DrThreadPool *pPool, int iBucket) : DrPoolWorkerThread("DrHashThread", "HASH", pPool, iBucket) + { + } + + virtual DWORD ThreadEntry(); + virtual void UpdateTagAndDescription(); + +private: +}; + + +/* + * DrThreadPool + * + * Encapsulates a pool of running Dryad threads + * Allows IO jobs and timers to be queued to the pool + */ + +class DrThreadPool +{ +public: + DrThreadPool(); + ~DrThreadPool(); + + // Create the thread pool + // The initialThreadCount defines the number of threads used for running non-hashed jobs. It can be 0. If it is -1, the number of processors + // on this machine is used. If it is 0, The hashed threads will be used to run non-hashed jobs (a random hash will be assigned to each job). + // numHashedJobs defines the number of threads used for running hashed jobs. It can be 0. If it is -1, the number of processors + // on this machine is used. If it is 0, attempts to schedule hashed jobs on this pool will fail. + DrError Initialize(int initialThreadCount = -1, int numHashedThreads = -1); + + // Tells all threads to quit and returns when they have. + // Returns DrError_Fail if any threads appeared deadlocked and unresponsive + DrError Deinitialize(); + + // Callable on any thread. Enqueue the given job on some thread in this thread pool. + // If err is not DrError_OK, the Job will be called back through JobFailed; ptherwise, it will be called through JobReady. + bool EnqueueJobWithStatus(DrJob *job, DWORD numBytes, ULONG_PTR key, DrError err); + + // Callable on any thread. Enqueue the given job on some thread in this thread pool + inline bool EnqueueJob(DrJob *job, DWORD numBytes, ULONG_PTR key) + { + return EnqueueJobWithStatus(job, numBytes, key, DrError_OK); + } + + // Associate a specified handle with the thread pool. This will cause threads + // from the pool to pick up the completitions of overlapped IO submitted + // on the specified handle. 'key' will be passed to each job that does overlapped + // IO on the file handle via the JobReady method + bool AssociateHandleWithPool(HANDLE fileHandle, ULONG_PTR key); + + // Associate a specified handle with the thread pool, with a specific hash affinity. This will cause the specified hash thread + // from the pool to pick up the completitions of overlapped IO submitted + // on the specified handle. 'key' will be passed to each job that does overlapped + // IO on the file handle via the JobReady method + bool AssociateHandleWithPoolAndHash(HANDLE fileHandle, ULONG_PTR key, const DrJobHash& jobHash); + + //Schedule a job to run a specific number of msec from now + void ScheduleTimerMs(DrJob * timer, DWORD delay); + + //Schedule a job to run a specific time interval from now + void ScheduleTimerInterval(DrJob * timer, DrTimeInterval timeInterval) + { + ScheduleTimerMs(timer, DrGetTimerMsFromInterval(timeInterval)); + } + + //Cancel a job that was scheduled with ScheduleTimer. Note this can fail and caller MUST handle that case. + //Failure means the timer has already fired (even if caller hasn't got + //the callback yet), and owner of the timer must stick around untill + //the callback is processed + bool CancelTimer(DrJob * timer); + + // Returns the completion port handle associated with the thread pool. This can be used to schedule + // completions of asynchronous operations. + HANDLE GetCompletionPortHandle() + { + return m_completionPortHandle; + } + + UInt32 GenerateRandomHash() + { + LONG n = InterlockedIncrement(&m_nextRandomHash); + return (UInt32)n; + } + + // It is a fatal error to call this with a specified hash when m_numHashBuckets is 0. + // returns -1 if there is no hash. + int GetBucketOfHash(const DrJobHash& jobHash) + { + return jobHash.GetBucket((UInt32)m_numHashBuckets); + } + + // Returns -1 if the job is not hashed + int GetBucketOfJob(DrJob *pJob) + { + int i = GetBucketOfHash(pJob->GetJobHash()); + return i; + } + + // returns NULL if the specified bucket is not supported. + // If iBucket is -1, uses the general completion pool handle. If there is no genera completion pool, picks a random hashed thread. + HANDLE GetCompletionHandleForBucket(int iBucket); + + // Returns true if the currently executing thread is an appropriate thread + // for the specified bucket. + bool CurrentThreadIsBucketThread(int iBucket) + { + bool f = true; + if (iBucket < 0) { + // no specific bucket, so we can be in any thread + } else if (iBucket < m_numHashBuckets) { + f = ((DrThread *)m_rgHashedThreads[iBucket] == DrGetCurrentThread()); + } else { + // bucket out of range, so can't be in the bucket + f = false; + } + return f; + } + + // Returns true if the currently executing thread is an appropriate thread + // for the specified job hash. + bool CurrentThreadIsHashThread(const DrJobHash& jobHash) + { + return CurrentThreadIsBucketThread(GetBucketOfHash(jobHash)); + } + + // Returns true if the currently executing thread is an appropriate thread + // for the specified job. + bool CurrentThreadIsJobThread(DrJob *pJob) + { + return CurrentThreadIsHashThread(GetBucketOfHash(pJob->GetJobHash())); + } + + bool ShouldQuit() + { + return m_threadsShouldQuit; + } + + OVERLAPPED *GetCommonOv() + { + return &m_ov; + } + +private: + OVERLAPPED m_ov; + + // This is the io completion port handle tied to this thread pool for non-hashed jobs + // Currently this value is given to all cosmos threads in the thread pool + HANDLE m_completionPortHandle; + + // Number of worker threads we've spun up for non-hashed jobs + int m_createdThreadCount; + + // Array of running thread handles for non-hashed jobs. Length is m_createdThreadCount. + DrPoolThread **m_rgWorkerThreads; // refcounted pointers + + int m_numHashBuckets; + + // Number of worker threads we've spun up for hashed jobs + int m_hashedThreadCount; + + // Array of thread info for hashed threads. Length is m_hashedThreadCount. + DrHashThread **m_rgHashedThreads; + + //Set to true when all threads should exit + volatile bool m_threadsShouldQuit; + + //Thread that watches the timer heap and farms expired timers off to worker threads + DrRef m_pTimerThread; + + volatile LONG m_nextRandomHash; + + //Tells all worker threads to quit, waits for them to do so and + //closes our handles to them. Returns DrError_Fail if it thinks + //threads are deadlocked and not responding + DrError CloseWorkerThreads(); +}; + +// this version does not implement IDrRefCounter; use DrRefCountedJob if you do not need multiple inheritance +class IDrRefCountedJob : public DrJob, public IDrRefCounter +{ +public: + IDrRefCountedJob() + { + m_pThreadPool = NULL; + } + + virtual ~IDrRefCountedJob() + { + } + + void SetThreadPool(DrThreadPool *pThreadPool) + { + m_pThreadPool = pThreadPool; + } + + void ScheduleJob(DrThreadPool *pThreadPool, DrTimeInterval t) + { + if (pThreadPool != NULL) { + m_pThreadPool = pThreadPool; + } else { + if (m_pThreadPool == NULL) { + m_pThreadPool = g_pDrClientThreads; + } + } + IncRef(); + m_pThreadPool->ScheduleTimerInterval(this, t); + } + + void ScheduleJob(DrTimeInterval t) + { + ScheduleJob(NULL, t); + } + + bool EnqueueJobWithStatus(DrThreadPool *pThreadPool, DWORD numBytes, ULONG_PTR key, DrError err) + { + if (pThreadPool != NULL) { + m_pThreadPool = pThreadPool; + } else { + if (m_pThreadPool == NULL) { + m_pThreadPool = g_pDrClientThreads; + } + } + IncRef(); + bool fok = m_pThreadPool->EnqueueJobWithStatus(this, numBytes, key, err); + if (!fok) { + DecRef(); + } + return fok; + } + + bool EnqueueJobWithStatus(DWORD numBytes, ULONG_PTR key, DrError err) + { + return EnqueueJobWithStatus(NULL, numBytes, key, err); + } + + bool EnqueueJobWithStatus(DrError err) + { + return EnqueueJobWithStatus(NULL, 0, 0, err); + } + + bool EnqueueJobWithStatus(DrThreadPool *pThreadPool, DrError err) + { + return EnqueueJobWithStatus(pThreadPool, 0, 0, err); + } + + bool EnqueueJob(DrThreadPool *pThreadPool, DWORD numBytes, ULONG_PTR key) + { + return EnqueueJobWithStatus(pThreadPool, numBytes, key, DrError_OK); + } + + bool EnqueueJob(DWORD numBytes, ULONG_PTR key) + { + return EnqueueJobWithStatus(NULL, numBytes, key, DrError_OK); + } + + bool EnqueueJob() + { + return EnqueueJob(NULL, 0, 0); + } + + bool EnqueueJob(DrThreadPool *pThreadPool) + { + return EnqueueJob(pThreadPool, 0, 0); + } + + // subclasses implement this method to provide job behavior + virtual void ExecuteJob(DrError err, DWORD numBytes, ULONG_PTR key) = 0; + + //Called when a job is completed OK. + virtual void JobReady(DWORD numBytes, ULONG_PTR key) + { + ExecuteJob(DrError_OK, numBytes, key); + DecRef(); + } + + //Called when a job fails. errorCode gives the reason. + virtual void JobFailed(DWORD numBytes, ULONG_PTR key, DrError errorCode) + { + ExecuteJob(errorCode, numBytes, key); + DecRef(); + } + + +private: + DrThreadPool *m_pThreadPool; +}; + +class DrRefCountedJob : public IDrRefCountedJob +{ +public: + DrRefCountedJob() + { + } + + virtual ~DrRefCountedJob() + { + } + +public: + DRREFCOUNTIMPL; +}; + + +// An easy-to-use refcounted timer base class. +class DrRefCountedTimer : public DrTimer, public DrRefCounter +{ +public: + DrRefCountedTimer() + { + m_pThreadPool = NULL; + } + + virtual ~DrRefCountedTimer() + { + } + + void Schedule(DrThreadPool *pThreadPool, DrTimeInterval t) + { + LogAssert(m_pThreadPool == NULL); + LogAssert(pThreadPool != NULL); + IncRef(); + m_pThreadPool = pThreadPool; + pThreadPool->ScheduleTimerInterval(this, t); + } + + // Either OnTimerFired or OnTimerCancelled is guarranteed to be called after + // Schedule. + virtual void OnTimerFired() = 0; + virtual void OnTimerCancelled() + { + } + + // Don't overload this method. Use OnTimerFired instead. + virtual void TimerFired(DWORD firedTime) + { + OnTimerFired(); + DecRef(); + } + + // Causes either OnTimerFired or OnTimerCancelled to be called as soon as possible if + // the timer has been scheduled but one has not already been called. + void Cancel() + { + if (m_pThreadPool != NULL && m_pThreadPool->CancelTimer(this)) { + OnTimerCancelled(); + DecRef(); + } + } + + bool HasBeenScheduled() + { + return (m_pThreadPool != NULL); + } + +private: + DrThreadPool *m_pThreadPool; +}; + +#endif //end if not defined __DRYADTHREAD_H__ diff --git a/DryadVertex/VertexHost/system/classlib/include/DrTypes.h b/DryadVertex/VertexHost/system/classlib/include/DrTypes.h new file mode 100644 index 0000000..a59b307 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DrTypes.h @@ -0,0 +1,118 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "basic_types.h" + +typedef HRESULT DrError; + +typedef size_t Size_t; +static const Size_t Max_Size_t = (Size_t)((Size_t)0 - (Size_t)1); + +//JCnamespace apsdk +//JC{ + +__forceinline DrError DrGetLastError() +{ + return HRESULT_FROM_WIN32( GetLastError() ); +} + +__forceinline DrError DrGetLastErrorEnforceFail() +{ + DrError hr = HRESULT_FROM_WIN32(GetLastError()); + return FAILED(hr) ? hr : E_UNEXPECTED; +} + +__forceinline DrError DrErrorFromWin32(DWORD err) +{ + return HRESULT_FROM_WIN32(err); +} + +/////////////////////////////////// +#pragma warning(push) +#pragma warning (disable: 4201) // nonstandard extension used : nameless struct/union + +/* + A cosmos timestamp is defined as the number of 100-nanosecond intervals that have elapsed + since 12:00 A.M. January 1, 1601 (UTC). It is the representation of choice whenever an + absolute date/time must be used. +*/ +typedef unsigned __int64 DrTimeStamp; + +static const DrTimeStamp DrTimeStamp_Never = (DrTimeStamp)_UI64_MAX; +static const DrTimeStamp DrTimeStamp_LongAgo = (DrTimeStamp)0; + +/* + A cosmos elapsed time is defined as the number of 100-nanosecond intervals between two + points in time. It may be negative. It is what you get when you subtract two DrTimestamp values, and is + the representation of choice whenever a time interval needs to be represented persistently + or in a network protocol. +*/ +typedef __int64 DrTimeInterval; + +static const DrTimeInterval DrTimeInterval_Infinite = (DrTimeInterval)_I64_MAX; +static const DrTimeInterval DrTimeInterval_NegativeInfinite = (DrTimeInterval)_I64_MIN; +static const DrTimeInterval DrTimeInterval_Zero = (DrTimeInterval)0; +static const DrTimeInterval DrTimeInterval_Quantum = (DrTimeInterval)1; +static const DrTimeInterval DrTimeInterval_100ns = DrTimeInterval_Quantum; +static const DrTimeInterval DrTimeInterval_Microsecond = DrTimeInterval_100ns * 10; +static const DrTimeInterval DrTimeInterval_Millisecond = DrTimeInterval_Microsecond * 1000; +static const DrTimeInterval DrTimeInterval_Second = DrTimeInterval_Millisecond * 1000; +static const DrTimeInterval DrTimeInterval_Minute = DrTimeInterval_Second * 60; +static const DrTimeInterval DrTimeInterval_Hour = DrTimeInterval_Minute * 60; +static const DrTimeInterval DrTimeInterval_Day = DrTimeInterval_Hour * 24; +static const DrTimeInterval DrTimeInterval_Week = DrTimeInterval_Day * 7; + +// A DrTimeInterval_Year is defined as 52 weeks. It is for convenience, not for computing exact years. +static const DrTimeInterval DrTimeInterval_Year = DrTimeInterval_Week * 52; + +/* An IPV4 IP address in host byte order */ +typedef UInt32 DrIpAddress; + +static const DrIpAddress DrAnyIpAddress = 0; +static const DrIpAddress DrInvalidIpAddress = 0; +static const DrIpAddress DrLocalIpAddress = 0x7F000001; // 127.0.0.1 + + +/* An IP port number in host byte order */ +typedef UInt16 DrPortNumber; + +static const DrPortNumber DrInvalidPortNumber = 0xFFFF; +static const DrPortNumber DrAnyPortNumber = 0; + +#pragma warning(pop) + +//JC} // namespace apsdk +#ifdef USING_APSDK_NAMESPACE +using namespace apsdk; +#endif + +// +// disabling a couple of warnings that show up with /Wall and are pretty much useless -- they come mostly from macros +// +#pragma warning (disable: 4514) // unreferenced inline function has been removed +#pragma warning (disable: 4820) // 'N' bytes padding added after data member 'XXX' +#pragma warning (disable: 4265) // class has virtual functions, but destructor is not virtual +#pragma warning (disable: 4668) // XXX is not defined as preprocessor macro, replacing with 0 +#pragma warning (disable: 4711) // function XXX selected for automatic inline expansion +#pragma warning (disable: 4548) // malloc.h(245) & STL: expression before comma has no effect; expected expression with side-effect +#pragma warning (disable: 4127) // conditional expression is constant + diff --git a/DryadVertex/VertexHost/system/classlib/include/Dryad.h b/DryadVertex/VertexHost/system/classlib/include/Dryad.h new file mode 100644 index 0000000..8ed8530 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/Dryad.h @@ -0,0 +1,2185 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#pragma warning( push ) +/* 'X' bytes padding added after member 'Y' */ +#pragma warning( disable: 4820 ) + + +#pragma pack( push, 8 ) + + +#if !defined(_PCVOID_DEFINED) +typedef const void* PCVOID; +#define _PCVOID_DEFINED +#endif + + + +#if defined(__cplusplus) +extern "C" { +#endif + + +#define DRYADAPI_EXT +#define DRYADAPI __stdcall + + +typedef struct tagDRHANDLE +{ + ULONG_PTR Unused; +} *DRHANDLE, **PDRHANDLE; + + +/*++ + +DR_ASYNC_INFO structure + +Each function that may be executed asynchronously takes pointer +to DR_ASYNC_INFO structure as last parameter. + +If NULL is passed then function completes in synchronous manner +and error code is returned as return value. + +If parameter is not NULL then operation is carried on in asynchronous manner. +If asynchronous operation has been successfully started then function terminates +immediately with HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. +Any other return value indicates that it was impossible to start asynchronous operation. + +Fields: + + cbSize Size of structure in bytes. Set to sizeof(DR_ASYNC_INFO). + + pOperationState Pointer to error code returned by completed operation. + While operation is in progress value is set to HRESULT_FROM_WIN32(ERROR_IO_PENDING). + Before completion is reported value is set to an error code of completed operation. + Cannot be NULL. + + Event Handle to event. Event is set once operation is completed. May be NULL. + + IOCP Handle to IO completion port. If not NULL then upon completion status is posted + to specified completion port. + + pOverlapped Pointer to OVERLAPPED structure. Used in conjunction with IOCP parameter + to post status to IOCP. + Should be null if IOCP is NULL, cannot be NULL if IOCP is not NULL. + + CompletionKey Used in conjunction with IOCP parameter to post status to IOCP. + Should be 0 if IOCP is NULL. + + unusedX Fields reserved for future use. + +Note that DR_ASYNC_INFO structure is not required to be available for the duration of asynchronous call +(for example, this allows to allocate DR_ASYNC_INFO structure on stack). +In contrast variable specified by pOperationState pointer is required to be available +for the duration of the asynchronous call. + +--*/ +typedef struct tagDR_ASYNC_INFO { + SIZE_T cbSize; + + DrError* pOperationState; + + HANDLE Event; + + HANDLE IOCP; + LPOVERLAPPED pOverlapped; + UINT_PTR CompletionKey; + + UINT64 unused0; + UINT64 unused1; +} DR_ASYNC_INFO, *PDR_ASYNC_INFO; +typedef const DR_ASYNC_INFO* PCDR_ASYNC_INFO; + +//JC +#if 0 +typedef struct tagDRSESSIONHANDLE +{ + ULONG_PTR Unused; +} *DRSESSIONHANDLE, **PDRSESSIONHANDLE; + +/*++ + +DR_ASYNC_INFO structure + +Each function that may be executed asynchronously takes pointer +to DR_ASYNC_INFO structure as last parameter. + +If NULL is passed then function completes in synchronous manner +and error code is returned as return value. + +If parameter is not NULL then operation is carried on in asynchronous manner. +If asynchronous operation has been successfully started then function terminates +immediately with HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. +Any other return value indicates that it was impossible to start asynchronous operation. + +Fields: + + cbSize Size of structure in bytes. Set to sizeof(DR_ASYNC_INFO). + + pOperationState Pointer to error code returned by completed operation. + While operation is in progress value is set to HRESULT_FROM_WIN32(ERROR_IO_PENDING). + Before completion is reported value is set to an error code of completed operation. + Cannot be NULL. + + Event Handle to event. Event is set once operation is completed. May be NULL. + + IOCP Handle to IO completion port. If not NULL then upon completion status is posted + to specified completion port. + + pOverlapped Pointer to OVERLAPPED structure. Used in conjunction with IOCP parameter + to post status to IOCP. + Should be null if IOCP is NULL, cannot be NULL if IOCP is not NULL. + + CompletionKey Used in conjunction with IOCP parameter to post status to IOCP. + Should be 0 if IOCP is NULL. + + unusedX Fields reserved for future use. + +Note that DR_ASYNC_INFO structure is not required to be available for the duration of asynchronous call +(for example, this allows to allocate DR_ASYNC_INFO structure on stack). +In contrast variable specified by pOperationState pointer is required to be available +for the duration of the asynchronous call. + +--*/ +typedef struct tagDR_ASYNC_INFO { + SIZE_T cbSize; + + DrError* pOperationState; + + HANDLE Event; + + HANDLE IOCP; + LPOVERLAPPED pOverlapped; + UINT_PTR CompletionKey; + + UINT64 unused0; + UINT64 unused1; +} DR_ASYNC_INFO, *PDR_ASYNC_INFO; +typedef const DR_ASYNC_INFO* PCDR_ASYNC_INFO; + + +/*++ + +DR_INIT_PARAMS structure + +Contains optional parameters for cosmos initialization. +Default value of all fields is zero. + +Fields: + + cbSize Size of structure in bytes. Set to sizeof(DR_INIT_PARAMS). + + SuppressConsoleLogOutput Set to TRUE to suppress autopilot init output to console. + +--*/ +typedef struct tagDR_INIT_PARAMS { + SIZE_T cbSize; + + BOOL SuppressConsoleLogOutput; + BOOL fLogDisableMillisecondTimestamps; + + UINT64 Reserved0; + UINT64 Reserved1; + UINT64 Reserved2; + UINT64 Reserved3; + UINT64 Reserved4; + UINT64 Reserved5; + UINT64 Reserved6; + UINT64 Reserved7; +} DR_INIT_PARAMS, *PDR_INIT_PARAMS; +typedef const DR_INIT_PARAMS* PCDR_INIT_PARAMS; + + +/*++ + +DR_STREAM_PROPERTIES structure + +Used in DrOpenStream and DrSetStreamProperties. + +Fields: + + cbSize Size of structure in bytes. Set to sizeof(DR_STREAM_PROPERTIES). + + ExpirePeriod Expiration period in 100-ns intervals. On server side at the time + of request execution stream expiration time is set to curent time + plus ExpirePeriod. + + ReadOffsetHint Optimization hint. Currently not used. + + Flags Specifies new values of modified stream flags. + Set to 0 for DrOpenStream. + + FlagsMask Specifies stream flags to be modified. + Set to 0 for DrOpenStream. + +--*/ +typedef struct tagDR_STREAM_PROPERTIES { + SIZE_T cbSize; + INT64 ExpirePeriod; + UINT64 ReadOffsetHint; + UINT Flags; + UINT FlagsMask; + UINT64 unused0; // four attributes reserved for future extension + UINT64 unused1; + UINT64 unused2; + UINT64 unused3; +} DR_STREAM_PROPERTIES, *PDR_STREAM_PROPERTIES; +typedef const DR_STREAM_PROPERTIES* PCDR_STREAM_PROPERTIES; + + +/*++ + + Old name for DR_STREAM_PROPERTIES structure + +--*/ +typedef DR_STREAM_PROPERTIES DR_STREAM_PARAMS; +typedef PDR_STREAM_PROPERTIES PDR_STREAM_PARAMS; +typedef PCDR_STREAM_PROPERTIES PCDR_STREAM_PARAMS; + + +typedef struct tagDR_EXTENT_INSTANCE { + SIZE_T cbSize; + + UINT Flags; + PCSTR ExtentNodeName; + UINT64 unused0; // two attributes reserved for future extension + UINT64 unused1; +} DR_EXTENT_INSTANCE, *PDR_EXTENT_INSTANCE; + + +typedef struct tagDR_EXTENT { + SIZE_T cbSize; + + UINT Flags; + + UINT64 ModificationTime; + + UINT64 Length; + + UINT64 Crc64; + + GUID Id; + + UINT NumberOfInstances; + PDR_EXTENT_INSTANCE Instances; + UINT64 unused0; // three attributes reserved for future extension + UINT64 unused1; + UINT64 unused2; + UINT64 unused3; +} DR_EXTENT, *PDR_EXTENT; + +/*++ +Fields: + + Flags DR_OBJECT_STREAM or DR_OBJECT_DIRECTORY +--*/ +typedef struct tagDR_STREAM { + SIZE_T cbSize; + + PCSTR pName; + + UINT Flags; + + UINT64 CreationTime; + UINT64 ExpirationTime; + UINT64 ModificationTime; + + UINT64 Length; + UINT TotalNumberOfExtents; + + GUID Id; + + UINT StartExtentIndex; + UINT64 StartExtentOffset; + UINT ExtentsCount; + PDR_EXTENT pExtents; + PCSTR pPath; + UINT64 unused0; // four attributes reserved for future extension + UINT64 unused1; + UINT64 unused2; + UINT64 unused3; +} DR_STREAM, *PDR_STREAM; + + +typedef struct tagDR_DIRECTORY { + SIZE_T cbSize; + + PDR_STREAM pStreams; + SIZE_T StreamsCount; + UINT64 unused0; // four attributes reserved for future extension + UINT64 unused1; + UINT64 unused2; + UINT64 unused3; +} DR_DIRECTORY, *PDR_DIRECTORY; + + +typedef struct tagDR_EXTENT_NODE { + SIZE_T cbSize; + + UINT ExtentNodeVersion; + PCSTR pExtentNodeVersionText; + + PCSTR ExtentNodeName; + UINT Flags; + UINT64 StartupTime; + UINT64 LastSyncTime; + PCSTR ReasonText; + + UINT NumberOfSealedExtents; + UINT64 SealedExtentsLength; + UINT64 SealedExtentsSizeOnDisk; + + UINT NumberOfUnsealedExtents; + UINT64 UnsealedExtentsLength; + UINT64 UnsealedExtentsSizeOnDisk; + + UINT NumberOfUnvacatedExtents; + UINT64 UnvacatedExtentsLength; + + UINT64 FreeSpaceSize; + + UINT16 ScaleUnit; + UINT64 unused0; // four attributes reserved for future extension + UINT64 unused1; + UINT64 unused2; + UINT64 unused3; +} DR_EXTENT_NODE, *PDR_EXTENT_NODE; + + +typedef struct tagDR_HOST { + PCSTR HostName; + UINT16 Port; + PCSTR PodName; + UINT16 ScaleUnit; +} DR_HOST, *PDR_HOST; + + +typedef struct tagDR_VOLUME { + SIZE_T cbSize; + + UINT DrmVersion; + PCSTR pDrmVersionText; + + UINT64 CurrentTime; + UINT64 StartupTime; + + UINT64 NumberOfStreams; + UINT64 StreamsLength; + UINT64 NumberOfCreatedStreams; + + UINT64 NumberOfSealedExtents; + UINT64 SealedExtentsLength; + + UINT64 NumberOfUnsealedExtents; + UINT64 UnsealedExtentsLength; + + PCSTR pPrimaryDRM; + UINT NumberOfDrmHosts; + PDR_HOST pDrmHostList; + UINT64 DrmRslSequence; + GUID VolumeId; + + UINT64 NumberOfCreatedExtents; + UINT NumberOfExtentNodes; + PDR_EXTENT_NODE pExtentNodes; + UINT64 unused0; // four attributes reserved for future extension + UINT64 unused1; + UINT64 unused2; + UINT64 unused3; +} DR_VOLUME, *PDR_VOLUME; + + +typedef struct tagDR_STREAM_POSITION { + UINT ExtentIndex; + UINT64 Offset; +} DR_STREAM_POSITION, *PDR_STREAM_POSITION; +typedef const DR_STREAM_POSITION* PCDR_STREAM_POSITION; + +/*++ + +DR_MULTIMODIFY_EXTENT_RANGE structure + +Used in DrMultiModifyStream to describe a contiguous range of extents to +be concatenated onto a target stream. + +Fields: + + cbSize Size of structure in bytes. Set to sizeof(DR_MULTIMODIFY_EXTENT_RANGE). + + MultiModifyStreamEntryIndex + The 0-based index of the DR_MULTIMODIFY_STREAM_ENTRY, provided in the call to DrMultiModifyStream, which describes + the source stream from which to copy extents. + + ExtentIndex + The index in the source stream of the first extent to be copied. 0 means beginning of stream. + + ExtentCount + The number of contiguous extents in the range. + If 0, no extents will be copied, but the operation will succeed. + If DR_ALL_EXTENTS, all remaining extents in the source stream will be concatenated (may be 0). +--*/ + +typedef struct tagDR_MULTIMODIFY_EXTENT_RANGE { + SIZE_T cbSize; + UINT MultiModifyStreamEntryIndex; + UINT ExtentIndex; + UINT ExtentCount; +} DR_MULTIMODIFY_EXTENT_RANGE, *PDR_MULTIMODIFY_EXTENT_RANGE; +typedef const DR_MULTIMODIFY_EXTENT_RANGE *PCDR_MULTIMODIFY_EXTENT_RANGE; + +/*++ + +DR_MULTIMODIFY_STREAM_ENTRY structure + +Used in DrMultiModifyStream to describe one of a set of +atomically applied stream modifications. + +Combines the capabilities of DrDeleteStream, DrRenameStream, DrConcatenateStream, and +DrSetStreamProperties, with atomic compare/exchange semantics. + +Note that it is explicitly possible to construct an entry that verifies the existence or size of a particular +stream without actually making any changes to the stream. This is particularly +useful when concatenating append ranges, since each source stream must be described +by a DR_MULTIMODIFY_STREAM_ENTRY even if it is not to be touched. + +If this structure is set to all zeroes, and only pOldStream is initialized, the effect will be to +verify existence of stream pOldStream, but make no changes to it. + +Fields: + + cbSize Size of structure in bytes. Set to sizeof(DR_MULTIMODIFY_ENTRY). + + ModifyFlags Options for the entry. Includes: + DR_MULTIMODIFY_ENTRY_SET_EXPIRE_PERIOD - the expire time is to be updated from pStreamProperties + DR_MULTIMODIFY_ENTRY_DELETE_STREAM -- the stream is to be deleted or recycled according to deleteFlags + DR_MULTIMODIFY_ENTRY_ENFORCE_EXTENT_COUNT -- fail the request with DrError_StreamChanged if requiredExtentCount does not match + DR_MULTIMODIFY_ENTRY_CREATE_STREAM -- Create the new stream names in pNewStreamName + + pOldStreamName + Fully qualified stream existing to be verified or modified. + + pNewStreamName + If renaming or creating a new stream, the new fully qualified stream name. NULL otherwise. Must + be on the same volume. + + DeleteFlags + If deleting (DR_MULTIMODIFY_ENTRY_DELETE_STREAM), options for the delete (See DrDeleteStream). Must be + 0 if not deleting. + + pStreamProperties + If setting stream properties (including updating expire time), the new + stream properties. NULL otherwise. See DrSetStreamProperties. + + RequiredStreamId + if not NULL_GUID, causes the entire request to fail with DrError_StreamChanged + if the stream does not have the specified stream ID. + + RequiredExtentCount + if DR_MULTIMODIFY_ENTRY_ENFORCE_EXTENT_COUNT is 1, causes the request to fail with DrError_StreamChanged if + the stream does not contain the specified number of extents (prior to conconcatenation). + + ConcatenateExtentRangeCount + The number of entries in pConcatenateExtentRanges. 0 If no concatenation is to be performed. + + pConcatenateExtentRanges + An array of extent range descriptors of length ConcatenateExtentRangeCount. Each entry describes a contiguous block + of extents to be appended onto the end of this stream, in order. NULL if no concatenation is to be performed. + See DrConcatenateStream. +--*/ +typedef struct tagDR_MULTIMODIFY_STREAM_ENTRY { + SIZE_T cbSize; + UINT ModifyFlags; + + PCSTR pOldStreamName; + PCSTR pNewStreamName; + UINT DeleteFlags; + + PCDR_STREAM_PROPERTIES pStreamProperties; + GUID RequiredStreamId; + UINT RequiredExtentCount; + UINT ConcatenateExtentRangeCount; + PCDR_MULTIMODIFY_EXTENT_RANGE pConcatenateExtentRanges; + + UINT64 unused0; // four attributes reserved for future extension + UINT64 unused1; + UINT64 unused2; + UINT64 unused3; +} DR_MULTIMODIFY_STREAM_ENTRY, *PDR_MULTIMODIFY_STREAM_ENTRY; +typedef const DR_MULTIMODIFY_STREAM_ENTRY *PCDR_MULTIMODIFY_STREAM_ENTRY; + +// +// MultiModifyStream Entry flags +// + +// +// DR_MULTIMODIFY_ENTRY_SET_EXPIRE_PERIOD If 1, stream's expiration period will be reset from pStreamProperties +// Append, concatenate, delete, rename and shrink expiration time operations +// fail for such streams. +// +#define DR_MULTIMODIFY_ENTRY_SET_EXPIRE_PERIOD 0x00000001u + +// +// DR_MULTIMODIFY_ENTRY_DELETE_STREAM If 1, stream will be deleted according to deleteFlags +// +#define DR_MULTIMODIFY_ENTRY_DELETE_STREAM 0x00000002u + +// +// DR_MULTIMODIFY_ENTRY_ENFORCE_EXTENT_COUNT If 1, request will fail with DrError_StreamChanged if requiredExtentCount does not match current stream size +// +#define DR_MULTIMODIFY_ENTRY_ENFORCE_EXTENT_COUNT 0x00000004u + +// +// DR_MULTIMODIFY_ENTRY_CREATE_STREAM If 1, creates stream pNewStreamName. pOldStreamName should be NULL. +// +#define DR_MULTIMODIFY_ENTRY_CREATE_STREAM 0x00000008u + + +typedef struct tagDR_RANGE +{ + PCSTR pStreamName; + UINT ExtentIndex; + UINT ExtentCount; +} DR_RANGE; +typedef DR_RANGE* PDR_RANGE; +typedef const DR_RANGE* PCDR_RANGE; + + +#define DR_CURRENT_VERSION 0x00000009u + +// +// Stream open modes +// +#define DR_APPEND 0x00000001u +#define DR_READ 0x00000002u + +// +// Creation disposition +// +#define DR_OPEN 0x00000010u +#define DR_CREATE 0x00000020u +#define DR_CREATE_OR_OPEN (DR_OPEN | DR_CREATE) + +// +// Metadata refresh policy flags for DrOpenStream and DrReadStream +// Refresh policy is set during DrOpenStream, and then may be overridden by DrReadStream +// If DrReadStream does not override, options from DrOpenStream affect DrReadStream as well +// +#define DR_REFRESH_AGGRESSIVE 0x10000000u +#define DR_REFRESH_PASSIVE 0x20000000u +#define DR_REFRESH_FROM_CACHE 0x30000000u +#define DR_REFRESH_NO_INSTANCES 0x40000000u // permit to refresh instances (length can be refreshed always) + +#define DR_REFRESH_MASK (DR_REFRESH_AGGRESSIVE | DR_REFRESH_PASSIVE | DR_REFRESH_NO_INSTANCES) // all refresh bits +#define DR_REFRESH_MASK_MUST_HAVE (DR_REFRESH_AGGRESSIVE | DR_REFRESH_PASSIVE) // at least one of these bits must be on +#define DR_REFRESH_DEFAULT (DR_REFRESH_AGGRESSIVE) // default for DrOpenStream +#define DR_REFRESH_DEFAULT_STREAMINFO (DR_REFRESH_AGGRESSIVE | DR_REFRESH_NO_INSTANCES) // default for DrOpenStreamFromStreamInfo + +// +// Cache policy settings for DrOpenStream (possible future TODO: allow to be passed directly to DrReadStream) +// +#define DR_CACHE_SEQUENTIAL 0x01000000u +#define DR_CACHE_READ_AHEAD 0x02000000u // same as DR_CACHE_SEQUENTIAL, but with optimistic read ahead in background +#define DR_CACHE_RANDOM_ACCESS 0x03000000u +#define DR_CACHE_NO_CACHE 0x04000000u // ineffective if read boundaries do not exactly match append boundaries + +#define DR_CACHE_DEFAULT DR_CACHE_SEQUENTIAL +#define DR_CACHE_MASK 0x07000000u + +// +// Object selection options +// +#define DR_OBJECT_STREAM 0x00000100u +#define DR_OBJECT_DIRECTORY 0x00000200u +#define DR_OBJECT_ANY (DR_OBJECT_STREAM | DR_OBJECT_DIRECTORY) +#define DR_OBJECT_LOCAL 0x00000400u +#define DR_OBJECT_RECYCLED 0x00000800u + +// +// Stream flags +// + +// +// DR_STREAM_READ_ONLY Stream is read-only. +// Append, concatenate, delete, rename and shrink expiration time operations +// fail for such streams. +// +#define DR_STREAM_READ_ONLY 0x00000008u + +// +// DR_STREAM_SEALED Stream is sealed. +// Append and concatenate operations +// fail for such streams. This bit cannot be cleared once set. Does not effect +// append operations to unsealed extents within the stream. +// +#define DR_STREAM_SEALED 0x00000010u + +// +// Extent flags +// +#define DR_EXTENT_SEALED 0x00001000u +#define DR_EXTENT_NEEDS_SEAL 0x00002000u + +// +// Extent node flags +// +#define DR_EXTENT_NODE_AVAILABLE 0x00000001u + +// +// Stream append options +// +#define DR_FIXED_OFFSET_APPEND 0x01000000u +#define DR_PARTIAL_APPEND 0x02000000u +#define DR_SEAL 0x04000000u +#define DR_UPDATE_DRM 0x08000000u + +// +// Stream delete options +// +#define DR_DELETE_FORCE 0x00000001u + +// +// Stream concatenation options +// +#define DR_CONCATENATE_ALLOW_CREATE_NEW 0x00100000u +#define DR_CONCATENATE_REQUIRE_CREATE_NEW 0x00200000u + +#define DR_CONCATENATE_MASK (DR_CONCATENATE_ALLOW_CREATE_NEW | DR_CONCATENATE_REQUIRE_CREATE_NEW) + +// +// Retrieve or append all extents +// +#define DR_ALL_EXTENTS 0xFFFFFFFFu + +#define DR_INVALID_EXTENT_INDEX 0xFFFFFFFFu +#define DR_UNKNOWN_EXTENT_INDEX 0xFFFFFFFFu + +#define DR_UNKNOWN_OFFSET 0xFFFFFFFFFFFFFFFFui64 +#define DR_UNKNOWN_LENGTH 0xFFFFFFFFFFFFFFFFui64 + +#define DR_NEVER 0xFFFFFFFFFFFFFFFFui64 +#define DR_INFINITE 0x7FFFFFFFFFFFFFFFi64 + +// +// Stream renaming flags +// +#define DR_RENAME_ALLOW_SUBSTITUTION 0x00000001u +#define DR_RENAME_REQUIRE_SUBSTITUTION 0x00000002u + +#define DR_RENAME_MASK (DR_RENAME_ALLOW_SUBSTITUTION | DR_RENAME_REQUIRE_SUBSTITUTION) + +// +// Get position flags +// +#define DR_ABSOLUTE_POSITION 0x00000001u + +// +// Stream option constants +// + +// +// DR_STREAM_OPTION_MAX_EXTENT_SIZE UINT64 +// Maximum extent size client passes to DRM and EN, +// if DRM returned smaller maximum size, the smaller +// size is used to pass to EN. +// Pass zero to use built-in default (DR_STREAM_OPTION_MAX_EXTENT_SIZE_FALLBACK). +// Set to value from the configuration file if there is any. +// +#define DR_STREAM_OPTION_MAX_EXTENT_SIZE 1u +#define DR_STREAM_OPTION_MAX_EXTENT_SIZE_DEFAULT 0ui64 +#define DR_STREAM_OPTION_MAX_EXTENT_SIZE_FALLBACK 0x6400000ui64 // 100 Mb +#define DR_STREAM_OPTION_MAX_EXTENT_SIZE_CONFIG "ExtentSize" + +// +// DR_STREAM_OPTION_MAX_PHYSICAL_EXTENT_SIZE UINT64 +// Maximum physical (compressed) extent size client passes to EN. +// Pass zero to use built-in default (DR_STREAM_OPTION_MAX_PHYSICAL_EXTENT_SIZE_FALLBACK). +// Set to value from the configuration file if there is any. +// +#define DR_STREAM_OPTION_MAX_PHYSICAL_EXTENT_SIZE 2u +#define DR_STREAM_OPTION_MAX_PHYSICAL_EXTENT_SIZE_DEFAULT 0ui64 +#define DR_STREAM_OPTION_MAX_PHYSICAL_EXTENT_SIZE_FALLBACK 0xFFFFFFFFFFFFFFFFui64 // unlimited +#define DR_STREAM_OPTION_MAX_PHYSICAL_EXTENT_SIZE_CONFIG "PhysicalExtentSize" + + +// +// DR_STREAM_OPTION_COMPRESSION_LEVEL UINT +// Specifies compression level for data being appended. +// Value from 0 (fastest, no compression) to 5 (slowest, best compression ration). +// Set to value from the configuration file if there is any. +// +#define DR_STREAM_OPTION_COMPRESSION_LEVEL 3u +#define DR_STREAM_OPTION_COMPRESSION_LEVEL_DEFAULT 2u +#define DR_STREAM_OPTION_COMPRESSION_LEVEL_MIN 0u +#define DR_STREAM_OPTION_COMPRESSION_LEVEL_MAX 5u +#define DR_STREAM_OPTION_COMPRESSION_LEVEL_CONFIG "CompressionLevel" + + +// +// DR_STREAM_OPTION_MAX_APPEND_RETRY UINT +// Specifies number of retries for unsuccessful stream append. +// Set to value from the configuration file if there is any. +// +#define DR_STREAM_OPTION_MAX_APPEND_RETRY 4u +#define DR_STREAM_OPTION_MAX_APPEND_RETRY_DEFAULT 16u +#define DR_STREAM_OPTION_MAX_APPEND_RETRY_MIN 1u +#define DR_STREAM_OPTION_MAX_APPEND_RETRY_MAX 1023u +#define DR_STREAM_OPTION_MAX_APPEND_RETRY_CONFIG "AppendRetry" + + +// +// DR_STREAM_OPTION_MAX_READ_RETRY UINT +// Specifies number of retries for unsuccessful stream read. +// Set to value from the configuration file if there is any. +// +#define DR_STREAM_OPTION_MAX_READ_RETRY 5u +#define DR_STREAM_OPTION_MAX_READ_RETRY_DEFAULT 16u +#define DR_STREAM_OPTION_MAX_READ_RETRY_MIN 1u +#define DR_STREAM_OPTION_MAX_READ_RETRY_MAX 1023u +#define DR_STREAM_OPTION_MAX_READ_RETRY_CONFIG "ReadRetry" + + +// +// DR_STREAM_OPTION_FLAGS UINT +// Contains flags value for given handle. +// Read-only. +// +#define DR_STREAM_OPTION_FLAGS 6u + + +// +// DR_STREAM_OPTION_MAX_APPEND_SIZE UINT64 +// Maxim block size for single append. +// Set to value from the configuration file if there is any. +// +#define DR_STREAM_OPTION_MAX_APPEND_SIZE 7u +#define DR_STREAM_OPTION_MAX_APPEND_SIZE_DEFAULT 0x400000ui64 // 4 Mb +#define DR_STREAM_OPTION_MAX_APPEND_SIZE_MIN 1ui64 +#define DR_STREAM_OPTION_MAX_APPEND_SIZE_MAX 0x2000000ui64 // 32 Mb +#define DR_STREAM_OPTION_MAX_APPEND_SIZE_CONFIG "MaxAppendSize" + + +// +// DR_STREAM_OPTION_VERIFY_COMPRESSION BOOL +// Set to TRUE to enable additional pass to verify compressed data +// for corruption on append. +// Should be either TRUE or FALSE. +// Set to value from the configuration file if there is any. +// +#define DR_STREAM_OPTION_VERIFY_COMPRESSION 8u +#define DR_STREAM_OPTION_VERIFY_COMPRESSION_DEFAULT FALSE +#define DR_STREAM_OPTION_VERIFY_COMPRESSION_CONFIG "VerifyCompression" + + +// +// DR_STREAM_OPTION_GUID GUID +// Returns stream GUID for given handle. +// Read-only. +// +#define DR_STREAM_OPTION_GUID 9u + + +// +// DR_STREAM_OPTION_SHORT_NAME String +// Returns stream name without path for given handle. +// Read-only. +// +#define DR_STREAM_OPTION_SHORT_NAME 10u + + +// +// DR_STREAM_OPTION_PATH String +// Returns stream path for given handle. +// Read-only. +// +#define DR_STREAM_OPTION_PATH 11u + + +// +// DR_STREAM_OPTION_VOLUME String +// Returns stream volume name for given handle. +// Read-only. +// +#define DR_STREAM_OPTION_VOLUME 12u + + +// +// DR_STREAM_OPTION_CLUSTER String +// Returns stream cluster name for given handle. +// Read-only. +// +#define DR_STREAM_OPTION_CLUSTER 13u + + +// +// DR_STREAM_OPTION_NAME String +// Returns complete stream name in URI form for given handle. +// Read-only. +// +#define DR_STREAM_OPTION_NAME 14u + + +// +// DR_STREAM_OPTION_GUID_PATH String +// Returns .streamid stream path for given handle. +// Read-only. +// +#define DR_STREAM_OPTION_GUID_PATH 15u + + +// +// DR_STREAM_OPTION_GUID_URI String +// Returns .streamid stream name in URI form for given handle. +// Read-only. +// +#define DR_STREAM_OPTION_GUID_URI 16u + + +// +// DR_STREAM_OPTION_NAME_GUID_PATH String +// Returns combined .streamid stream path and namespace location for given handle. +// Read-only. +// +#define DR_STREAM_OPTION_NAME_GUID_PATH 17u + + +// +// DR_STREAM_OPTION_NAME_GUID_URI String +// Returns combined .streamid stream name and namespace location in URI form for given handle. +// Read-only. +// +#define DR_STREAM_OPTION_NAME_GUID_URI 18u + +// +// DR_STREAM_OPTION_MAX_CONCURRENT_READS UINT +// Maximum number of concurrent read request for single DrReadStream call. +// If first read attempt does not complete in DR_STREAM_OPTION_CONCURRENT_READ_TIMEOUT_1_MS +// time, second read request for the same data in sent to the different storage node. +// DR_STREAM_OPTION_CONCURRENT_READ_TIMEOUT_2_MS timeout is used for the second +// and subsequent read attempts. If this option is set to 1 no concurrent read requests are +// performed. While this option limits number of concurrent retries, +// DR_STREAM_OPTION_MAX_READ_RETRY option limits total number of retries for single API call. +// Aggressive sittings for the maximum concurrent reads number and timeouts may result in +// increased network load and cause load convoy effects during peak loads. +// Set to value from the configuration file if there is any. +// +#define DR_STREAM_OPTION_MAX_CONCURRENT_READS 19u +#define DR_STREAM_OPTION_MAX_CONCURRENT_READS_DEFAULT 2u +#define DR_STREAM_OPTION_MAX_CONCURRENT_READS_MIN 1u +#define DR_STREAM_OPTION_MAX_CONCURRENT_READS_MAX 3u +#define DR_STREAM_OPTION_MAX_CONCURRENT_READS_CONFIG "MaxConcurrentReads" + +// +// DR_STREAM_OPTION_CONCURRENT_READ_TIMEOUT_1_MS UINT +// Wait time in milliseconds for making second concurrent read request +// if first read request is not complete. Maximum number of concurrent read +// requests is limited by DR_STREAM_OPTION_MAX_CONCURRENT_READS option. +// Set to value from the configuration file if there is any. +// +#define DR_STREAM_OPTION_CONCURRENT_READ_TIMEOUT_1_MS 20u +#define DR_STREAM_OPTION_CONCURRENT_READ_TIMEOUT_1_MS_DEFAULT 10000u +#define DR_STREAM_OPTION_CONCURRENT_READ_TIMEOUT_1_MS_MIN 0u +#define DR_STREAM_OPTION_CONCURRENT_READ_TIMEOUT_1_MS_MAX 60000u +#define DR_STREAM_OPTION_CONCURRENT_READ_TIMEOUT_1_MS_CONFIG "ConcurrentReadTimeout1Ms" + +// +// DR_STREAM_OPTION_CONCURRENT_READ_TIMEOUT_2_MS UINT +// Wait time in milliseconds after starting second read request for making third concurrent +// read request if either first or second read request is not complete. Maximum number of +// concurrent read requests is limited by DR_STREAM_OPTION_MAX_CONCURRENT_READS option. +// Set to value from the configuration file if there is any. +// +#define DR_STREAM_OPTION_CONCURRENT_READ_TIMEOUT_2_MS 21u +#define DR_STREAM_OPTION_CONCURRENT_READ_TIMEOUT_2_MS_DEFAULT 10000u +#define DR_STREAM_OPTION_CONCURRENT_READ_TIMEOUT_2_MS_MIN 0u +#define DR_STREAM_OPTION_CONCURRENT_READ_TIMEOUT_2_MS_MAX 60000u +#define DR_STREAM_OPTION_CONCURRENT_READ_TIMEOUT_2_MS_CONFIG "ConcurrentReadTimeout2Ms" + +// +// DR_STREAM_OPTION_MAX_OUTSTANDING_APPENDS UINT +// Maximum number of append requests sent to network at same time per stream. +// Set to value from the configuration file if there is any. +// +#define DR_STREAM_OPTION_MAX_OUTSTANDING_APPENDS 22u +#define DR_STREAM_OPTION_MAX_OUTSTANDING_APPENDS_DEFAULT 32u +#define DR_STREAM_OPTION_MAX_OUTSTANDING_APPENDS_MIN 1u +#define DR_STREAM_OPTION_MAX_OUTSTANDING_APPENDS_MAX 128u +#define DR_STREAM_OPTION_MAX_OUTSTANDING_APPENDS_CONFIG "MaxOutstandingAppends" + +// +// DR_STREAM_OPTION_MAX_OUTSTANDING_APPEND_SIZE UINT64 +// Maximum size of data contained in append requests sent to network at same time per stream. +// Set to value from the configuration file if there is any. +// +#define DR_STREAM_OPTION_MAX_OUTSTANDING_APPEND_SIZE 23u +#define DR_STREAM_OPTION_MAX_OUTSTANDING_APPEND_SIZE_DEFAULT 0x2000000ui64 // 32 Mb +#define DR_STREAM_OPTION_MAX_OUTSTANDING_APPEND_SIZE_MIN 0x200000ui64 // 2 Mb +#define DR_STREAM_OPTION_MAX_OUTSTANDING_APPEND_SIZE_MAX 0x4000000ui64 // 64 Mb +#define DR_STREAM_OPTION_MAX_OUTSTANDING_APPEND_SIZE_CONFIG "MaxOutstandingAppendSize" + +// +// DR_STREAM_OPTION_READ_AHEAD_BLOCK_SIZE UINT64 +// Read block size to be used if handle is opened with DR_CACHE_READ_AHEAD flag. +// Set to value from the configuration file if there is any. +// +#define DR_STREAM_OPTION_READ_AHEAD_BLOCK_SIZE 24u +#define DR_STREAM_OPTION_READ_AHEAD_BLOCK_SIZE_DEFAULT 0x200000ui64 // 2 Mb +#define DR_STREAM_OPTION_READ_AHEAD_BLOCK_SIZE_MIN 0x4000ui64 // 16 Kb +#define DR_STREAM_OPTION_READ_AHEAD_BLOCK_SIZE_MAX 0x2000000ui64 // 32 Mb +#define DR_STREAM_OPTION_READ_AHEAD_BLOCK_SIZE_CONFIG "ReadAheadBlockSize" + +// +// DR_STREAM_OPTION_READ_AHEAD_OUTSTANDING_COUNT UINT +// Maximum of simultaneous optimistic read requests to issue for handles opened with DR_CACHE_READ_AHEAD flag. +// Set to value from the configuration file if there is any. +// +#define DR_STREAM_OPTION_READ_AHEAD_OUTSTANDING_COUNT 25u +#define DR_STREAM_OPTION_READ_AHEAD_OUTSTANDING_COUNT_DEFAULT 8u +#define DR_STREAM_OPTION_READ_AHEAD_OUTSTANDING_COUNT_MIN 0u +#define DR_STREAM_OPTION_READ_AHEAD_OUTSTANDING_COUNT_MAX 32u +#define DR_STREAM_OPTION_READ_AHEAD_OUTSTANDING_COUNT_CONFIG "ReadAheadOutstandingCount" + +// +// DR_STREAM_OPTION_CACHE_DEFAULT_MODE UINT +// Default cache operation mode for reads. +// Set to value from the configuration file if there is any. +// NOTE: this is global-only setting and it can't be changed for individual handles. +// +#define DR_STREAM_OPTION_CACHE_DEFAULT_MODE 26u +#define DR_STREAM_OPTION_CACHE_DEFAULT_MODE_DEFAULT DR_CACHE_DEFAULT +#define DR_STREAM_OPTION_CACHE_DEFAULT_MODE_MIN DR_CACHE_SEQUENTIAL +#define DR_STREAM_OPTION_CACHE_DEFAULT_MODE_MAX DR_CACHE_NO_CACHE +#define DR_STREAM_OPTION_CACHE_DEFAULT_MODE_MASK DR_CACHE_MASK +#define DR_STREAM_OPTION_CACHE_DEFAULT_MODE_CONFIG "CacheDefaultMode" + +// +// DR_STREAM_OPTION_CACHE_SOFT_SIZE_LIMIT UINT64 +// Soft limit for cache size measured in bytes of uncompressed data. +// After soft limit is reached sequential and read-ahead blocks are discarded +// immediately after use while random-read blocks are kept up to hard limit. +// Set to value from the configuration file if there is any. +// NOTE: this is global-only setting and it can't be changed for individual handles. +// +#define DR_STREAM_OPTION_CACHE_SOFT_SIZE_LIMIT 27u +#define DR_STREAM_OPTION_CACHE_SOFT_SIZE_LIMIT_DEFAULT 0x10000000ui64 // 256 Mb +#define DR_STREAM_OPTION_CACHE_SOFT_SIZE_LIMIT_MIN 0x0ui64 +#define DR_STREAM_OPTION_CACHE_SOFT_SIZE_LIMIT_MAX MAX_UINT64 +#define DR_STREAM_OPTION_CACHE_SOFT_SIZE_LIMIT_CONFIG "CacheSoftSizeLimit" + +// +// DR_STREAM_OPTION_CACHE_HARD_SIZE_LIMIT UINT64 +// Hard limit for cache size measured in bytes of uncompressed data. +// Set to value from the configuration file if there is any. +// NOTE: this is global-only setting and it can't be changed for individual handles. +// +#define DR_STREAM_OPTION_CACHE_HARD_SIZE_LIMIT 28u +#define DR_STREAM_OPTION_CACHE_HARD_SIZE_LIMIT_DEFAULT 0x20000000ui64 // 512 Mb +#define DR_STREAM_OPTION_CACHE_HARD_SIZE_LIMIT_MIN 0x0ui64 +#define DR_STREAM_OPTION_CACHE_HARD_SIZE_LIMIT_MAX MAX_UINT64 +#define DR_STREAM_OPTION_CACHE_HARD_SIZE_LIMIT_CONFIG "CacheHardSizeLimit" + + +// +// Init APIs +// + +/*++ + +DrInitialize + +This function behaves like DrInitializeEx with pInitParams set to NULL. +See description of DrInitializeEx below. + +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrInitialize( + IN UINT Version, + IN PCSTR pDryadIniFile + ); + + +/*++ + +DrInitialize + +Parameters: + + Version Requested version number of the library. Pass DR_CURRENT_VERISON. + + pDryadIniFile Pointer to string containing path and file name for cosmos ini file. + Pass NULL to use default. Default is cosmos.ini in current directory. + + pInitParams Pointer to parametes block. May be NULL. + +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrInitializeEx( + IN UINT Version, + IN PCSTR pDryadIniFile, + IN PCDR_INIT_PARAMS pInitParams + ); + + +DRYADAPI_EXT +void +DRYADAPI +DrLogPreInitialize ( + void +); + + +DRYADAPI_EXT +DrError +DRYADAPI +DrUninitialize( void ); + + +// +// Free for data structures allocated by Dryad client lib +// + +DRYADAPI_EXT +DrError +DRYADAPI +DrFreeMemory( + IN PVOID ptr + ); + + +/*++ + +DrOpenStream + +Opens or creates cosmos or local file system stream. + +Parameters: + + pStreamName UTF8 encoded fully qualified cosmos URI, file URI, UNC name, or local file name. + Cannot be NULL. + + Flags Specifies if stream should be open or created and open mode. + Combines next values: + + DR_APPEND open for append | exclusive, required + DR_READ open for read | + (one of above should be specified) + + DR_OPEN open existing stream | can be combined, required + DR_CREATE create new | + DR_CREATE_OR_OPEN create or open existing + (defined as DR_OPEN | DR_CREATE) + (at least one of above should be specified) + + DR_REFRESH_AGGRESSIVE (default) Refresh metadata from DRM if EOS | exclusive, optional + is detected. | + DR_REFRESH_PASSIVE Do not refresh metadata from DRM if EOS | + is detected. | + (refresh flags are optional) + DR_REFRESH_NO_INSTANCES Do not refresh instances in case of read error + + + DR_CACHE_SEQUENTIAL (default) Optimize data cache for | exclusive, optional + sequential access. Reading same item twice | + is not efficient. | + DR_CACHE_READ_AHEAD Optimize data cache for | + sequential access and enable optimistic | + background read-ahead. | + DR_CACHE_RANDOM_ACCESS Optimize data cache for random | + access. Reading same items twice is | + efficient. | + DR_CACHE_NO_CACHE Do not cache data. Consumes least memory. | + Could lead to data reread if read boundaries| + do not exactly match append boundaries. | + (cache flags are optional) + + pStreamParams Pointer to DR_STREAM_PROPERTIES structure. DR_STREAM_PROPERTIES structure + contains additional optional parameters. + + pStreamHandle Pointer to variable that receives handle for opened stream. + NOTE: if operation is asynchronous, this location should be valid + until completion of operation is reported. + + pAsyncInfo Pointer to DR_ASYNC_INFO structure. + See description of to DR_ASYNC_INFO structure. + +Examples: + +// create new or open existing cosmos stream for append synchronously +DRHANDLE h; +DrError ret = DrOpenStream( "cosmos://store/volume/dir1/dir2/stream1.ext", + DR_CREATE_OR_OPEN | DR_APPEND, NULL, &h, NULL ); + +// asyncronously open existing stream for read +// use event for completion notification +DrError ret; +DRHANDLE h; +DrError asyncRet; +DR_ASYNC_INFO asyncInfo; +ZeroMemory( &asyncInfo, sizeof( asyncInfo ) ); +asyncInfo.cbSize = sizeof( asyncInfo ); +asyncInfo.pOperationState = &asyncRet; +asyncInfo.Event = CreateEvent( NULL, TRUE, FALSE, NULL ); +ret = DrOpenStream( "cosmos://store/volume/dir1/dir2/stream2.ext", + DR_OPEN | DR_READ, NULL, &h, &asyncInfo ); +if ( ret == HRESULT_FROM_WIN32( ERROR_IO_PENDING ) ) { + WaitForSingleObject( asyncInfo.Event, INFINITE ); + if ( SUCCEEDED( asyncRet ) ) { + // Open completed successfully + } +} + +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrOpenStream( + IN PCSTR pStreamName, + IN UINT Flags, + IN PCDR_STREAM_PROPERTIES pStreamParams, + OUT PDRHANDLE pStreamHandle, + IN PCDR_ASYNC_INFO pAsyncInfo + ); + +// +// Extended version of open stream which accept a session handle as parameter +// + +DRYADAPI_EXT +DrError +DRYADAPI +DrOpenStreamEx( + IN DRSESSIONHANDLE Session, + IN PCSTR pStreamName, + IN UINT Flags, + IN PCDR_STREAM_PROPERTIES pStreamParams, + OUT PDRHANDLE pStreamHandle, + IN PCDR_ASYNC_INFO pAsyncInfo + ); + +/*++ + +DrOpenStreamFromStreamInfo + +Binary verion of open stream: opening based on DR_STREAM +Data structure has the same format as the data returned by DrGetStreamInformation +Upon return from DrOpenStreamFromStreamInfo that data structure may be freed. + +Parameters: + + pStreamName UTF8 encoded fully qualified cosmos URI, file URI, UNC name, or local file name. + Cannot be NULL. + + pStreamInfo Pointer to DR_STREAM structure. Cannot be NULL + + Flags Same as for DrOpenStream, except for DR_CREATE, which is unacceptable in this API + Note: DR_NO_REFRESH_INSTANCES is set by default (if no DR_REFRESH_ is provided) + + + pStreamHandle Pointer to variable that receives handle for opened stream. + NOTE: if operation is asynchronous, this location should be valid + until completion of operation is reported. + + pReserved Reserved. Must be zero. + + pAsyncInfo Pointer to DR_ASYNC_INFO structure. + See description of to DR_ASYNC_INFO structure. + +Example: + +DrError ret; + +// retrieve streaminfo +PDR_STREAM streaminfo; +ret = DrGetStreamInformation("cosmos://store/volume/dir1/dir2/stream2.ext", + 0, 0, DR_ALL_EXTENTS, &streaminfo, NULL); + +// asyncronously open existing stream for read +// use event for completion notification +DRHANDLE h; +DrError asyncRet; +DR_ASYNC_INFO asyncInfo; +ZeroMemory( &asyncInfo, sizeof( asyncInfo ) ); +asyncInfo.cbSize = sizeof( asyncInfo ); +asyncInfo.pOperationState = &asyncRet; +asyncInfo.Event = CreateEvent( NULL, TRUE, FALSE, NULL ); +ret = DrOpenStreamFromStreamInfo (NULL, "cosmos://store/volume/dir1/dir2/stream2.ext", + streaminfo, DR_OPEN | DR_READ, NULL, &h, &asyncInfo); +if ( ret == HRESULT_FROM_WIN32( ERROR_IO_PENDING ) ) { + WaitForSingleObject( asyncInfo.Event, INFINITE ); + if ( SUCCEEDED( asyncRet ) ) { + // Open completed successfully + } +} + +// release stream info +DrFreeMemory (streaminfo); + + +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrOpenStreamFromStreamInfo ( + IN DRSESSIONHANDLE Session, // may be NULL + IN PCSTR pStreamName, + IN PDR_STREAM pStreamInfo, + IN UINT Flags, + OUT PDRHANDLE pStreamHandle, + IN PVOID pReserved, // for future extensions, currently NULL + IN PCDR_ASYNC_INFO pAsyncInfo + ); + +/*++ +DrOpenSession + +Remarks: + + This API lets user to create an empty session handle. + +*/ + +DRYADAPI_EXT +DrError +DRYADAPI +DrOpenSession( + OUT PDRSESSIONHANDLE pSession + ); + +/*++ +DrCloseSession + +Remarks: + + This API lets user to close an existing session handle. + +*/ + +DRYADAPI_EXT +DrError +DRYADAPI +DrCloseSession( + IN DRSESSIONHANDLE Session + ); + +/*++ + +DrReadStream + +Parameters: + + StreamHandle Stream handle returned by DrOpenStream. + + pBuffer Pointer to the buffer that receives the data read from the stream. + + pBytesRead Pointer to variable containing size of the buffer. On return this + variable receives number of bytes read. + + Flags Optional, may be 0. + + DR_REFRESH_AGGRESSIVE (default) Refresh metadata from DRM if EOS + is detected. + DR_REFRESH_PASSIVE Do not refresh metadata from DRM if EOS + is detected. + + pReadPosition Pointer to DR_STREAM_POSITION structure. May be NULL. + Specifies position in the stream of the block to be read from. + If set to DR_INVALID_EXTENT_INDEX and DR_UNKNOWN_OFFSET then read is + performed from the current stream position and on return this variable + receives the position of the block read. + + pAsyncInfo Pointer to DR_ASYNC_INFO structure. + See description of to DR_ASYNC_INFO structure. + +Remarks: + + If pReadPosition is NULL, then current stream position is used as read position. + Current stream position is advanced if read is successful. + + If pReadPosition is not NULL, extent index in position structure is set to DR_INVALID_EXTENT_INDEX + and offset is set to DR_UNKNOWN_OFFSET, then current stream position is used as read position. + If read is successfull, then read position is saved into variable specifed by pReadPosition and + current stream position is advanced. + + If pReadPosition is not NULL, extent index in position structure is set valid extent index and + and offset is set to valid offset, then current stream position is set to position specified by + pReadPosition, then current read position is used as read position. + Current stream position is advanced if read is successful. + + Queueing TODO + + If none of refresh flags is not specified in Flags, then value passed to DrOpenStream is used. + +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrReadStream( + IN DRHANDLE StreamHandle, + OUT PVOID pBuffer, + IN OUT PSIZE_T pBytesRead, + IN UINT Flags, + IN OUT PDR_STREAM_POSITION pReadPosition, + IN PCDR_ASYNC_INFO pAsyncInfo + ); + + +/*++ + +DrAppendStream + +Parameters: + + StreamHandle Stream handle returned by DrOpenStream. + + pData Pointer to the buffer containing data to be appended to the stream. + + DataSize Number of bytes to be appened to the stream. + + Flags Specifies if append should use fixed offset and if partials append is + allowed. Flags can be combined. + + + DR_FIXED_OFFSET_APPEND append at current offset + //TODO: DR_SEAL, DR_UPDATE_DRM + + + pAppendPosition Pointer to DR_STREAM_POSITION, may be NULL. + If not NULL and append succeeds, then variable receives offset and base extent + of appended data block. + + pBytesAppended Pointer to variable to receive number of bytes appended to the stream. + May be NULL. + + pAsyncInfo Pointer to DR_ASYNC_INFO structure. + See description of DR_ASYNC_INFO structure. + +Remarks: + + If DR_FIXED_OFFSET_APPEND flag is not specified, then data block is appended at the end of stream, + possibly more then one time. If there is a need to seal an extent during this process and + pAppendPosition is NULL, than data block is appended to the next extent without waiting + for the sealing of previous extent to be completed. + + If DR_FIXED_OFFSET_APPEND flag is specified, than data block is appended to the stream at offset set + by DrSetStreamPostition. Operation fails if offset is not equal to current length of the stream + since it's not possible to overwrite data appended to the stream or to write beyond end of stream. + If append succeeds then current stream position is incremented by the number of bytes appended. + + If DR_PARTIAL_APPEND flag is specified, then partial append is enabled and possibly only first part of + the specified data block is appended. Variable pointed at by pBytesAppended receives number of bytes + actually appended to the stream. + + TODO: Queueing + + TODO: IN/OUT append pos + +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrAppendStream( + IN DRHANDLE StreamHandle, + IN PCVOID pData, + IN SIZE_T DataSize, + IN UINT Flags, + IN OUT PDR_STREAM_POSITION pAppendPosition, // non-NULL pAppendPosition must be help around until completion + OUT PSIZE_T pBytesAppended, // non-NULL pBytesAppended must be held around until completion + IN PCDR_ASYNC_INFO pAsyncInfo + ); + +typedef struct tagDR_BUFFERLIST +{ + PCVOID pData; + SIZE_T DataSize; +} DR_BUFFERLIST; +typedef DR_BUFFERLIST *PDR_BUFFERLIST; +typedef const DR_BUFFERLIST *PCDR_BUFFERLIST; + +DRYADAPI_EXT +DrError +DRYADAPI +DrAppendStreamBufferList( + IN DRHANDLE StreamHandle, + IN PCDR_BUFFERLIST pBufferList, + IN SIZE_T BufferCount, + IN UINT Flags, + IN OUT PDR_STREAM_POSITION pAppendPosition, // non-NULL pAppendPosition must be help around until completion + OUT PSIZE_T pBytesAppended, // non-NULL pBytesAppended must be held around until completion + IN OUT void *pReserved, // must be set to NULL for now + IN PCDR_ASYNC_INFO pAsyncInfo + ); + + +DRYADAPI_EXT +DrError +DRYADAPI +DrSetStreamPosition( + IN DRHANDLE StreamHandle, + IN UINT ExtentIndex, + IN UINT64 Offset + ); + + +DRYADAPI_EXT +DrError +DRYADAPI +DrGetStreamPosition( + IN DRHANDLE StreamHandle, + IN UINT Flags, + OUT PDR_STREAM_POSITION pPosition + ); + +/*++ + +DrTranslateStreamPosition + + Translates a stream position so that it is relative to another extent's base offset (or 0) + +Parameters + + StreamHandle Stream handle returned by DrOpenStream + + pSourcePosition Position you want to translate + + pDestExtentIndex The extent index to translate relative to. 0 for offset relative to the stream. + + pDestOffset Pointer to the offset that will be output. + Must not be NULL. + + Flags Reserved. Must be 0. + +Remarks: + + To translate into absolute stream offset pass 0 for DestExtentIndex. + + NOTE: this operation retrieves information from internal metadata cache. If you use this API to + translate extent offset returned by DrAppendStream, then it should always return correct result + since cache is up-to-date. In other scenarios in case if API fails you may need to refresh stream + metadata by calling DrGetStreamLength with DR_REFRESH_AGGRESSIVE flag and retry DrTranslateStreamPosition + after DrGetStreamLength succeeds. +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrTranslateStreamPosition( + IN DRHANDLE StreamHandle, + IN PDR_STREAM_POSITION pSourcePosition, + IN UINT DestExtentIndex, + OUT PUINT64 pDestOffset, + IN UINT Flags + ); + + +/*++ + +DrGetStreamLength + + Gets the length of the stream. + +Parameters + + StreamHandle Stream handle returned by DrOpenStream + + pLength Pointer to the output length variable. + Must not be NULL. + + Flags One of (or pass 0 for default value): + + DR_REFRESH_AGGRESSIVE (default) + - visit server to find out latest known length + DR_REFRESH_PASSIVE + - return length from local cache if available otherwise + visit server to find out latest known length + DR_REFRESH_FROM_CACHE + - return length from local cache + fail if not available + + pAsyncInfo Pointer to DR_ASYNC_INFO structure. + See description of DR_ASYNC_INFO structure. + +Remarks: + + Call with DR_REFRESH_FROM_CACHE always returns immediately. + Note that if there are ongoing appends to this stream issue from this or + other client this this function may not return precise length. + +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrGetStreamLength( + IN DRHANDLE StreamHandle, + OUT PUINT64 pLength, + IN UINT Flags, + IN PCDR_ASYNC_INFO pAsyncInfo + ); + + +DRYADAPI_EXT +DrError +DRYADAPI +DrCloseHandle( + IN DRHANDLE StreamHandle + ); + + +/*++ + +DrDelete + + Deletes on recycles stream. + +Parameters: + + pStreamName Pointer to UTF8 name of the stream to be deleted. + + Flags Specifies if stream should be deleted on recycled. + + DR_DELETE_FORCE delete stream + + pAsyncInfo Pointer to DR_ASYNC_INFO structure. + See description of to DR_ASYNC_INFO structure. + +Remarks: + + By default this operation recycles stream. Stream is renamed to + streamName#versionNumber, where versionNumber is integer, and + stream expiration time is set to recycle stream timeout. Recycled + stream timeout is specified in DRM configuration file. + + To immediately delete stream set Flags to DR_DELETE_FORCE. + + @TODO: create DrDeleteEx API to return new recycled stream name. + +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrDelete( + IN PCSTR pStreamName, + IN UINT Flags, + IN PCDR_ASYNC_INFO pAsyncInfo + ); + +DRYADAPI_EXT +DrError +DRYADAPI +DrDeleteEx( + IN DRSESSIONHANDLE Session, + IN PCSTR pStreamName, + IN UINT Flags, + IN PCDR_ASYNC_INFO pAsyncInfo + ); + +/*++ + +DrRename and DrRenameEx + +Parameters: + + Session Session handle. May be NULL. + + pOldName Pointer to UTF8 name of the stream to be renamed to pNewName. + + pNewName Pointer to new UTF8 name of the pOldName stream. + + Flags A combination of next bit values or zero. + + DR_RENAME_ALLOW_SUBSTITUTION Delete pNewName stream if it exists + before renaming pOldName stream. + + DR_RENAME_REQUIRE_SUBSTITUTION Replace existing pNewStream stream + with pOldStream stream. + + pAsyncInfo Pointer to DR_ASYNC_INFO structure. + See description of to DR_ASYNC_INFO structure. + +Remarks: + + If DR_RENAME_ALLOW_SUBSTITUTION and DR_RENAME_REQUIRE_SUBSTITUTION flags are not set then + operation fails if pNewName stream already exists. + + If either DR_RENAME_ALLOW_SUBSTITUTION or DR_RENAME_REQUIRE_SUBSTITUTION flag is set then + this function atomically deletes pNewName stream and renames pOldName stream to pNewName. + If DR_RENAME_REQUIRE_SUBSTITUTION flag is set and pNewName stream does not exist then + operation fails with DrError_StreamNotFound. + + In substitution mode destination stream may be specified by GUID using .streamid name. + In this case source stream is renamed to original destination stream name. + +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrRename( + IN PCSTR pOldName, + IN PCSTR pNewName, + IN UINT Flags, + IN PCDR_ASYNC_INFO pAsyncInfo + ); + +DRYADAPI_EXT +DrError +DRYADAPI +DrRenameEx( + IN DRSESSIONHANDLE Session, + IN PCSTR pOldName, + IN PCSTR pNewName, + IN UINT Flags, + IN PCDR_ASYNC_INFO pAsyncInfo + ); + +/*++ + +DrMultiModifyStream + +Parameters: + + pEntries Array of DR_MULTIMODIFY_STREAM_ENTRY structures, one for each stream operation + to perform + + nEntries Number of entries in pEntries + + Flags Reserved. Must be 0. + + pAsyncInfo Pointer to DR_ASYNC_INFO structure. + See description of to DR_ASYNC_INFO structure. + +Remarks: + + This API allows you to atomically rename, delete, and set the properties of multiple + streams on a single volume. + + It is allowed to rename streams over each other concurrently; e.g., A->B and B->A, or + A->B, B->C, C->A. + + NOTE: in future cosmos versions, the partitioning of volumes will be less clear. To ensure + compatibility with future versions, you should limit operations to a single directory. +--*/ +DrError +DRYADAPI +DrMultiModifyStreamEx( + IN DRSESSIONHANDLE Session, + IN PCDR_MULTIMODIFY_STREAM_ENTRY pEntries, + IN UINT nEntries, + IN UINT Flags, + IN PCDR_ASYNC_INFO pAsyncInfo + ); + +DrError +DRYADAPI +DrMultiModifyStream( + IN PCDR_MULTIMODIFY_STREAM_ENTRY pEntries, + IN UINT nEntries, + IN UINT Flags, + IN PCDR_ASYNC_INFO pAsyncInfo + ); + +/*++ + +DrConcatenateStream + +Parameters: + + pDestName Pointer to UTF8 name of the stream to receive extents from source stream. + + pSrcName Pointer to UTF8 name of the stream to share extents with the destination stream. + + ExtentIndex Index of extent range to share. Pass 0 to start from the beginning of the stream. + + ExtentCount Number of extents range to share. Pass DR_ALL_EXTENTS to share all extents from + the ExtentIndex to the end of the source stream. + + pDestExtentIndex If not NULL, pointer to the expected extent index where the new extents will be appended to the destination + stream. If NULL, or if the value is DR_UNKOWN_EXTENT_INDEX, then the current end of the stream is used; otherwise, + the request fails with DrError_AppendNotAtEnd if the given value does not match the current length of the + stream. On return, if not NULL, the value is updated with the actual index at which the extents were appended. + Note: On older clusters, the returned value may be DR_UNKNOWN_EXTENT_INDEX if the DRM does not + support the feature. + + Flags Reserved. Must be 0. + + pAsyncInfo Pointer to DR_ASYNC_INFO structure. + See description of to DR_ASYNC_INFO structure. + +Remarks: + + Appends extent range specified by ExtentIndex and ExtentCount for pSrcName stream to the end of the + pDestName stream. No actual copy is done, both streams point to the same extents. Extent is kept + available while all streams referencing it are present. + +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrConcatenateStream( + IN PCSTR pDestName, + IN PCSTR pSrcName, + IN UINT ExtentIndex, + IN UINT ExtentCount, + IN OUT PUINT pDestExtentIndex, + IN UINT Flags, + IN PCDR_ASYNC_INFO pAsyncInfo + ); + +/*++ + +DrConcatenateStreamEx + +Parameters: + + Session Session handle, may be NULL. + + pDestName Pointer to UTF8 name of the stream to receive extents from source stream. + + pDestExtentIndex If not NULL, pointer to the expected extent index where the new extents will be appended to the destination + stream. If NULL, or if the value is DR_UNKOWN_EXTENT_INDEX, then the current end of the stream is used; otherwise, + the request fails with DrError_AppendNotAtEnd if the given value does not match the current length of the + stream. On return, if not NULL, the value is updated with the actual index at which the extents were appended. + Note: On older clusters, the returned value may be DR_UNKNOWN_EXTENT_INDEX if the DRM does not + support the feature. + + pSrcRanges Pointer to array of DR_RANGE structures describing source ranges. See DR_RANGE. + There is no need to keep this data in memory for the duration of async call, i.e. it can be allocate on the stack. + + cSrcRanges Number of DR_RANGE stuctures in pSrcRanges array. + + Flags Zero or combination of next values: + + DR_CONCATENATE_ALLOW_CREATE_NEW create new stream if pDestName stream does not exists. + + DR_CONCATENATE_REQUIRE_CREATE_NEW create new stream, fails if pDestName stream does not exists. + + pAsyncInfo Pointer to DR_ASYNC_INFO structure. + See description of to DR_ASYNC_INFO structure. + +Remarks: + + Appends extent range specified by ExtentIndex and ExtentCount for pSrcName stream to the end of the + pDestName stream. No actual copy is done, both streams point to the same extents. Extent is kept + available while all streams referencing it are present. + +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrConcatenateStreamEx( + IN DRSESSIONHANDLE Session, + IN PCSTR pDestName, + IN OUT PUINT pDestExtentIndex, + IN PCDR_RANGE pSrcRanges, + IN SIZE_T cSrcRanges, + IN UINT Flags, + IN PCDR_ASYNC_INFO pAsyncInfo +); + +/*++ + +DrGetStreamInformation +DrGetStreamInformationEx + +Retrieves stream information for stream with given name. + +Parameters: + + pStreamName UTF8 encoded fully qualified cosmos URI, file URI, UNC name, or local file name. + Cannot be NULL. + + Reserved0 Reserved. Must be zero. + + StartExtentIndex Index of first stream extent to return information about. + Pass 0 for all extents. + + NumberOfExtentsToFill Number of extents to return information about. + Pass 0 for no extent information, DR_ALL_EXTENTS for information about all + extents in the stream. + + ppStreamInfo Pointer to a variable to receive pointer to DR_STREAM structure. + Returned structure must be freed by call to DrFreeMemory. + + pAsyncInfo Pointer to DR_ASYNC_INFO structure. + See description of to DR_ASYNC_INFO structure. +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrGetStreamInformation( + IN PCSTR pStreamName, + IN UINT64 Reserved0, + IN UINT StartExtentIndex, + IN UINT NumberOfExtentsToFill, + OUT PDR_STREAM* ppStreamInfo, + IN PCDR_ASYNC_INFO pAsyncInfo +); + +DRYADAPI_EXT +DrError +DRYADAPI +DrGetStreamInformationEx( + IN DRSESSIONHANDLE Session, + IN PCSTR pStreamName, + IN UINT64 Reserved0, + IN UINT StartExtentIndex, + IN UINT NumberOfExtentsToFill, + OUT PDR_STREAM* ppStreamInfo, + IN PCDR_ASYNC_INFO pAsyncInfo +); + +/*++ +DrGetStreamInformationByHandle + +Retrieves stream information for stream with given handle. + +See DrGetStreamInformation for parameter description. +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrGetStreamInformationByHandle( + IN DRHANDLE StreamHandle, + IN UINT64 Reserved0, + IN UINT StartExtentIndex, + IN UINT NumberOfExtentsToFill, + OUT PDR_STREAM* ppStreamInfo, + IN PCDR_ASYNC_INFO pAsyncInfo +); + + +/*++ + +DrSetStreamProperties + +Parameters: + + pStreamName Pointer to UTF8 name of the stream. + + pStreamProperties Pointer to DR_STREAM_PROPERTIES structure containing properties to change. + + pAsyncInfo Pointer to DR_ASYNC_INFO structure. + See description of to DR_ASYNC_INFO structure. + +Remarks: + + Attempts to changes stream properties such as expiration time or read-only attribute. + + Set pStreamProperties->ExpirePeriod to number of 100 nanoseconds ticks to change stream's + expiration time (resulting expiration time is calculated as server current time + pStreamProperties->ExpirePeriod). + Set to 0 leave stream's expiration time intact. + Set to DR_INIFINTE to reset stream expiration time to never expire. + NOTE: for read only stream it's only possible to increase expiration period. Attempt to srink it fails + with DrError_StreamReadOnly error. + + pStreamProperties->ReadOffsetHint is not used and should be set to 0. + + pStreamProperties->FlagsMask specifies flags to be modified and pStreamProperties->Flags specifies new flag values. + Set pStreamProperties->FlagsMask to 0 to skip flag modification. + + Available flags: + DR_STREAM_READ_ONLY + DR_STREAM_SEALED + + NOTE: Flag values applied before expiration period. If you attempt to set read-only flag and shrink expiration + period in one call, stream becomes read-only and expiration period is not modified. Such call fails + with DrError_StreamReadOnly error. + +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrSetStreamProperties( + IN PCSTR pStreamName, + IN PCDR_STREAM_PROPERTIES pStreamProperties, + IN PCDR_ASYNC_INFO pAsyncInfo +); + +// +// Extended version of setting stream property API which caller can attach a session handle +// + +DRYADAPI_EXT +DrError +DRYADAPI +DrSetStreamPropertiesEx( + IN DRSESSIONHANDLE Session, + IN PCSTR pStreamName, + IN PCDR_STREAM_PROPERTIES pStreamProperties, + IN PCDR_ASYNC_INFO pAsyncInfo +); + +/*++ + +DrSetStreamProperties + +Parameters: + + StreamHandle Stream handle returned by DrOpenStream. + + pStreamProperties Pointer to DR_STREAM_PROPERTIES structure containing properties to change. + + pAsyncInfo Pointer to DR_ASYNC_INFO structure. + See description of to DR_ASYNC_INFO structure. + +Remarks: + + See remarks for DrSetStreamProperties function. + +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrSetStreamPropertiesByHandle( + IN DRHANDLE StreamHandle, + IN PCDR_STREAM_PROPERTIES pStreamProperties, + IN PCDR_ASYNC_INFO pAsyncInfo +); + + +DRYADAPI_EXT +DrError +DRYADAPI +DrSetStreamOption( + IN DRHANDLE StreamHandle, + IN UINT StreamOption, + IN PCVOID pData, + IN SIZE_T cbData +); + + +DRYADAPI_EXT +DrError +DRYADAPI +DrGetStreamOption( + IN DRHANDLE StreamHandle, + IN UINT StreamOption, + OUT PVOID pBuffer, + IN SIZE_T cbBuffer +); + + +// +// Directory APIs +// + +/*++ + +DrEnumerateDirectory + +Enumerates content of directory. + +Parameters: + + pDirectoryName UTF8 encoded fully qualified cosmos directory URI, file system directory URI, + UNC name, or local directory name. + + Flags + + + ppDirectoryInfo Pointer to variable that receives + + pAsyncInfo Pointer to DR_ASYNC_INFO structure. + See description of to DR_ASYNC_INFO structure. +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrEnumerateDirectory( + IN PCSTR pDirectoryName, + IN UINT Flags, + OUT PDR_DIRECTORY* ppDirectoryInfo, + IN PCDR_ASYNC_INFO pAsyncInfo + ); + +// +// Extended version of enumerating directory API which caller can attach a session handle +// + +DRYADAPI_EXT +DrError +DRYADAPI +DrEnumerateDirectoryEx( + IN DRSESSIONHANDLE Session, + IN PCSTR pDirectoryName, + IN UINT Flags, + OUT PDR_DIRECTORY* ppDirectoryInfo, + IN PCDR_ASYNC_INFO pAsyncInfo + ); + +// +// Volume APIs +// + +DRYADAPI_EXT +DrError +DRYADAPI +DrGetVolumeInformation( + IN PCSTR pVolumeName, + OUT PDR_VOLUME* ppVolumeInfo, + IN PCDR_ASYNC_INFO pAsyncInfo +); + +// +// Extended version of Volume API which caller can attach a session handle +// + +DRYADAPI_EXT +DrError +DRYADAPI +DrGetVolumeInformationEx( + IN DRSESSIONHANDLE Session, + IN PCSTR pVolumeName, + OUT PDR_VOLUME* ppVolumeInfo, + IN PCDR_ASYNC_INFO pAsyncInfo +); + +// +// Auxiliary APIs +// + +/*++ + +DrGetErrorMessage + +Parameters: + + StatusCode Status code + + pBuffer Pointer to a buffer + + cbBuffer Size of the buffer in bytes + +Remarks: + + Fills in pBuffer with UTF8 encoded error message corresponding to StatusCode. + If cbBuffer bytes is not enough to save the whole message + returns HRESULT_FROM_WIN32( ERROR_INSUFFICIENT_BUFFER ). + +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrGetErrorMessage( + IN DrError StatusCode, + OUT PCHAR pBuffer, + IN SIZE_T cbBuffer + ); + + +/*++ + +DrSetParentId + +Parameters: + + ParentId Parent operation Id for the next client API call. + +Remarks: + + Sets parent operation id for the subsequent call to the other client API function + on the same thread. + Provided id is logged at the start of the next client API call. It makes it possible to + relate log statements for user operations and client library operations. + +--*/ +DRYADAPI_EXT +void +DRYADAPI +DrSetParentId( + IN UINT ParentId + ); + +/*++ + +DrSetOpId + +Parameters: + + OpId Operation Id for the next client API call. + +Remarks: + + Sets operation id for the subsequent call to the other client API function on the same thread. + + MAKE SURE that you always call DrOpidNew to generate the operation id, so that it is + guaranteed to be unique in the process. + + NEVER generate the operation id by youself or any other way. +--*/ +DRYADAPI_EXT +void +DRYADAPI +DrSetOpId( + IN UINT OpId + ); + +/*++ + +DrComposeExtentStreamInfo + +Parameters: + + pCluster, + pVolume cluster and volume names where the extent belongs + + pExtentInfo Extent info of the extent to become stream + + ppStreamName Pointer to a variable to receive .extentid stream name + + ppStreamInfo Pointer to a variable to receive pointer to DR_STREAM structure. + + Returned structure must be freed by call to DrFreeMemory, + which also frees *ppStreamName. + DO NOT free *ppStreamName alone! + +Remarks: + + Allocate and fill in fake stream information for the extent information provided here. + *ppStreamName and *ppStreamInfo may be passed as parameters straight into + DrOpenStreamFromStreamInfo + +--*/ +DRYADAPI_EXT +DrError +DRYADAPI +DrComposeExtentStreamInfo ( + IN PCSTR pCluster, + IN PCSTR pVolume, + IN PDR_EXTENT pExtentInfo, + OUT PCSTR * ppStreamName, + OUT PDR_STREAM * ppStreamInfo + ); + + +#endif // JC if 0 + + +#pragma pack( pop ) + +#pragma warning( pop ) + +#if defined(__cplusplus) +} +#endif + diff --git a/DryadVertex/VertexHost/system/classlib/include/DryadTags.h b/DryadVertex/VertexHost/system/classlib/include/DryadTags.h new file mode 100644 index 0000000..9979a3a --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DryadTags.h @@ -0,0 +1,57 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// Tags used for BeginTag/EndTag + +// There must *only* be DEFINE_DRYADTAG directives in this file + +DEFINE_DRYADTAG(DryadTag_InvalidTag, 0, "InvalidTag", MetaData) // An invalid tag value + +DEFINE_DRYADTAG(DryadTag_HostNameList, 608, "HostNameList", MetaData) +DEFINE_DRYADTAG(DryadTag_HostAndPort, 9011, "HostAndPort", MetaData) + +DEFINE_DRYADTAG(DryadTag_InputChannelDescription, 10000, "InputChannelDescription", InputChannelDescription) +DEFINE_DRYADTAG(DryadTag_OutputChannelDescription, 10001, "OutputChannelDescription", OutputChannelDescription) +DEFINE_DRYADTAG(DryadTag_VertexProcessStatus, 10002, "VertexProcessStatusBlock", VertexProcessStatus) +DEFINE_DRYADTAG(DryadTag_VertexStatus, 10003, "VertexStatusBlock", VertexStatus) +DEFINE_DRYADTAG(DryadTag_VertexCommand, 10004, "VertexCommandBlock", VertexCommandBlock) +DEFINE_DRYADTAG(DryadTag_ItemStart, 10005, "ItemStartBlock", MetaData) +DEFINE_DRYADTAG(DryadTag_ItemEnd, 10006, "ItemEndBlock", MetaData) +DEFINE_DRYADTAG(DryadTag_ChannelMetaData, 10007, "ChannelMetaData", MetaData) +DEFINE_DRYADTAG(DryadTag_VertexMetaData, 10008, "VertexMetaData", MetaData) +DEFINE_DRYADTAG(DryadTag_ArgumentArray, 10009, "ArgumentArray", MetaData) +DEFINE_DRYADTAG(DryadTag_VertexArray, 10010, "VertexArray", MetaData) +DEFINE_DRYADTAG(DryadTag_VertexInfo, 10011, "VertexInfo", MetaData) +DEFINE_DRYADTAG(DryadTag_EdgeArray, 10012, "EdgeArray", MetaData) +DEFINE_DRYADTAG(DryadTag_EdgeInfo, 10013, "EdgeInfo", MetaData) +DEFINE_DRYADTAG(DryadTag_GraphDescription, 10014, "GraphDescription", MetaData) +DEFINE_DRYADTAG(DryadTag_RSCAReturnMachine, 10015, "ReturnMachine", MetaData) +DEFINE_DRYADTAG(DryadTag_RSCAEnqueueProcess, 10016, "EnqueueProcess", MetaData) +DEFINE_DRYADTAG(DryadTag_RSCAReportFailedMachine, 10017, "ReportFailedMachine", MetaData) +DEFINE_DRYADTAG(DryadTag_RSCADiscardProcess, 10018, "DiscardProcess", MetaData) +DEFINE_DRYADTAG(DryadTag_RSClientResponse, 10019, "RSClientResponse", MetaData) +DEFINE_DRYADTAG(DryadTag_RSClientRootProcessRequest, 10020, "RSClientRootProcessRequest", MetaData) +DEFINE_DRYADTAG(DryadTag_RSClientRootProcessResponse, 10021, "RSClientRootProcessResponse", MetaData) +DEFINE_DRYADTAG(DryadTag_RSClientInitializeRequest, 10022, "RSClientInitializeRequest", MetaData) +DEFINE_DRYADTAG(DryadTag_RSClientActionRequest, 10023, "RSClientActionRequest", MetaData) +DEFINE_DRYADTAG(DryadTag_RSClientStatusRequest, 10024, "RSClientStatusRequest", MetaData) +DEFINE_DRYADTAG(DryadTag_RSClientMatch, 10025, "RSClientMatch", MetaData) +DEFINE_DRYADTAG(DryadTag_RSRevocation, 10026, "RSRevocation", MetaData) +DEFINE_DRYADTAG(DryadTag_RSClientCommand, 10027, "RSClientCommand", MetaData) diff --git a/DryadVertex/VertexHost/system/classlib/include/DryadTagsDef.h b/DryadVertex/VertexHost/system/classlib/include/DryadTagsDef.h new file mode 100644 index 0000000..483468e --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/DryadTagsDef.h @@ -0,0 +1,33 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#ifdef DEFINE_DRYADTAG +#undef DEFINE_DRYADTAG +#endif + +#define DEFINE_DRYADTAG(var, value, tagName, tagType) \ + static const UInt16 var = value; + +#include "dryadtags.h" + +#undef DEFINE_DRYADTAG + diff --git a/DryadVertex/VertexHost/system/classlib/include/Interlocked.h b/DryadVertex/VertexHost/system/classlib/include/Interlocked.h new file mode 100644 index 0000000..47e7dfd --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/Interlocked.h @@ -0,0 +1,150 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once +#include + +// The Interlocked64 APIs are only available on Windows Server 2003 +// and later. Even though the Interlocked64 operation are declared in +// winbase.h, they are not in kernel32, so the application breaks at +// runtime. + +// Also, AMD64 machines do not have InterlockedCompareExchange64() in +// kernel32, since the function is an intrinsic, so it must be called +// directly. +// +// This file defines Interlocked64 operations on 32 bit platforms. +// On 64 bit platforms (AMD or IA64), the inlined functions +// call the intrinsic functions. + +namespace Interlocked +{ +#if defined(_M_AMD64) || defined(_M_IA64) + + LONGLONG inline CompareExchange64(LONGLONG volatile* destination, + LONGLONG comparand, + LONGLONG exchange) + { + return ::InterlockedCompareExchange64(destination, comparand, exchange); + } + + LONGLONG inline Increment64(LONGLONG volatile *Addend) + { + return ::InterlockedIncrement64(Addend); + } + + LONGLONG inline Decrement64(LONGLONG volatile *Addend) + { + return InterlockedDecrement64(Addend); + } + + LONGLONG inline Exchange64(LONGLONG volatile *Target, LONGLONG Value) + { + return ::InterlockedExchange64(Target, Value); + } + + LONGLONG inline ExchangeAdd64(LONGLONG volatile *Addend, LONGLONG Value) + { + return ::InterlockedExchangeAdd64(Addend, Value); + } + + LONGLONG inline Read64(LONGLONG volatile *target) + { + return *target; + } + +#else + + LONGLONG inline __cdecl CompareExchange64(LONGLONG volatile* destination, + LONGLONG comparand, + LONGLONG exchange) + { + __asm + { + mov esi, [destination] + mov ebx, dword ptr [comparand] + mov ecx, dword ptr [comparand + 4] + mov eax, dword ptr [exchange] + mov edx, dword ptr [exchange + 4] + lock cmpxchg8b [esi] + }; + } + + // Copied from winbase.h + LONGLONG inline Increment64(LONGLONG volatile *Addend) + { + LONGLONG Old; + + do { + Old = *Addend; + } while (CompareExchange64(Addend, Old + 1, Old) != Old); + return Old + 1; + } + + // Copied from winbase.h + LONGLONG inline Decrement64(LONGLONG volatile *Addend) + { + LONGLONG Old; + + do { + Old = *Addend; + } while (CompareExchange64(Addend, Old - 1, Old) != Old); + return Old - 1; + } + + // Copied from winbase.h + LONGLONG inline Exchange64(LONGLONG volatile *Target, LONGLONG Value) + { + LONGLONG Old; + + do { + Old = *Target; + } while (CompareExchange64(Target, Value, Old) != Old); + + return Old; + } + + + // Copied from winbase.h + LONGLONG inline ExchangeAdd64(LONGLONG volatile *Addend, + LONGLONG Value) + { + LONGLONG Old; + + do { + Old = *Addend; + } while (CompareExchange64(Addend, Old + Value, Old) != Old); + + return Old; + } + + LONGLONG inline Read64(LONGLONG volatile *target) + { + // As far as I know, this is the only way to atomically read a + // 64 bit value on a 32 bit platform. + // InterlockedCompareExchange64 reads the target value + // atomically. If the value is 0, it sets the value to 0 and + // returns 0 (no-op) If the value is non-zero, it does not + // change the value and returns old value. + return CompareExchange64(target, 0, 0); + } + +#endif +}; diff --git a/DryadVertex/VertexHost/system/classlib/include/LogIds.h b/DryadVertex/VertexHost/system/classlib/include/LogIds.h new file mode 100644 index 0000000..824d9ff --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/LogIds.h @@ -0,0 +1,846 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +//JC Check this file for unnecessary content. + +//JCnamespace apsdk +//JC{ + +// Add your log entries (areas) here +// This list must be kept in sync with g_LogIDNames[] below +// The list must start at 0 and increase sequentially +typedef enum +{ + LogID_Logging, // The logging system can log stuff too! + LogID_MsnbotFileGen, // shared module used by different process like chunkbuilder, fex, RTDiscovery + LogID_Crawler, + LogID_CrawlerReceiver, + LogID_CrawlerFetcher, + LogID_CrawlerScheduler, + LogID_CrawlerDownloadLog, + LogID_Counters, + LogID_Parser, + LogID_IndexBuild, + LogID_IndexGeneration, + LogID_QueryProcessor, + LogID_IndexMerge, + LogID_IndexGeneral, + LogID_IndexQueryLog, + LogID_IndexServer, + LogID_MCPClient, + LogID_FileSync, + LogID_GetAndSyncChunkManifest, + LogID_ServiceManager, + LogID_DynamicRank, + LogID_Caption, + LogID_Federation, + LogID_APRunService, + LogID_GetAndRunServices, + LogID_Categorization, + LogID_StaticRanker, + LogID_StaticRankerBackup, + LogID_DRClient, + LogID_DRServer, + LogID_DRLib, + LogID_Common, + LogID_Aggregator, + LogID_Netlib, + LogID_CatCFSpecific, + LogID_SpamDetection, + LogID_CrawlerUrlTracker, + LogID_StaticRankerDetail, + LogID_FixRegKeys, + LogID_CreateCounter, + LogID_CRCHashComputation, + LogID_FEX, + LogID_Moonshot, + LogID_Fcslite, + LogID_FcsliteCache, + LogID_QualityOfPage, + LogID_UST, + LogID_SEARCH, + LogID_FSServer, + LogID_FSClient, + LogID_DPClient, + LogID_ENCARTA, + LogID_MUSIC, + LogID_LeaseVote, + LogID_Test, + LogID_CreateSocEvent, + LogID_SyncFiles, + LogID_Doodad, + LogID_DoodadFCtoLC, + LogID_DoodadLCtoFC, + LogID_DoodadWatchdog, + LogID_DOSDog, + LogID_StaticRankerTester, + LogID_LogCollector, + LogID_LogConvertor, + LogID_LogUploader, + LogID_LocalWD, + LogID_MachineWatchdog, + LogID_IPBlock, + LogID_DisableServerHeader, + LogID_IAdmin, + LogID_GENERICAP, + LogID_PSGateway, + LogID_HttpProxy, + LogID_CacheInfo, + LogID_FeedBuilder, + LogID_FeedBuilderData, + LogID_SysWD, + LogID_AlertService, + LogID_FexRLog, + LogID_FexHLog, + LogID_FexPreLog, + LogID_FexPostDataLog, + LogID_PsTFTP, + LogID_PsGatewayWatchdog, + LogID_CacheManager, + LogID_RSLLIB, + LogID_BurninTest, + LogID_QpsWatchdog, + LogID_ClusterBuilder, + LogID_MSRATermExtractor, + LogID_RealTimeCrawl, + LogID_RealTimeBuild, + LogID_DPS, + LogID_PerfCollector, + LogID_PerfAggregator, + LogID_ReplicatedDataSet, + LogID_WdWatcher, + LogID_DigiWatchdog, + LogID_OutBandMappingWatchdog, + LogID_PhantomWatchdog, + LogID_DigiProvisioning, + LogID_MiniSwitchProvisioning, + LogID_TLALoadBalancer, + LogID_Shopping, + LogID_CacheLogsProcessor, + LogID_DsSync, + LogID_DM, + LogID_DMWD, + LogID_DMAudit, + LogID_DMSQL, + LogID_PhraseTokenExtractor, + LogID_CQAClient, + LogID_CQABuilder, + LogID_CQAOQI, + LogID_CQAUserStore, + LogID_CQATagStore, + LogID_CQAQAStore, + LogID_TermExtractor, + LogID_EMS, + LogID_EMSStat, + LogID_Macro, + LogID_DRM, + LogID_EN, + LogID_EnWd, + LogID_ENTest, + LogID_Cosmos, + LogID_DrClient, + LogID_DrNameServer, + LogID_PN, + LogID_Dryad, + LogID_DryadTest, + LogID_DryadTestJournal, + LogID_DryadClient, + LogID_CockpitServer, + LogID_APWebServer, + LogID_QueryService, + LogID_MultiEnv, + LogID_Election, + LogID_APTextProtocol, + LogID_APMPClient, + LogID_APManagementProxy, + LogID_APProxyCommand, + LogID_FexLiveLog, + LogID_Wssync, + LogID_WmiLib, + LogID_PhantomLib, + LogID_AssetTool, + LogID_DeviceScanner, + LogID_DeadMachineMacScanner, + LogID_ServerTools, + LogID_BOOTP, + LogID_DHCP, + LogID_PXE, + LogID_BOOTPServer, + LogID_Duff, + LogID_IdCollect, + LogID_DryadWebServer, + LogID_DryadLogLoader, + LogID_DrmReplication, + LogID_CQACloseQuestions, + LogID_CQAPresser, + LogID_CQAPoller, + LogID_CQAStaticRanker, + LogID_CQATagTextProcessor, + LogID_CQAAlert, + LogID_CQAAlertStore, + LogID_FEX_CQA, + LogID_CQAWatchdogs, + LogID_MultiMedia, + LogID_ImageFetcher, + LogID_ImageFetcherCrawler, + LogID_ImageFetcherThumbnailer, + LogID_LiveSearchPane, + LogID_PsServer, + LogID_PsClient, + LogID_PsAgent, + LogID_RMAUtils, + LogID_LCDService, + LogID_DRV, + LogID_DRVParser, + LogID_CreateImage, + LogID_CreateImageFiles, + LogID_PhantomPowerStateWatchdog, + LogID_WebFeedDiscovery, + LogID_Answers, + LogID_Newsgroup, + LogID_TLARemoting, + LogID_HttpClient, + LogID_DiskTest, + LogID_DiskTestSequential, + LogID_DiskTestRandom, + LogID_MemoryTest, + LogID_WebFeedDiscoveryReader, + LogID_WebFeedDiscoveryProcessor, + LogID_WebFeedDiscoveryCrawlfileGenerator, + LogID_WebFeedDiscoveryRSSFetcher, + LogID_WebFeedDiscoveryProvider, + LogID_FileSyncIgnoreCRC, + LogID_FileSyncWrongCRC, + LogID_DSSlaveSync, + LogID_OSUpgrade, + LogID_DeviceUpgrade, + LogID_SyncOSImage, + LogID_ReliableRebootService, + LogID_GenericRepair, + LogID_NetlibCorruptPacket, + LogID_QuerySuggestion, + LogID_ISNManager, + LogID_ChunkLocator, + LogID_ISAgent, + LogID_ConfigViews, + LogID_CDGDominantImage, + LogID_DrJobManager, + LogID_DrJobManagerWd, + LogID_CachePropagator, + LogID_CachePropagatorClient, + LogID_MlToHosts, + LogID_TidyFS, + LogID_UrlTracker, + LogID_DNSWatchdog, + LogID_PsWatchdog, + LogID_WebFeedDiscoveryPinger, + LogID_MacroSuggestion, + LogID_AppAlertService, + LogID_TLARowStatus, + LogID_Sputnik, + LogID_GeneralClient, + LogID_QueryAugmenter, + LogID_Dictionary, + LogID_QueryStatistics, + LogID_QueryISNStatistics, + LogID_AnswersRLog, + LogID_AnswersALog, + LogID_AnswersPreLog, + LogID_IPTable, + LogID_DNSService, + LogID_FexMissingSnippetLog, + LogID_XifBuilder, + LogID_SharedModules, + LogID_DryadProfiler, + LogID_DryadProxy, + LogID_RTIndexCoverage, + LogID_Hardware, + LogID_PowerstripProvisioning, + LogID_PowerstripWatchdog, + LogID_Watchdog, + LogID_ApCommClientServer, + LogID_AdminGroup, + LogID_DMClient, + LogID_SvcMgrClient, + LogID_Speller, + LogID_FDR, + LogID_MT_HttpServer, + LogID_MT_Distributor, + LogID_MT_Cache, + LogID_MT_Translator, + LogID_MT_ResearchSdk, + LogID_MT_ModelServer, + LogID_MT_DB, + LogID_Localization, + LogID_MachineStatusClient, + LogID_ExpensiveQueryMonitor, + LogID_FeedChunkCleaner, + LogID_Hello, + LogID_AnswersPerformanceMonitor, + LogID_APMResult, + LogID_BackEndMachines, + LogID_LogService, + LogID_UserEvent, + LogID_FexImpressionLog, + LogID_RMAService, + LogID_RMAProtocol, + LogID_CockpitWatchdog, + LogID_DeviceCounterCollector, + LogID_DhcpMonitor, + LogID_DeviceCounter, + LogID_InstallService, + LogID_MinidumpSummary, + LogID_DeviceValidater, + LogID_ExpressRanker, + LogID_CrawlerUseChunk, + LogID_CrawlerDropLog, + LogID_CrawlerToBuilderLog, + LogID_ChunkBuilderArchive, + LogID_DocConverter, + LogID_ChunkBuilderCanary, + LogID_DynamicCrawler, + LogID_CacheClient, + LogID_CacheCommon, + LogID_QueryAlteration, + LogID_ResultAlteration, + LogID_ChunkSyncManager, + LogID_ChunkSyncManagerVerbose, + LogID_Extractor, + LogID_StaticRankManager, + LogID_ISNMonitorWatchdog, + LogID_VoxPopuliRatingLog, + LogID_KickServices, + LogID_DnsServer, + LogID_DnsServerRequest, + LogID_UpdateSecurityGroups, + LogID_SearchRepository, + LogID_SearchRepositoryCommon, + LogID_SearchRepositoryFELib, + LogID_SearchRepositoryProtocol, + LogID_SearchRepositoryLocator, + LogID_SearchRepositoryReadNode, + LogID_SearchRepositoryTest, + LogID_SearchRepositoryLog, + LogID_SearchRepositoryHttpServer, + LogID_SearchRepositoryWatchdog, + LogID_SearchRepositoryBackDoor, + LogID_SearchRepositoryMergeMgr, + LogID_SearchRepositoryMerger, + LogID_SearchRepositoryClient, + LogID_BackendQueryResult, + LogID_TLAPreLog, + LogID_APM_AlertLog, + LogID_WebAnswer, + LogID_QueryLog, + LogID_WatchDogClient, + LogID_WatchDogServer, + LogID_FcsXml, + LogID_FcsPostLog, + LogID_FcsErrorQueriesLog, + LogID_FcsLostQueriesLog, + LogID_DUI, + LogID_StufSync, + LogID_ISMerge, + LogID_WebMerge, + LogID_UserData, + LogID_Commerce, + LogID_XRank, + LogID_DocFetcher, + LogID_VoxPopuliGeneralLog, + LogID_VoxPopuliDecisionLog, + LogID_ChunkPublisher, + LogID_MediaProcessor, + LogID_KeywordExtractor, + LogID_IndexTracker, + LogID_Syslog, + LogID_Environment, + LogID_GenericAudit, + LogID_FcsAnswerQueryLog, + LogID_FeedsImportClient, + LogID_FeedsImportServer, + LogID_FeedsCapacityManager, + LogID_RealtimeVideo, + LogID_ThreadPoolLib, + LogID_TSFDR, + LogID_TSClient, + LogID_TSServer, + LogID_TSPipeline, + LogID_API, + LogID_API_Request, + LogID_API_Response, + LogID_ThrottleManagement, + LogID_SRCache, + LogID_DrmMirroring, + LogID_FrontDoor, + LogID_News, + LogID_FexCrash, + LogID_Webmaster, + LogID_WMSitemapSubmit, + LogID_WMLinkDataAggregator, + LogID_SqlRepository, + LogID_SqlRepositoryWatchDog, + LogID_LinkDataAggregator, + LogID_LinkDataAggregatorPreLog, + LogID_LinkDataAggregatorPostLog, + LogID_LinkDataAggregatorErrorLog, + LogID_WebmasterSitemapService, + LogID_WebmasterClog, + LogID_PhonebookResult, + LogID_PhonebookClassifier, + LogID_QueryDiagnostic, + LogID_DMExhibitCounter, + LogID_SyncAutopilotData, + LogID_AuthProxyServer, + LogID_AnswersXifBuilder, + LogID_FeedsDataProvider, + LogID_Scrounger, + LogID_Kif, + LogID_VariantConfigParser, + + // This will include all the logids added + // by partners that are using the logging API +#ifdef APSDK_CUSTOMIZED_LOGIDS_STRING +#error APSDK_CUSTOMIZED_LOGIDS_STRING has already be defined. +#endif +#define APSDK_CUSTOMIZED_LOGIDS_STRING(x) LogIDEx_##x +#include "LogIdsCustomized.h" +#undef APSDK_CUSTOMIZED_LOGIDS_STRING + + // This must be the last entry + LogID_Count +} LogID; + +extern char *g_LogIDNames[]; + +#ifdef DECLARE_DATA + +// This is an array of possible log entries (areas) +// This must be kept in sync with the LogID enumeration above +char *g_LogIDNames[] = +{ + "Logging", + "MsnbotFileGen", + "Crawler", + "CrawlerReceiver", + "CrawlerFetcher", + "CrawlerScheduler", + "CrawlerDownloadLog", + "Counters", + "Parser", + "IndexBuild", + "IndexGeneration", + "QueryProcessor", + "IndexMerge", + "IndexGeneral", + "IndexQueryLog", + "IndexServer", + "MCPClient", + "FileSync", + "GetAndSyncChunkManifest", + "ServiceManager", + "DynamicRank", + "Caption", + "Federation", + "APRunService", + "GetAndRunServices", + "Categorization", + "StaticRanker", + "StaticRankerBackup", + "DRClient", + "DRServer", + "DRLib", + "Common", + "Aggregator", + "Netlib", + "CatCFSpecific", + "SpamDetection", + "CrawlerUrlTracker", + "StaticRankerDetail", + "FixRegKeys", + "CreateCounter", + "CRCHashComputation", + "FEX", + "Moonshot", + "Fcslite", + "FcsliteCache", + "QualityOfPage", + "UrlSubmissionTool", + "Search", + "FSServer", + "FSClient", + "DPClient", + "Encarta", + "Music", + "LeaseVote", + "Test", + "CreateSocEvent", + "SyncFiles", + "Doodad", + "DoodadFCtoLC", + "DoodadLCtoFC", + "DoodadWatchdog", + "DOSDog", + "StaticRankerTester", + "Collector", + "Convertor", + "Uploader", + "LocalWD", + "MachineWatchdog", + "IPBlock", + "DisableServerHeader", + "IAdmin", + "GenericAP", + "PSGateway", + "HttpProxy", + "CacheInfo", + "FeedBuilder", + "FeedBuilderData", + "SysWatchdog", + "AlertService", + "FexRLog", + "FexHLog", + "FexPreLog", + "FexPostDataLog", + "PsTFTP", + "PsGatewayWatchdog", + "CacheManager", + "RSL", + "BurninTest", + "QpsWatchdog", + "ClusterBuilder", + "MSRATermExtractor", + "RealTimeCrawl", + "RealTimeBuild", + "ProfileStore", + "PerfCollector", + "PerfAggregator", + "ReplicatedDataSet", + "WatchdogsWatcher", + "DigiWatchdog", + "OutBandMappingWatchdog", + "PhantomWatchdog", + "DigiProvisioning", + "MiniSwitchProvisioning", + "TLALoadBalancer", + "Shopping", + "CacheLogsProcessor", + "DatasetSync", + "DeviceManager", + "DeviceManagerIncomingWatchdog", + "DMAudit", + "DeviceManagerSQL", + "PhraseTokenExtractor", + "CQAClient", + "CQABuilder", + "CQAOQI", + "CQAUserStore", + "CQATagStore", + "CQAQAStore", + "TermExtractor", + "EMS", + "EMSStat", + "Macro", + "DRM", + "EN", + "EnWd", + "ENTest", + "Dryad", + "DrClient", + "DrNameServer", + "PN", + "Dryad", + "DryadTest", + "DryadTestJournal", + "DryadClient", + "CockpitServer", + "APWebServer", + "QueryService", + "MultiEnv", + "Election", + "APTextProtocol", + "APMPClient", + "APProxy", + "APProxyCommand", + "FexLiveLog", + "wssynclog", + "WmiLib", + "PhantomLib", + "AssetTool", + "DeviceScanner", + "DeadMachineMacScanner", + "ServerTools", + "BOOTP", + "DHCP", + "PXE", + "BOOTPServer", + "Duff", + "IdCollect", + "DryadWebServer", + "DryadLogLoader", + "DrmReplication", + "CQACloseQuestions", + "CQAPresser", + "CQAPoller", + "CQAStaticRanker", + "CQATagTextProcessor", + "CQAAlert", + "CQAAlertStore", + "FexCQA", + "CQAWatchdogs", + "MultiMedia", + "ImageFetcher", + "ImageFetcherCrawler", + "ImageFetcherThumbnailer", + "LiveSearchPane", + "PsServer", + "PsClient", + "PsAgent", + "RMAUtils", + "LCDService", + "DRV", + "DRVParser", + "CreateImage", + "CreateImageFiles", + "PhantomPowerStateWatchdog", + "WebFeedDiscovery", + "Answers", + "Newsgroup", + "TLARemoting", + "HttpClient", + "DiskTest", + "DiskTestSequential", + "DiskTestRandom", + "MemoryTest", + "WebFeedDiscoveryReader", + "WebFeedDiscoveryProcessor", + "WebFeedDiscoveryCrawlfileGenerator", + "WebFeedDiscoveryRSSFetcher", + "WebFeedDiscoveryProvider", + "FileSyncIgnoreCRC", + "FileSyncWrongCRC", + "DSSlaveSync", + "OSUpgrade", + "DeviceUpgrade", + "SyncOSImage", + "ReliableRebootService", + "GenericRepair", + "NetlibCorruptPacket", + "QuerySuggestion", + "ISNManager", + "ChunkLocator", + "ISAgent", + "ConfigViews", + "CDGDominantImage", + "DryadJobManager", + "DryadJobManagerWd", + "CachePropagator", + "CachePropagatorClient", + "MlToHosts", + "TidyFS", + "UrlTracker", + "DNSWatchdog", + "PsWatchdog", + "WebFeedDiscoveryPinger", + "MacroSuggestion", + "AppAlertService", + "RowStatus", + "Sputnik", + "GeneralClient", + "QueryAugmenter", + "Dictionary", + "QueryStatistics", + "QueryISNStatistics", + "AnswersRLog", + "AnswersALog", + "AnswersPreLog", + "IPTable", + "DNSService", + "FexMissingSnippetLog", + "XifBuilder", + "SharedModules", + "DryadProfiler", + "DryadProxy", + "RTIndexCoverage", + "Hardware", + "PowerstripProvision", + "PowerstripWatchdog", + "Watchdog", + "ApCommClientServer", + "LocalAdminGroup", + "DMClient", + "SvcMgrClient", + "Speller", + "FDR", + "MT_HttpServer", + "MT_Distributor", + "MT_Cache", + "MT_Translator", + "MT_ResearchSdk", + "MT_ModelServer", + "MT_DB", + "Localization", + "MachineStatusClient", + "ExpensiveQueryMonitor", + "FeedChunkCleaner", + "Hello", + "AnswersPerformanceMonitor", + "APMResult", + "BackEndMachines", + "LogService", + "UserEvent", + "FexImpressionLog", + "RMAService", + "RMAProtocol", + "CockpitWatchdog", + "DeviceCounterCollector", + "DhcpMonitor", + "DeviceCounter", + "InstallService", + "MinidumpSummary", + "DeviceValidater", + "ExpressRanker", + "CrawlerUseChunk", + "CrawlerDropLog", + "CrawlerToBuilderLog", + "ChunkBuilderArchive", + "DocumentConverter", + "ChunkBuilderCanary", + "DynamicCrawler", + "CacheClient", + "CacheCommon", + "QueryAlteration", + "ResultAlteration", + "ChunkSyncManager", + "ChunkSyncManagerVerbose", + "Extractor", + "StaticRankManager", + "ISNMonitorWatchdog", + "VoxPopuliRatingLog", + "KickServices", + "DnsServer", + "DnsServerRequest", + "UpdateSecurityGroups", + "SearchRepository", + "SearchRepositoryCommon", + "SearchRepositoryFELib", + "SearchRepositoryProtocol", + "SearchRepositoryLocator", + "SearchRepositoryReadNode", + "SearchRepositoryTest", + "SearchRepositoryLog", + "SearchRepositoryHttpServer", + "SearchRepositoryWatchdog", + "SearchRepositoryBackDoor", + "SearchRepositoryMergeMgr", + "SearchRepositoryMerger", + "SearchRepositoryClient", + "BackendQueryResult", + "TLAPreLog", + "APMAlertLog", + "WebAnswer", + "QueryLog", + "WatchDogClient", + "WatchDogServer", + "FcsXml", + "FcsPostLog", + "FcsErrorQueriesLog", + "FcsLostQueriesLog", + "DUIProcess", + "StufSync", + "ISMerge", + "WebMerge", + "UserData", + "Commerce", + "XRank", + "DocumentFetcherService", + "VoxPopuliGeneralLog", + "VoxPopuliDecisionLog", + "ChunkPublisher", + "MediaProcessor", + "KeywordExtractor", + "IndexTracker", + "Syslog", + "Environment", + "GenericAudit", + "FcsAnswerQueryLog", + "FeedsImportClient", + "FeedsImportServer", + "FeedsCapacityManager", + "RTVideo", + "ThreadPoolLib", + "TSFDR", + "TSClient", + "TSServer", + "TSPipeline", + "API", + "APIRequest", + "APIResponse", + "ThrottleManagement", + "SearchRepositoryCache", + "DrmMirroring", + "FrontDoor", + "News", + "FexCrash", + "Webmaster", + "WMSitemapSubmit", + "WMLinkDataAggregator", + "SqlRepository", + "SqlRepositoryWatchDog", + "LinkDataAggregator", + "LinkDataAggregatorPreLog", + "LinkDataAggregatorPostLog", + "LinkDataAggregatorErrorLog", + "WebmasterSitemapService", + "WebmasterClog", + "PhonebookResult", + "PhonebookClassifier", + "QueryDiag", + "DMExhibitCounter", + "SyncAutopilotData", + "AuthProxyServer", + "AnswersXifBuilder", + "FeedsDataProvider", + "Scrounger", + "Kif", + "VariantConfigParser", + + // This will include all the logids added + // by partners that are using the logging API +#ifdef APSDK_CUSTOMIZED_LOGIDS_STRING +#error APSDK_CUSTOMIZED_LOGIDS_STRING has already be defined. +#endif +#define APSDK_CUSTOMIZED_LOGIDS_STRING(x) #x +#include "LogIdsCustomized.h" +#undef APSDK_CUSTOMIZED_LOGIDS_STRING + + // Last entry must be a NULL (for sanity checking) + NULL +}; + +#endif + +//JC} // namespace apsdk +//JC + +#ifdef USING_APSDK_NAMESPACE +using namespace apsdk; +#endif diff --git a/DryadVertex/VertexHost/system/classlib/include/LogIdsCustomized.h b/DryadVertex/VertexHost/system/classlib/include/LogIdsCustomized.h new file mode 100644 index 0000000..7987a1e --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/LogIdsCustomized.h @@ -0,0 +1,91 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//JC Check this file for unnecessary content. + +// +// This header file is designed to be included twice by the +// old autopilot logging code, once by new logging. +// +// This file meant to be used by +// non-autopilot code to define logids without having to +// modify any the autopilot SDK files (not really necessary +// with new logging). +// + +#ifndef USE_DRTRACE + +#if !defined(APSDK_CUSTOMIZED_LOGIDS_STRING) +#error This file is not meant to be included directly. APSDK_CUSTOMIZED_LOGIDS_STRING is not defined. +#endif + +// To add a new log id, append a line below in the form of: +// APSDK_CUSTOMIZED_LOGIDS_STRING(Foo), +// This will end up creating two things, an enum LogIDEx_Foo of enum type LogID and a string "Foo" +// as the value for g_LogIDNames[LogIDEx_Foo]. + + APSDK_CUSTOMIZED_LOGIDS_STRING(SDKBasicSampleWatchdog), + APSDK_CUSTOMIZED_LOGIDS_STRING(SDKSample), + APSDK_CUSTOMIZED_LOGIDS_STRING(AnswersMatchLog), + APSDK_CUSTOMIZED_LOGIDS_STRING(AnswersRemoteTLAPreLog), + APSDK_CUSTOMIZED_LOGIDS_STRING(VoxPopuliRatingLog), + APSDK_CUSTOMIZED_LOGIDS_STRING(DryadAudit), + APSDK_CUSTOMIZED_LOGIDS_STRING(DryadRS), + APSDK_CUSTOMIZED_LOGIDS_STRING(DryadCache), + APSDK_CUSTOMIZED_LOGIDS_STRING(DryadSimulator), + APSDK_CUSTOMIZED_LOGIDS_STRING(QUERY_PROCESSING), + APSDK_CUSTOMIZED_LOGIDS_STRING(QUERY_PARSING), + APSDK_CUSTOMIZED_LOGIDS_STRING(QUERY_REWRITING), + APSDK_CUSTOMIZED_LOGIDS_STRING(QUERY_CLASSIFICATION), + APSDK_CUSTOMIZED_LOGIDS_STRING(WatchDogClient), + APSDK_CUSTOMIZED_LOGIDS_STRING(WatchDogServer), + APSDK_CUSTOMIZED_LOGIDS_STRING(CacheSync), + APSDK_CUSTOMIZED_LOGIDS_STRING(AdServiceHttpRequestLog), + APSDK_CUSTOMIZED_LOGIDS_STRING(AdServiceHttpResponseLog), + APSDK_CUSTOMIZED_LOGIDS_STRING(AdServiceXMLResponseLog), + APSDK_CUSTOMIZED_LOGIDS_STRING(AnswersFrameworkDebug), + APSDK_CUSTOMIZED_LOGIDS_STRING(NoCodeUser), +#else + +// +// With new logging, just define the value (or, better yet, do it somewhere in your project) +// + +#define LogIDEx_SDKBasicSampleWatchdog "SDKBasicSampleWatchdog" +#define LogIDEx_SDKSample "SDKSample" +#define LogIDEx_AnswersMatchLog "AnswersMatchLog" +#define LogIDEx_AnswersRemoteTLAPreLog "AnswersRemoteTLAPreLog" +#define LogIDEx_VoxPopuliRatingLog "VoxPopuliRatingLog" +#define LogIDEx_DryadAudit "DryadAudit" +#define LogIDEx_DryadRS "DryadRS" +#define LogIDEx_DryadCache "DryadCache" +#define LogIDEx_DryadSimulator "DryadSimulator" +#define LogIDEx_QUERY_PROCESSING "QUERY_PROCESSING" +#define LogIDEx_QUERY_PARSING "QUERY_PARSING" +#define LogIDEx_QUERY_REWRITING "QUERY_REWRITING" +#define LogIDEx_QUERY_CLASSIFICATION "QUERY_CLASSIFICATION" +#define LogIDEx_AdServiceHttpRequestLog "AdServiceHttpRequestLog" +#define LogIDEx_AdServiceHttpResponseLog "AdServiceHttpResponseLog" +#define LogIDEx_AdServiceXMLResponseLog "AdserviceXMLResponseLog" +#define LogIDEx_AnswersFrameworkDebug "AnswersFrameworkDebug" +#define LogIDEx_NoCodeUser "NoCodeUser" + +#endif + diff --git a/DryadVertex/VertexHost/system/classlib/include/LogTagIds.h b/DryadVertex/VertexHost/system/classlib/include/LogTagIds.h new file mode 100644 index 0000000..6f2bb18 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/LogTagIds.h @@ -0,0 +1,612 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//JC Check this file for unnecessary content. + +// +// This file lists all the known tag values for logging. +// +// This file is included twice, with DeclareTag() defined to do different things, to initialize +// different data structures with the logging code. +// +// Add any new named log tags to this file +// The IDs MUST be sequential +// Update LogTag_End if you add a new one +// +// The value in quotes is the name under which it will appear in the log file; +// e.g. Filename="Foo" +// + +// Use these as generic log tags parameters if you don't feel it makes sense to add a new +// log tag entry. There really is no real overhead for adding more, though, so if you're +// going to log the same data item in more than a few places, you probably want to consider +// making a log tag for it, in case we want WatchDog to be able to key off it. +DeclareTag(LogTag_Int1, "Int1", LogTagType_Int32), +DeclareTag(LogTag_Int2, "Int2", LogTagType_Int32), +DeclareTag(LogTag_Int64_1, "Int64_1", LogTagType_Int64), +DeclareTag(LogTag_Int64_2, "Int64_2", LogTagType_Int64), +DeclareTag(LogTag_String1, "String1", LogTagType_String), +DeclareTag(LogTag_String2, "String2", LogTagType_String), +DeclareTag(LogTag_WString1, "WString1", LogTagType_WideString), +DeclareTag(LogTag_WString2, "WString2", LogTagType_WideString), +DeclareTag(LogTag_UInt1, "UInt1", LogTagType_UInt32), +DeclareTag(LogTag_UInt2, "UInt2", LogTagType_UInt32), +DeclareTag(LogTag_U32X1, "Hex1", LogTagType_Hex32), +DeclareTag(LogTag_U32X2, "Hex2", LogTagType_Hex32), +DeclareTag(LogTag_U64X1, "Hex64_1", LogTagType_Hex64), +DeclareTag(LogTag_U64X2, "Hex64_2", LogTagType_Hex64), +DeclareTag(LogTag_Float1, "Float1", LogTagType_Float), +DeclareTag(LogTag_Float2, "Float2", LogTagType_Float), + +DeclareTag(LogTag_Filename, "Filename", LogTagType_String), +DeclareTag(LogTag_TargetServer, "TargetServer", LogTagType_String), +DeclareTag(LogTag_QueryString, "Q", LogTagType_String), +DeclareTag(LogTag_TraceID, "TraceID", LogTagType_String), +DeclareTag(LogTag_URL, "URL", LogTagType_String), +DeclareTag(LogTag_ErrorCode, "ErrorCode", LogTagType_UInt32), +DeclareTag(LogTag_ThreadID, "ThreadID", LogTagType_Int32), +DeclareTag(LogTag_Description, "Desc", LogTagType_String), + +// ids for ini error +DeclareTag(LogTag_Section, "Section", LogTagType_String), +DeclareTag(LogTag_Param1, "Param1", LogTagType_String), +DeclareTag(LogTag_Param2, "Param2", LogTagType_String), +DeclareTag(LogTag_Param3, "Param3", LogTagType_String), +DeclareTag(LogTag_Param4, "Param4", LogTagType_String), + +// ids for SQL ODBC error +DeclareTag(LogTag_SQLMessage, "SQLMessage", LogTagType_String), +DeclareTag(LogTag_SQLErrorCode, "SQLErrorCode", LogTagType_Int32), +DeclareTag(LogTag_SQLState, "SQLState", LogTagType_String), + +// ids for SQL BCP error +DeclareTag(LogTag_SQLBCPFunction, "SQLBCPFunction", LogTagType_String), +DeclareTag(LogTag_SQLBCPToTable, "SQLBCPToTable", LogTagType_String), +DeclareTag(LogTag_SQLBCPColumn, "SQLBCPColumn", LogTagType_Int32), +DeclareTag(LogTag_SQLBCPRowCountSum, "SQLBCPRowCountSum", LogTagType_Int32), +DeclareTag(LogTag_SQLBCPRowCountInc, "SQLBCPRowCountInc", LogTagType_Int32), +DeclareTag(LogTag_SQLBCPRowCountCur, "SQLBCPRowCountCur", LogTagType_Int32), + +// ids for aggregator machine status notifications +DeclareTag(LogTag_ISNName, "ISNName", LogTagType_String), +DeclareTag(LogTag_ISNService, "ISNService", LogTagType_String), +DeclareTag(LogTag_ISNFailureType, "ISNFailureType", LogTagType_String), + +// ids for Caption Generator +DeclareTag(LogTag_DocId, "DocId", LogTagType_Hex64), +DeclareTag(LogTag_ContentChunkId, "ContentChunkId", LogTagType_Int32), + +// ids for Netlib +DeclareTag(LogTag_Port, "Port", LogTagType_UInt32), +DeclareTag(LogTag_IP, "IP", LogTagType_String), +DeclareTag(LogTag_NumericIP, "NumericIP", LogTagType_Hex32), +DeclareTag(LogTag_LastError, "LastError", LogTagType_UInt32), + +// ids for Fcslite +DeclareTag(LogTag_PacketId, "PacketId", LogTagType_UInt32), +DeclareTag(LogTag_ConfigName, "ConfigName", LogTagType_String), +DeclareTag(LogTag_ResultBase, "ResultBase", LogTagType_UInt32), +DeclareTag(LogTag_ResultsCount, "ResultsCount", LogTagType_UInt32), +DeclareTag(LogTag_MaxResultsPerHost, "MaxResultsPerHost", LogTagType_UInt32), +DeclareTag(LogTag_FcsOptions, "FcsOptions", LogTagType_Hex64), +DeclareTag(LogTag_AggregatorOptions, "AggregatorOptions", LogTagType_Hex64), +DeclareTag(LogTag_QueryOptions, "QueryOptions", LogTagType_Hex64), +DeclareTag(LogTag_CaptionOptions, "CaptionOptions", LogTagType_Hex64), +DeclareTag(LogTag_NextTierToQuery, "NextTierToQuery", LogTagType_UInt32), +DeclareTag(LogTag_MaxTiersInIndex, "MaxTiersInIndex", LogTagType_UInt32), +DeclareTag(LogTag_RowID, "RowID", LogTagType_UInt32), + +// These should have been in the generic section but are here to be in the proper range +DeclareTag(LogTag_UInt641, "UInt641", LogTagType_UInt64), +DeclareTag(LogTag_UInt642, "UInt642", LogTagType_UInt64), + +#if defined(_M_AMD64) +DeclareTag(LogTag_Sizet1, "Size_t1", LogTagType_UInt64), +DeclareTag(LogTag_Sizet2, "Size_t2", LogTagType_UInt64), +#else +DeclareTag(LogTag_Sizet1, "Size_t1", LogTagType_UInt32), +DeclareTag(LogTag_Sizet2, "Size_t2", LogTagType_UInt32), +#endif + +DeclareTag(LogTag_ActiveMCP, "Active", LogTagType_Int32), +DeclareTag(LogTag_VoteCount, "VoteCount", LogTagType_Int32), + +#if defined(_M_AMD64) +DeclareTag(LogTag_Ptr1, "Ptr1", LogTagType_Hex64), +DeclareTag(LogTag_Ptr2, "Ptr2", LogTagType_Hex64), +#else +DeclareTag(LogTag_Ptr1, "Ptr1", LogTagType_Hex32), +DeclareTag(LogTag_Ptr2, "Ptr2", LogTagType_Hex32), +#endif + +DeclareTag(LogTag_Latitude, "Latitude", LogTagType_Float), +DeclareTag(LogTag_Longitude, "Longitude", LogTagType_Float), + +DeclareTag(LogTag_VarString1, "VarString1", LogTagType_VarString), +DeclareTag(LogTag_VarString2, "VarString2", LogTagType_VarString), +DeclareTag(LogTag_VarWString1, "VarWString1", LogTagType_VarWideString), +DeclareTag(LogTag_VarWString2, "VarWString2", LogTagType_VarWideString), + +// ids for service manager +DeclareTag(LogTag_PID, "ProcessID", LogTagType_UInt32), +DeclareTag(LogTag_PPID, "ParentProcessID", LogTagType_UInt32), + +// ids for fex +DeclareTag(LogTag_FEXLatency, "FEXLatency", LogTagType_Float), +DeclareTag(LogTag_TLALatency, "TLALatency", LogTagType_Float), +DeclareTag(LogTag_CDGLatency, "CDGLatency", LogTagType_Float), +DeclareTag(LogTag_ISNLatency, "ISNLatency", LogTagType_Float), +DeclareTag(LogTag_FcsStatus, "FcsStatus", LogTagType_Hex64), +DeclareTag(LogTag_MachineNeeded, "MachineNeeded", LogTagType_UInt32), +DeclareTag(LogTag_MachineComplete, "MachineComplete", LogTagType_UInt32), +DeclareTag(LogTag_MachineTimedOut, "MachineTimedOut", LogTagType_UInt32), +DeclareTag(LogTag_Federator, "Federator", LogTagType_String), +DeclareTag(LogTag_RawUrl, "RawUrl", LogTagType_VarString), +DeclareTag(LogTag_StatusCode, "StatusCode", LogTagType_UInt32), + +// ISN Total Time +DeclareTag(LogTag_ISNMaxTotalMachine, "ISNMaxTotalMachine", LogTagType_String), +DeclareTag(LogTag_ISNMaxTotalLatency, "ISNMaxTotalLatency", LogTagType_Float), +DeclareTag(LogTag_ISNMaxTotalStartTime, "ISNMaxTotalStartTime", LogTagType_Time), +DeclareTag(LogTag_ISNAvgTotalLatency, "ISNAvgTotalLatency", LogTagType_Float), + +// ISN Queue Time +DeclareTag(LogTag_ISNMaxQueueMachine, "ISNMaxQueueMachine", LogTagType_String), +DeclareTag(LogTag_ISNMaxQueueLatency, "ISNMaxQueueLatency", LogTagType_Float), +DeclareTag(LogTag_ISNMaxQueueStartTime, "ISNMaxQueueStartTime", LogTagType_Time), +DeclareTag(LogTag_ISNAvgQueueLatency, "ISNAvgQueueLatency", LogTagType_Float), + +// ISN Ranker Time +DeclareTag(LogTag_ISNMaxRankerMachine, "ISNMaxRankerMachine", LogTagType_String), +DeclareTag(LogTag_ISNMaxRankerLatency, "ISNMaxRankerLatency", LogTagType_Float), +DeclareTag(LogTag_ISNMaxRankerStartTime, "ISNMaxRankerStartTime", LogTagType_Time), +DeclareTag(LogTag_ISNAvgRankerLatency, "ISNAvgRankerLatency", LogTagType_Float), + +// ISN Pages Scored Time +DeclareTag(LogTag_ISNMaxPagesScoredMachine, "ISNMaxPagesScoredMachine", LogTagType_String), +DeclareTag(LogTag_ISNMaxPagesScored, "ISNMaxPagesScored", LogTagType_UInt64), +DeclareTag(LogTag_ISNMaxPagesScoredStartTime, "ISNMaxPagesScoredStartTime", LogTagType_Time), +DeclareTag(LogTag_ISNAvgPagesScored, "ISNAvgPagesScored", LogTagType_Float), + +// ISN Pages Matched Time +DeclareTag(LogTag_ISNMaxPagesMatchedMachine, "ISNMaxPgsMatchedMachine", LogTagType_String), +DeclareTag(LogTag_ISNMaxPagesMatched, "ISNMaxPgsMatched", LogTagType_UInt64), +DeclareTag(LogTag_ISNMaxPagesMatchedStartTime, "ISNMaxPgsMatchedStartTime", LogTagType_Time), +DeclareTag(LogTag_ISNAvgPagesMatched, "ISNAvgPgsMatched", LogTagType_Float), + +// ISN Pages In Corpus Time +DeclareTag(LogTag_ISNMaxPagesInCorpusMachine, "ISNMaxPgsInCorpusMachine", LogTagType_String), +DeclareTag(LogTag_ISNMaxPagesInCorpus, "ISNMaxPgsInCorpus", LogTagType_UInt64), +DeclareTag(LogTag_ISNMaxPagesInCorpusStartTime, "ISNMaxPgsInCorpusStartTime", LogTagType_Time), +DeclareTag(LogTag_ISNAvgPagesInCorpus, "ISNAvgPagesInCorpus", LogTagType_Float), + +// ISN QueueLength Length +DeclareTag(LogTag_ISNMaxQueueLengthMachine, "ISNMaxQueueLengthMachine", LogTagType_String), +DeclareTag(LogTag_ISNMaxQueueLength, "ISNMaxQueueLength", LogTagType_Float), +DeclareTag(LogTag_ISNMaxQueueLengthStartTime, "ISNMaxQueueLengthStartTime", LogTagType_Time), +DeclareTag(LogTag_ISNAvgQueueLength, "ISNAvgQueueLength", LogTagType_Float), + +// Reissue +DeclareTag(LogTag_ReissueLatency, "ReissueLatency", LogTagType_Float), +DeclareTag(LogTag_ReissueCount, "ReissueCount", LogTagType_UInt32), +DeclareTag(LogTag_MaxCDLatency, "MaxCDLatency", LogTagType_Float), +DeclareTag(LogTag_MachineAnswered, "MachineAnswered", LogTagType_UInt32), +DeclareTag(LogTag_MachineQueried, "MachineQueried", LogTagType_UInt32), +DeclareTag(LogTag_MaxTierQueried, "MaxTierQueried", LogTagType_UInt32), + +// Version +DeclareTag(LogTag_CDVersion, "CDVersion", LogTagType_UInt32), +DeclareTag(LogTag_CDVersionCount, "CDVersionCount", LogTagType_UInt32), +DeclareTag(LogTag_IndexVersion, "IndexVersion", LogTagType_UInt64), +DeclareTag(LogTag_RankVersion, "RankVersion", LogTagType_UInt64), + +// Watchdog +DeclareTag(LogTag_Property, "Property", LogTagType_String), +DeclareTag(LogTag_Level, "Level", LogTagType_String), +DeclareTag(LogTag_Machinename, "machinename", LogTagType_String), + +// Caching +DeclareTag(LogTag_HitCount, "HitCount", LogTagType_UInt64), +DeclareTag(LogTag_TotalEstimatedMatches, "TotalEstimatedMatches", LogTagType_UInt64), + +// fex filter tags +DeclareTag(LogTag_Mkt, "Mkt", LogTagType_String), +DeclareTag(LogTag_Flight, "Flight", LogTagType_String), +DeclareTag(LogTag_Brand, "Brand", LogTagType_String), +DeclareTag(LogTag_VariantID, "Variant", LogTagType_UInt32), + +DeclareTag(LogTag_RequestTime, "RequestTime", LogTagType_String), +DeclareTag(LogTag_Method, "Method", LogTagType_String), +DeclareTag(LogTag_Host, "Host", LogTagType_String), +DeclareTag(LogTag_BytesRecv, "BytesRecv", LogTagType_UInt32), +DeclareTag(LogTag_BytesSent, "BytesSent", LogTagType_UInt32), +DeclareTag(LogTag_BlockedStatus, "BlockedStatus", LogTagType_UInt32), +DeclareTag(LogTag_RequestCost, "RequestCost", LogTagType_UInt32), +DeclareTag(LogTag_CacheHeader, "CacheHeader", LogTagType_String), +DeclareTag(LogTag_Latency, "Latency", LogTagType_Float), +DeclareTag(LogTag_UserAgent, "UserAgent", LogTagType_String), +DeclareTag(LogTag_Referer, "Referer", LogTagType_String), +DeclareTag(LogTag_Cookies, "Cookies", LogTagType_String), +DeclareTag(LogTag_GetHeaderXUpSubno, "GetHeaderXUpSubno", LogTagType_String), +DeclareTag(LogTag_GetHeaderXUpUpLink, "GetHeaderXUpUpLink", LogTagType_String), + +DeclareTag(LogTag_RSLMsg, "Msg", LogTagType_String), +DeclareTag(LogTag_RSLMsgLen, "MsgLen", LogTagType_UInt32), +DeclareTag(LogTag_Offset, "Offset", LogTagType_UInt64), +DeclareTag(LogTag_RSLState, "State", LogTagType_UInt32), +DeclareTag(LogTag_RSLMemberId, "MemberId", LogTagType_UInt64), +DeclareTag(LogTag_RSLBallotId, "BallotId", LogTagType_UInt32), +DeclareTag(LogTag_RSLDecree, "Decree", LogTagType_UInt64), +DeclareTag(LogTag_RSLBallot, "Ballot", LogTagType_String), +DeclareTag(LogTag_RSLMsgVersion, "MsgVersion", LogTagType_UInt32), + +DeclareTag(LogTag_OldIndexVersion, "OldIndexVersion", LogTagType_UInt64), +DeclareTag(LogTag_NewIndexVersion, "NewIndexVersion", LogTagType_UInt64), + +// Clusterbuilder +DeclareTag(LogTag_NumArticlesLoaded, "NumArticlesLoaded", LogTagType_UInt32), +DeclareTag(LogTag_NumArticlesRefs, "NumArticleRefs", LogTagType_UInt32), +DeclareTag(LogTag_NumArticles, "NumArticles", LogTagType_UInt32), +DeclareTag(LogTag_NumClusterRefs, "NumClusterRefs", LogTagType_UInt32), +DeclareTag(LogTag_NumClusters, "NumClusters", LogTagType_UInt32), +DeclareTag(LogTag_NumLanguageModelRowRefs, "NumLanguageModelRowRefs", LogTagType_UInt32), +DeclareTag(LogTag_NumLanguageModelRows, "NumLanguageModelRows", LogTagType_UInt32), +DeclareTag(LogTag_NumChunksExpired, "NumChunksExpired", LogTagType_UInt32), +DeclareTag(LogTag_NumClustersExpired, "NumClustersExpired", LogTagType_UInt32), +DeclareTag(LogTag_NumArticlesExpired, "NumArticlesExpired", LogTagType_UInt32), +DeclareTag(LogTag_NumStopTokensLoaded, "NumStopTokensLoaded", LogTagType_UInt32), +DeclareTag(LogTag_NumNoClusterTokensLoaded, "NumNoClusterTokensLoaded", LogTagType_UInt32), + +DeclareTag(LogTag_PrefixString, "Prefix", LogTagType_String), +DeclareTag(LogTag_PrefixOptions, "PrefixOptions", LogTagType_Hex64), + +// ids for fex (latencies in ms) +DeclareTag(LogTag_FederationLatency, "FederationLatency", LogTagType_Float), +DeclareTag(LogTag_HttpSysLatency, "HttpSysLatency", LogTagType_Float), +DeclareTag(LogTag_TotalLatency, "TotalLatency", LogTagType_Float), + +// Dryad +DeclareTag(LogTag_Cluster, "Cluster", LogTagType_String), +DeclareTag(LogTag_Namespace, "Namespace", LogTagType_String), +DeclareTag(LogTag_NodeName, "Node", LogTagType_String), +DeclareTag(LogTag_ServiceType, "SvcType", LogTagType_String), + +DeclareTag(LogTag_Command, "Command", LogTagType_String), +DeclareTag(LogTag_Service, "Service", LogTagType_String), + +DeclareTag(LogTag_OID, "OID", LogTagType_Hex64), +DeclareTag(LogTag_EID, "ExtentID", LogTagType_String), +DeclareTag(LogTag_RefCount, "RefCount", LogTagType_UInt32), + +DeclareTag(LogTag_RemoteMachine, "RemoteMachine", LogTagType_String ), + +// Other (non-cosmos) tags +DeclareTag(LogTag_APProxyCommandID, "ID", LogTagType_UInt64), + +//Logging for Fex C Logs +DeclareTag(LogTag_QLocation, "QLoc", LogTagType_String), +DeclareTag(LogTag_QLatitude, "QLat", LogTagType_Float), +DeclareTag(LogTag_QLongitude, "QLong", LogTagType_Float), + +//Windows Live Searchpane Action tracing +DeclareTag(LogTag_P4_ActionID, "ActionID", LogTagType_UInt32), +DeclareTag(LogTag_P4_SessionID, "SessionID", LogTagType_String), +DeclareTag(LogTag_P4_ActionTime, "ActionTime", LogTagType_String), +DeclareTag(LogTag_P4_ResultType, "ResultType", LogTagType_UInt32), +DeclareTag(LogTag_P4_Market, "Market", LogTagType_String), +DeclareTag(LogTag_P4_SearchSource, "SearchSource", LogTagType_UInt32), +DeclareTag(LogTag_P4_TargetPage, "TargetPage", LogTagType_UInt32), +DeclareTag(LogTag_P4_ExceptionID, "ExceptionID", LogTagType_UInt64), +DeclareTag(LogTag_P4_SearchTerm, "SearchTerm", LogTagType_String), +DeclareTag(LogTag_P4_ActiveURL, "ActiveURL", LogTagType_String), +DeclareTag(LogTag_P4_ExceptionMessage, "ExceptionMessage", LogTagType_String), +DeclareTag(LogTag_P4_QuickSearch, "QuickSearch", LogTagType_UInt32), +DeclareTag(LogTag_P4_AnswerType, "AnswerType", LogTagType_String), +DeclareTag(LogTag_P4_TutorialMode, "TutorialMode", LogTagType_UInt32), + +//AppID Tracing for SOAP API +DeclareTag(LogTag_AppID, "AppID", LogTagType_String), + +//Logging of Reverse IP Loc, Lat and Long for Fex C Logs +DeclareTag(LogTag_IPLocation, "IPLoc", LogTagType_String), +DeclareTag(LogTag_IPLatitude, "IPLat", LogTagType_Float), +DeclareTag(LogTag_IPLongitude, "IPLong", LogTagType_Float), + +//adding one more latency type for logging +DeclareTag(LogTag_DPSLatency, "DPSLatency", LogTagType_Float), +DeclareTag(LogTag_RawQuery, "RawQuery", LogTagType_String), + +// Answers stuff +DeclareTag(LogTag_Market, "Market", LogTagType_String), +DeclareTag(LogTag_Environment, "Environment", LogTagType_String), +DeclareTag(LogTag_QueryTokenID, "QueryTokenID", LogTagType_UInt64), +DeclareTag(LogTag_GrammarTokenID, "GrammarTokenID", LogTagType_UInt64), +DeclareTag(LogTag_AnswerRequest, "AnsRequest", LogTagType_String), +DeclareTag(LogTag_AnswerResponse, "AnsResponse", LogTagType_String), +DeclareTag(LogTag_AnswerServiceStatus, "AnsServiceStatus", LogTagType_String), +DeclareTag(LogTag_AnswerLogVersion, "LogVersion", LogTagType_UInt32), +DeclareTag(LogTag_MatchDuration, "MatchDuration", LogTagType_UInt32), +DeclareTag(LogTag_FulfillDuration, "FulfillDuration", LogTagType_UInt32), + +// Latencies for C-logs +DeclareTag(LogTag_TotalLatency_C, "TotLat", LogTagType_Float), +DeclareTag(LogTag_DPSLatency_C, "DPSLat", LogTagType_Float), +DeclareTag(LogTag_HttpSysLatency_C, "HttpSysLat", LogTagType_Float), +DeclareTag(LogTag_FederationLatency_C, "FedLat", LogTagType_Float), +DeclareTag(LogTag_FEXLatency_C, "FEXLat", LogTagType_Float), +DeclareTag(LogTag_Latency_C, "Latency", LogTagType_Float), + +//General CLogging +DeclareTag(LogTag_CLogVersion, "CLogVersion", LogTagType_String), +DeclareTag(LogTag_FEXBuild, "FexBuild", LogTagType_String), +DeclareTag(LogTag_DataCenter, "DataCenter", LogTagType_String), + +// Speller request parameters +DeclareTag(LogTag_SpellerTimeout, "Timeout", LogTagType_UInt32), +DeclareTag(LogTag_SpellerTargetCorrection, "TargetCorrection", LogTagType_String), +DeclareTag(LogTag_SpellerConfig, "Config", LogTagType_String), +DeclareTag(LogTag_SpellerOptions, "Options", LogTagType_Hex64), + +// Tag to identify spilling in indexserve +DeclareTag(LogTag_SpillStatus, "SpillStatus", LogTagType_UInt32), + +// Tags for TLA Query Log +DeclareTag(LogTag_ISNSourceEnvironment, "ISNSourceEnvironment", LogTagType_String), +DeclareTag(LogTag_CDGSourceEnvironment, "CDGSourceEnvironment", LogTagType_String), +DeclareTag(LogTag_LocalEnvironment, "LocalEnvironment", LogTagType_String), +DeclareTag(LogTag_NumDocuments, "NumDocuments", LogTagType_UInt32), +DeclareTag(LogTag_NumDocumentsActual, "NumDocumentsActual", LogTagType_UInt32), +DeclareTag(LogTag_Tier1ISNLatency, "Tier1ISNLatency", LogTagType_Float), +DeclareTag(LogTag_Tier2ISNLatency, "Tier2ISNLatency", LogTagType_Float), + +// tag for clusterbuilder +DeclareTag(LogTag_NumNoAutotermTokensLoaded, "NumNoAutotermTokens", LogTagType_UInt32), +DeclareTag(LogTag_NumEntityGazetteersLoaded, "NumEntityGazetteersLoaded", LogTagType_UInt32), +DeclareTag(LogTag_NumPeakTermsLoaded, "NumPeakTerms", LogTagType_UInt32), +DeclareTag(LogTag_NumHighCTRTermsLoaded, "NumHighCTRTerms", LogTagType_UInt32), +DeclareTag(LogTag_NumDailyPeakTermsLoaded, "NumDailyHighCTRTerms", LogTagType_UInt32), + +// Answers Performance Monitor Results +DeclareTag(LogTag_APM_ExecutionTime, "ExecutionTime", LogTagType_String ), +DeclareTag(LogTag_APM_TestNumber, "TestNumber", LogTagType_String ), +DeclareTag(LogTag_APM_ExpectedAnswer, "ExpectedAnswer", LogTagType_String ), +DeclareTag(LogTag_APM_ExpectedScenario, "ExpectedScenario", LogTagType_String ), +DeclareTag(LogTag_APM_Environment, "Environment", LogTagType_String ), +DeclareTag(LogTag_APM_HostName, "HostName", LogTagType_String ), +DeclareTag(LogTag_APM_Port, "Port", LogTagType_String ), +DeclareTag(LogTag_APM_APMQueryResult, "APMQueryResult", LogTagType_String ), +DeclareTag(LogTag_APM_AnswersResponseCode, "AnswersResponseCode", LogTagType_String ), +DeclareTag(LogTag_APM_Latency, "Latency", LogTagType_UInt32 ), +DeclareTag(LogTag_APM_ActualAnswer, "ActualAnswer", LogTagType_String ), +DeclareTag(LogTag_APM_ProductionID, "ProductionID", LogTagType_String ), +DeclareTag(LogTag_APM_GrammarID, "GrammarID", LogTagType_String ), +DeclareTag(LogTag_APM_DataSet, "DataSet", LogTagType_String ), +DeclareTag(LogTag_APM_DataSetVersion, "DataSetVersion", LogTagType_String ), +DeclareTag(LogTag_APM_ActualScenario, "ActualScenario", LogTagType_String ), +DeclareTag(LogTag_APM_OutputVersion, "Version", LogTagType_UInt32 ), +DeclareTag(LogTag_APM_AlertStatus, "AlertStatus", LogTagType_UInt32 ), +DeclareTag(LogTag_APM_SuccessRate, "SuccessRate", LogTagType_UInt32 ), +// APM v2 XML blob is all one string. +DeclareTag(LogTag_APM_TestResultXML, "TestResultXML", LogTagType_String ), + +// Speller Debugging Tags +DeclareTag(LogTag_Speller_Query, "Query", LogTagType_String), +DeclareTag(LogTag_Speller_Status, "PostWebSpellerStatus", LogTagType_UInt32), +DeclareTag(LogTag_Speller_QueryType, "QueryType", LogTagType_UInt32), +DeclareTag(LogTag_Speller_Flags, "SpellerFlags", LogTagType_UInt32), +DeclareTag(LogTag_Speller_NumSuggestions, "NumSuggestions", LogTagType_UInt32), +DeclareTag(LogTag_Speller_Suggestion, "Suggestion", LogTagType_String), +DeclareTag(LogTag_Speller_PPLatency, "PreprocessLatency", LogTagType_Float), +DeclareTag(LogTag_Speller_URLDetectLatency, "URLDetectLatency", LogTagType_Float), +DeclareTag(LogTag_Speller_CandGenLatency, "CandGenLatency", LogTagType_Float), +DeclareTag(LogTag_Speller_ViterbiLatency, "ViterbiLatency", LogTagType_Float), +DeclareTag(LogTag_Speller_LMLatency, "LMLatency", LogTagType_Float), +DeclareTag(LogTag_Speller_URLCheckLatency, "URLCheckLatency", LogTagType_Float), +DeclareTag(LogTag_Speller_ConfLatency, "ConfLatency", LogTagType_Float), +DeclareTag(LogTag_Speller_PreWebLatency, "PreWebSpellerLatency", LogTagType_Float), +DeclareTag(LogTag_Speller_PostWebLatency, "PostWebSpellerLatency", LogTagType_Float), + +// More Tags for TLA Query Log: request info from users +DeclareTag(LogTag_RequestIP, "RequestIP", LogTagType_String), +DeclareTag(LogTag_RequestMethod, "RequestMethod", LogTagType_String), +DeclareTag(LogTag_RequestDomain, "RequestDomain", LogTagType_String), +DeclareTag(LogTag_RequestUrl, "RequestUrl", LogTagType_String), +DeclareTag(LogTag_RequestReferer, "RequestReferer", LogTagType_String), +DeclareTag(LogTag_RequestAppID, "RequestAppID", LogTagType_String), +DeclareTag(LogTag_RequestPort, "RequestPort", LogTagType_UInt32), + +// Log Service +DeclareTag(LogTag_IG, "IG", LogTagType_String), + +// Tags for TLAQueryStatJoin Log +DeclareTag(LogTag_TLATraceID, "TLATraceID", LogTagType_String), +DeclareTag(LogTag_CacheTraceID, "CacheTraceID", LogTagType_String), +// tag for cache service +DeclareTag(LogTag_CacheStatus, "CacheStatus", LogTagType_Hex64), +DeclareTag(LogTag_CacheID, "CacheID", LogTagType_String), + +// Vox Populi (answers) service tags +DeclareTag(LogTag_VoxPopuli_Rank, "Rank", LogTagType_Float), +DeclareTag(LogTag_VoxPopuli_Pos, "Pos", LogTagType_UInt32), + +DeclareTag(LogTag_VarQueryString, "Q", LogTagType_VarString), +DeclareTag(LogTag_VarMarket, "Market", LogTagType_VarString), +DeclareTag(LogTag_AnswerService, "Service", LogTagType_VarString), +DeclareTag(LogTag_AnswerScenario, "Scenario", LogTagType_VarString), +DeclareTag(LogTag_AnswerFeed, "Feed", LogTagType_VarString), +DeclareTag(LogTag_AnswerEffectiveConstraint,"AnswerEffectiveConstraint",LogTagType_VarString), +DeclareTag(LogTag_EffectiveConstraint, "EffectiveConstraint", LogTagType_VarString), +DeclareTag(LogTag_AnswerEffectiveFlight, "AnswerEffectiveFlight", LogTagType_VarString), +DeclareTag(LogTag_EffectiveFlight, "EffectiveFlight", LogTagType_VarString), +DeclareTag(LogTag_Time1, "Time1", LogTagType_Time), +DeclareTag(LogTag_Time2, "Time2", LogTagType_Time), + +// TLA Query Log Tags +DeclareTag(LogTag_RemoteFcsLatency, "RemoteFcsLatency", LogTagType_Float), +DeclareTag(LogTag_RemoteFcsNetworkLatency, "RemoteFcsNetworkLatency", LogTagType_Float), + +// tag for tiers queried within request +DeclareTag(LogTag_TiersQueried, "TiersQueried", LogTagType_String), + +// IDs for alterations +DeclareTag(LogTag_AlterationName, "AlterationName", LogTagType_String), + +DeclareTag(LogTag_Mean, "Mean", LogTagType_Float), +DeclareTag(LogTag_WeightedMean, "WeightedMean", LogTagType_Float), +DeclareTag(LogTag_StdDev, "StdDev", LogTagType_Float), +DeclareTag(LogTag_Samples, "Samples", LogTagType_Int32), +DeclareTag(LogTag_Metric, "Metric", LogTagType_String), +DeclareTag(LogTag_MetricParameters, "MetricParameters", LogTagType_String), + +DeclareTag(LogTag_VarFlight, "Flight", LogTagType_VarString), + +DeclareTag(LogTag_FcsResultsNumber, "FcsResultsNumber", LogTagType_UInt64), +DeclareTag(LogTag_QueryResults, "QueryResults", LogTagType_String), +DeclareTag(LogTag_DuplicatedQueryType, "DuplicatedQueryType", LogTagType_UInt32), + +// +// Additional speller latencies for experimentation +// +DeclareTag(LogTag_Speller_RemoveDupCandLat, "RemoveDupCandLat", LogTagType_Float), +DeclareTag(LogTag_Speller_PopCandLat, "PopCandLat", LogTagType_Float), +DeclareTag(LogTag_Speller_FiltCandLat, "FiltCandLat", LogTagType_Float), +DeclareTag(LogTag_Speller_BldLatticeLat, "BldLatticeLat", LogTagType_Float), +DeclareTag(LogTag_Speller_CalcBestPathLat, "CalcBestPathLat", LogTagType_Float), +DeclareTag(LogTag_Speller_PrnLatticeLat, "PrnLatticeLat", LogTagType_Float), +DeclareTag(LogTag_Speller_GetTopPathsLat, "GenTopPathsLat", LogTagType_Float), +DeclareTag(LogTag_Speller_ScorePathLat, "ScorePathLat", LogTagType_Float), +DeclareTag(LogTag_Speller_TokenizeLat, "TokenizeLat", LogTagType_Float), +DeclareTag(LogTag_Speller_FcsParseLat, "FcsParseLat", LogTagType_Float), +DeclareTag(LogTag_Speller_GenCandLat, "GenCandLat", LogTagType_Float), +DeclareTag(LogTag_Speller_PreCheckLat, "PreCheckLat", LogTagType_Float), +DeclareTag(LogTag_Speller_JaJpSpellCheckLat, "JaJpSpellCheckLat", LogTagType_Float), +DeclareTag(LogTag_Speller_PreSpellCheckLat, "PreSpellCheckLat", LogTagType_Float), +DeclareTag(LogTag_Speller_SpellCheckLat, "SpellCheckLat", LogTagType_Float), +DeclareTag(LogTag_Speller_PostSpellCheckLat, "PostSpellCheckLat", LogTagType_Float), +DeclareTag(LogTag_Speller_PhraseCheckLat1, "PhraseCheckLat1", LogTagType_Float), +DeclareTag(LogTag_Speller_PhraseCheckLat2, "PhraseCheckLat2", LogTagType_Float), +DeclareTag(LogTag_Speller_PhraseCheckLat3, "PhraseCheckLat3", LogTagType_Float), +DeclareTag(LogTag_Speller_PhraseCheckLat4, "PhraseCheckLat4", LogTagType_Float), +DeclareTag(LogTag_Speller_PhraseCheckLat5, "PhraseCheckLat5", LogTagType_Float), +DeclareTag(LogTag_Speller_PhraseCheckLat6, "PhraseCheckLat6", LogTagType_Float), + +// +// Tag ids for Commerce Answer Service (CAS) metadata logging +// +DeclareTag(LogTag_CAS_ProductID, "ProductID", LogTagType_VarString), +DeclareTag(LogTag_CAS_VendID, "VendorID", LogTagType_VarString), +DeclareTag(LogTag_CAS_Shingle, "Shingle", LogTagType_VarString), +DeclareTag(LogTag_CAS_MSNShopItemID, "MSNShopItemID", LogTagType_VarString), +DeclareTag(LogTag_CAS_TraceID, "TraceID", LogTagType_VarString), +DeclareTag(LogTag_CAS_PTraceID, "PTraceID", LogTagType_String), +DeclareTag(LogTag_CAS_RawQuery, "RawQuery", LogTagType_VarString), +DeclareTag(LogTag_CAS_Metadata, "CASMetadata", LogTagType_String), +DeclareTag(LogTag_CAS_CommDocType, "CommDocType", LogTagType_String), +DeclareTag(LogTag_CAS_DpID, "DpID", LogTagType_VarString), +DeclareTag(LogTag_CAS_Category, "Category", LogTagType_VarString), +DeclareTag(LogTag_CAS_MCATId, "MCATID", LogTagType_VarString), +DeclareTag(LogTag_CAS_Scenario, "Scenario", LogTagType_String), +DeclareTag(LogTag_CAS_Brand, "Brand", LogTagType_VarString), +DeclareTag(LogTag_CAS_ProductName, "ProductName", LogTagType_VarString), +DeclareTag(LogTag_CAS_ProductLine, "ProductLine", LogTagType_VarString), +DeclareTag(LogTag_CAS_ReviewRating, "ReviewRating", LogTagType_VarString), +DeclareTag(LogTag_CAS_Title, "Title", LogTagType_VarString), +DeclareTag(LogTag_CAS_ResultPosition, "ResultPosition", LogTagType_Int32), +DeclareTag(LogTag_CAS_FeatureName, "FeatureName", LogTagType_VarString), +DeclareTag(LogTag_CAS_CRFUsedToMatch, "CRFUsedToMatch", LogTagType_Int32), +DeclareTag(LogTag_CAS_CRFConfidenceLevel, "CRFConfidenceLevel", LogTagType_Float), +DeclareTag(LogTag_CAS_QDRConfidenceLevel, "QDRConfidenceLevel", LogTagType_Float), +DeclareTag(LogTag_CAS_CRFAlteredRawQuery, "CRFAlteredRawQuery", LogTagType_String), +DeclareTag(LogTag_CAS_CRFLabeledQuery, "CRFLabeledQuery", LogTagType_String), + +// WebPM Performance Monitor Results +DeclareTag(LogTag_WebPM_CurrentEnvironment, "CurrentEnvironment", LogTagType_String ), +DeclareTag(LogTag_WebPM_TargetEnvironment, "TargetEnvironment", LogTagType_String ), +DeclareTag(LogTag_WebPM_TargetEnvironmentName,"TargetEnvironmentName",LogTagType_String ), +DeclareTag(LogTag_WebPM_TargetPort, "Port", LogTagType_UInt32 ), +DeclareTag(LogTag_WebPM_RAWQuery, "RAWQuery", LogTagType_String ), +DeclareTag(LogTag_WebPM_QueryResult, "QueryResult", LogTagType_String ), +DeclareTag(LogTag_WebPM_ResponseStatus, "ResponseStatus", LogTagType_UInt64 ), +DeclareTag(LogTag_WebPM_HttpResponseCode, "HttpResponseCode", LogTagType_String ), +DeclareTag(LogTag_WebPM_E2ELatency, "E2ELatency", LogTagType_UInt32 ), +DeclareTag(LogTag_WebPM_FcsCacheFindLatency, "FcsCacheFindLatency", LogTagType_Float ), +DeclareTag(LogTag_WebPM_FcsISNLatency, "FcsISNLatency", LogTagType_Float ), +DeclareTag(LogTag_WebPM_FcsCDGLatency, "FcsCDGLatency", LogTagType_Float ), +DeclareTag(LogTag_WebPM_FcsTotalLatency, "FcsTotalLatency", LogTagType_Float ), +DeclareTag(LogTag_WebPM_ProductionID, "ProductionID", LogTagType_UInt32 ), +DeclareTag(LogTag_WebPM_FDRSourceISNEnv, "FDRSourceISNEnv", LogTagType_String ), +DeclareTag(LogTag_WebPM_FDRSourceCDGEnv, "FDRSourceCDGEnv", LogTagType_String ), +DeclareTag(LogTag_WebPM_MaxTierQueried, "MaxTierQueried", LogTagType_UInt32 ), +DeclareTag(LogTag_WebPM_ISNResponseFoundInCache, "ISNResponseFoundInCache", LogTagType_String ), +DeclareTag(LogTag_WebPM_CDGResponseFoundInCache, "CDGResponseFoundInCache", LogTagType_String ), +DeclareTag(LogTag_WebPM_OutputVersion, "Version", LogTagType_UInt32 ), +DeclareTag(LogTag_WebPM_AlertStatus, "AlertStatus", LogTagType_UInt32 ), +DeclareTag(LogTag_WebPM_FcsEstimatedMatches, "FcsEstimatedMatches", LogTagType_UInt64 ), +DeclareTag(LogTag_WebPM_FcsNumberofResults, "FcsNumberofResults", LogTagType_UInt64 ), +DeclareTag(LogTag_WebPM_ExecutionTime, "ExecutionTime", LogTagType_String ), +DeclareTag(LogTag_WebPM_TestNumber, "TestNumber", LogTagType_UInt32 ), +DeclareTag(LogTag_WebPM_RemoteEnvironmentName,"RemoteEnvironmentName",LogTagType_String ), +DeclareTag(LogTag_WebPM_NumberOfHops, "NumberOfHops", LogTagType_UInt32 ), + +// Moonshot information +DeclareTag(LogTag_Docs, "DODR", LogTagType_String), + +DeclareTag(LogTag_FEXAnswerServiceName, "AnswerServiceName", LogTagType_String), +DeclareTag(LogTag_FEXAnswerScenario, "AnswerScenario", LogTagType_String), +DeclareTag(LogTag_FEXAnswerUXDisplayHint, "AnswerUXDisplayHint", LogTagType_String), + +DeclareTag(LogTag_VarConstraint, "Constraint", LogTagType_VarString), +DeclareTag(LogTag_Constraint, "Constraint", LogTagType_String), + +// Count of adult documents rendered +DeclareTag(LogTag_AdultDocumentCount, "AdultDocumentCount", LogTagType_UInt32), + +// User state +DeclareTag(LogTag_ULS, "ULS", LogTagType_UInt32), + +DeclareTag(LogTag_IndexName, "IndexName", LogTagType_String), +DeclareTag(LogTag_ResultSource, "ResultSource", LogTagType_String), +DeclareTag(LogTag_MUID, "MUID", LogTagType_String), + +// ApSDk tags +DeclareTag(LogTag_ApCommMsgVersion, "Msg", LogTagType_String), +DeclareTag(LogTag_ApCommSeqNo, "SeqNo", LogTagType_UInt64), + +// Performance Tracking +DeclareTag(LogTag_LoadBalanceId, "LoadBalanceId", LogTagType_String), +DeclareTag(LogTag_LoadBalanceTS, "LoadBalanceTS", LogTagType_String), + +// FEX / FrontDoor +DeclareTag(LogTag_FrontDoorAction, "FDAction", LogTagType_String), +DeclareTag(LogTag_FullUrl, "FullUrl", LogTagType_String), + +DeclareTag(LogTag_JSON, "JSON", LogTagType_String), + +//webmaster +DeclareTag(LogTag_Email, "Email", LogTagType_String), +DeclareTag(LogTag_PUID, "PUID", LogTagType_String), +DeclareTag(LogTag_UserProfileID, "UserProfileID", LogTagType_String), + +// Query Diagnostic +DeclareTag(LogTag_QueryProcessLogVersion, "LogVersion", LogTagType_UInt32), +DeclareTag(LogTag_QueryProcessTracking, "QPTracking", LogTagType_String), + +// This must be the final tag, and it must have type None +DeclareTag(LogTag_End, "End", LogTagType_None), diff --git a/DryadVertex/VertexHost/system/classlib/include/MSMutex.h b/DryadVertex/VertexHost/system/classlib/include/MSMutex.h new file mode 100644 index 0000000..bdba16b --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/MSMutex.h @@ -0,0 +1,101 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "RefCount.h" + +//JCnamespace apsdk +//JC{ + +class MSMutex : public RefCount { +public: + CRITICAL_SECTION m_Section; + + MSMutex() { + InitializeCriticalSection(&m_Section); + } + ~MSMutex() { + DeleteCriticalSection(&m_Section); + } + + void Acquire() { + EnterCriticalSection(&m_Section); + } + + void Release() { + LeaveCriticalSection(&m_Section); + } + + BOOL TryAcquire() { + return TryEnterCriticalSection(&m_Section); + } + +}; + +struct MutexLock { + Ptr m_Lock; + MutexLock(MSMutex *am) : m_Lock(am) { + LogAssert (m_Lock); + m_Lock->Acquire(); + } + ~MutexLock() { + Release(); + } + void Release() { + if (m_Lock) { + m_Lock->Release(); + m_Lock = NULL; + } + } +}; + + + +struct MutexTryLock { + Ptr m_Lock; + MutexTryLock(MSMutex * am) { + BOOL lockAcquired = am->TryAcquire(); + if (lockAcquired) + m_Lock = am; + } + + ~MutexTryLock() { + Release(); + } + void Release() { + if (m_Lock) { + m_Lock->Release(); + m_Lock = NULL; + } + } + bool operator!() { return (m_Lock == NULL); } + operator BOOL () { return (m_Lock != NULL); } +}; + +#define MUTEX_LOCK(_l, _m) MutexLock _l(_m) +#define MUTEX_TRY_LOCK(_l,_m) MutexTryLock _l(_m) +#define MUTEX_RELEASE(_l) (_l).Release() + +//JC} // namespace apsdk + +#ifdef USING_APSDK_NAMESPACE +using namespace apsdk; +#endif diff --git a/DryadVertex/VertexHost/system/classlib/include/PropertyIds.h b/DryadVertex/VertexHost/system/classlib/include/PropertyIds.h new file mode 100644 index 0000000..e5d14fc --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/PropertyIds.h @@ -0,0 +1,49 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "basic_types.h" + +// property type flag +const UInt16 PropTypeMask = 0xc000; + +// PropType_Atom is a leaf property which is an element of a list or a +// set. There may be nested properties within the leaf. +const UInt16 PropType_Atom = 0x0000; + +// length type flag +const UInt16 PropLengthMask = 0x2000; + +// A property with PropLength_Short has a 1-byte length field +const UInt16 PropLength_Short = 0x0000; +// A property with PropLength_Long has a 4-byte length field +const UInt16 PropLength_Long = 0x2000; + +// mask for the remaining 13-bit namespace + +const UInt16 PropValueMask = 0x1fff; + +#define PROP_SHORTATOM(x_) ((x_) | PropType_Atom | PropLength_Short) +#define PROP_LONGATOM(x_) ((x_) | PropType_Atom | PropLength_Long) + +// Propries for Dryad +const UInt16 Prop_Stream_BeginTag = PROP_SHORTATOM(0x1200); +const UInt16 Prop_Stream_EndTag = PROP_SHORTATOM(0x1201); diff --git a/DryadVertex/VertexHost/system/classlib/include/RefCount.h b/DryadVertex/VertexHost/system/classlib/include/RefCount.h new file mode 100644 index 0000000..1ac7630 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/RefCount.h @@ -0,0 +1,119 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once +#include + +class RefCount +{ + mutable volatile long m_Crefs; + public: + + RefCount(void) { m_Crefs = 0; } + + virtual ~RefCount() {} + + int UpCount(void) const + { + return InterlockedIncrement(&m_Crefs); + } + + int DownCount(void) const + { + int val = InterlockedDecrement(&m_Crefs); + if (!val) + { + delete this; + } + return val; + } +}; + +template class Ptr +{ + T* m_Ptr; + public: + + Ptr(const Ptr& other) : m_Ptr(other.m_Ptr) + { + if (m_Ptr) + { + m_Ptr->UpCount(); + } + } + + Ptr(T* ptr = 0) : m_Ptr(ptr) + { + if (m_Ptr) + { + m_Ptr->UpCount(); + } + } + + ~Ptr(void) + { + if (m_Ptr) + { + m_Ptr->DownCount(); + } + } + + operator T*(void) const { return m_Ptr; } + + operator T*(void) { return m_Ptr; } + + T& operator*(void) const { return *m_Ptr; } + + T& operator*(void) { return *m_Ptr; } + + T* operator->(void) const { return m_Ptr; } + + T* operator->(void) { return m_Ptr; } + + bool operator == (const T* ptr) const { return (m_Ptr == ptr); } + + bool operator == (const Ptr &ptr) const { return (m_Ptr == ptr.m_Ptr); } + + bool operator != (const T * ptr) const { return (m_Ptr != ptr); } + + bool operator != (const Ptr &ptr) const { return (m_Ptr != ptr.m_Ptr); } + + bool operator !() const { return (m_Ptr == 0); } + + Ptr& operator=(Ptr &ptr) {return operator=((T *) ptr);} + + Ptr& operator=(T* ptr) + { + if (m_Ptr != ptr) + { + if (m_Ptr) + { + m_Ptr->DownCount(); + } + m_Ptr = ptr; + if (m_Ptr) + { + m_Ptr->UpCount(); + } + } + return *this; + } +}; + diff --git a/DryadVertex/VertexHost/system/classlib/include/XCompute.h b/DryadVertex/VertexHost/system/classlib/include/XCompute.h new file mode 100644 index 0000000..43a1090 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/XCompute.h @@ -0,0 +1,1783 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#pragma warning( push ) +/* 'X' bytes padding added after member 'Y' */ +#pragma warning( disable: 4820 ) + + +#pragma pack( push, 8 ) + + +#if !defined(_PCVOID_DEFINED) +typedef const void* PCVOID; +#define _PCVOID_DEFINED +#endif + + +#include + + +#if defined(__cplusplus) +extern "C" { +#endif + + + +/*++ + +XcOpenSession API + +Description: + +Opens an XCompute session for a given cluster. Each session +is associated with a cluster and is independent of other sessiosn. +The session (apart from other things) is associated with +user credientials. + +It is possible to create multiple sessions for the same cluster and +these multiple sessions will behave independent of each other. This +is particularly useful for applications like WebServer which will run +multiple sessions, one per user. + +Use the XcCloseSession to close the handle returned as a result of +XcOpenSession call. + +Arguments: + + pOpenSessionParams + The Open Session Parameters. Passes info about cluster to + establish session with, clientId, etc. + See XC_OPEN_SESSION_PARAMS for details. + + Pass NULL for defaults - Default cluster and a default cliend id. + + pSessionHandle + Handle to session + + pAsyncInfo + The async info structure. Its an alias to the + DR_ASYNC_INFO defined in Dryad.h. IF this + parameter is NULL, then function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then operation is carried + on in asynchronous manner. If asynchronous + operation has been successfully started then + function terminates immediately with + HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start asynchronous operation. + + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcOpenSession( + IN PCXC_OPEN_SESSION_PARAMS pOpenSessionParams, + OUT PXDRESSIONHANDLE pSessionHandle, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + +XcCloseSession API + +Description: + +Closes the session. + +Arguments: + + SessionHandle + Handle to session to close + +Return Value: + + XCERROR_OK + Call succeeded. + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCloseSession( + IN XDRESSIONHANDLE SessionHandle +); + + + +/*++ + +XcInitialize API + +Description: + +Call this function at the start to initialize the various internal +data structures of the XCompute SDK library. + +Arguments: + + ConfigFile + Name of the config file + + ComponentName + The name of the component + +Return Value: + + XCERROR_OK + Call succeeded. + NOTE: + S_FALSE will be returned if the initialize + has already been called. + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcInitialize( + IN PCSTR ConfigFileName, + IN PCSTR ComponentName +); + + + +/*++ + +XcFreeMemory API + +Description: + +Frees the memory allocated by the XCompute API. +All the memory returned as a result of call to +the XCompute API should use the XcFreeMemory to +deallocate the memory + +Arguments: + + pMem + Pointer to the memory + +Return Value: + + XCERROR_OK + Memory was successfully deallocated + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcFreeMemory( + IN PCVOID pMem +); + + + +/*++ + +XcCreateNewProcessHandle API + +Description: + +Creates a new process handle for a new process in the +given Job. + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +Note: +1. + This method just creates the handle to + the XCompute process. It does not schedule the process itself. + Use the XcScheduleProcessAPI to schedule the XCompute process. + +2. + Use the XcCloseProcessHandle() to free the handle + +3. + Do not copy handle using the simple assignment operator.Use the + DuplicateProcessHandle() API. Each handle variable needs to be + freed using the XcCloseProcessHandle(). + +Arguments: + + SessionHandle + Handle to a session associated with this call + + pJobId + The Id of the job under which the process will + be created. A NULL value will cause the current + processes JobId to be automatically picked up. + NOTE: + This parameter is only interesting to the Task Scheduler. + For all other cases, it should be assined to NULL + + pProcessHandle + The handle to the process. + + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCreateNewProcessHandle( + IN XDRESSIONHANDLE SessionHandle, + IN const GUID* pJobId, + OUT PXCPROCESSHANDLE pProcessHandle +); + + + +/*++ + +XcOpenCurrentProcessHandle API + +Description: + +Opens the current processes handle + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +Note: +1. + This method creates the handle to the current process and assigns + it to the session on that process. + +2. + Use the XcCloseProcessHandle() to free the handle + +3. + Do not copy handle using the simple assignment operator.Use the + DuplicateProcessHandle() API. Each handle variable needs to be + freed using the XcCloseProcessHandle(). + +Arguments: + + SessionHandle + Handle to a session associated with this call. + + pProcessHandle + The handle to the process. This must be closed using the + XcClosePorcessHandle() + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcOpenCurrentProcessHandle( + IN XDRESSIONHANDLE SessionHandle, + OUT PXCPROCESSHANDLE pProcessHandle +); + + + +/*++ + +XcCloseProcessHandle API + +Description: + +Closes a process handle created either by a call to +XcCreateNewProcessHandle() or XcDupProcessHandle(). + +This call is synchronous and does not cross +machine boundaries/process boundaries. + + +NOTE: +Every call to the XcCreateNewProcessHandle() or +DupProcessHandle() should ultimately +result in a call to XcCloseProcessHandle() to deallocated the handle. + +Arguments: + + ProcessHandle + Process handle to be closed + +Return Value: + + XCERROR_OK + The call succeded + + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCloseProcessHandle ( + IN XCPROCESSHANDLE ProcessHandle +); + + + +/*++ + +XcDupProcessHandle API + +Description: + +Duplicates a process handle. Use this api, if a copy of the +process handle is needed. + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +NOTE: + a. Every call to the DupProcessHandle should ultimately result + in a call to XcCloseProcessHandle() to deallocated the handle. + +Arguments: + + ProcessHandle + Process handle to be duplicated + + pDupProcessHandle + The duplicated process handle + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcDupProcessHandle( + IN XCPROCESSHANDLE ProcessHandle, + OUT PXCPROCESSHANDLE pDupProcessHandle +); + + + +/*++ + +XcSerializeProcessHandle API + +Description: + +Creates a serialized process handle. A XCompute process can serialize a +process handle, and pass it to another XCompute process where the other +XCompute process can use the XcUnSerializeProcessHandle() API, to +recreate the process handle. Then it can use that process handle to +communicate with the process. e.g by using XcSetAndGetProcessInfo() API + +Arguments: + + ProcessHandle + The handle to the process to serialize. + + ppXcSerializedHandleBlock + The serialized process handle. Use the XcFreeMemory() API + to de-allocated the pXcSerializedHandleBlock. + + pBlockLength + The length in bytes of the serialized process handle block + +NOTE: + The UserContext assiciated with the process handle will *NOT* + be serialzed. + +Return Value: + + XCERROR_OK + The call succeded + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcSerializeProcessHandle ( + IN XCPROCESSHANDLE ProcessHandle, + OUT PCVOID* ppXcSerializedHandleBlock, + OUT PSIZE_T pBlockLength +); + + + +/*++ + +XcUnSerializeProcessHandle API + +Description: + +Un-serializes a serialized process handle. See XcSerializeProcessHandle() API +for more details + +Arguments: + + SessionHandle + The session to which to associate the un-serialized process handle with. + + pXcSerializedHandleBlock + The serialized process handle. + + pBlockLength + The length in bytes of the serialized process handle block + + pProcessHandle + The un-serialized process handle. + +Return Value: + + XCERROR_OK + The call succeded + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcUnSerializeProcessHandle ( + IN XDRESSIONHANDLE SessionHandle, + IN PCVOID pXcSerializedHandleBlock, + IN SIZE_T BlockLength, + OUT PXCPROCESSHANDLE pProcessHandle +); + + + +/*++ + +XcSetProcessUserContext API + +Description: + +Associates API user related data with process +identified by the Process handle. + +The API user can associate any data with the XCompute process +and get back the data, using the XcGetProcessUserContext API. + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +NOTE: +a. The XcCloseProcessHandle() will not deallocate the user + context data. It is the API users responsibilty to + deallocated data associated with UserContext. + +b. The user context is associated with a process and not with a + ProcessHandle. So if multiple handles identify the same + XCompute process, they will return the same user context. + +Arguments: + + ProcessHandle + Process handle to identify process to which user context + is being associated + + pUserContext + The user context data + + pPreviousUserContext + If there was a previously associated user context + with the XCompute process , then returns that context data. + Otherwise NULL is returned. If the caller supplies NULL input, + the previous value is not returned + + +Return Value: + + XCERROR_OK + The call succeded + + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcSetProcessUserContext( + IN XCPROCESSHANDLE ProcessHandle, + IN ULONG_PTR pUserContext, + OUT ULONG_PTR* pPreviousUserContext +); + + + +/*++ + +XcGetProcessUserContext API + +Description: + +Gets the API user related data associated with the XCompute Process. +The API user can associate any data with the XCompute process via the +XcAddUserContextToHandle API. + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +Arguments: + + ProcessHandle + Process handle + + pUserContext + The user context data associated with the XCompute Process. + If no user context is associated, then + NULL is returned. + +Return Value: + + XCERROR_OK + The call succeded + + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessUserContext( + IN XCPROCESSHANDLE ProcessHandle, + OUT ULONG_PTR* pUserContext +); + + + +/*++ + +XcGetProcessState API + +Description: + +Gets the process state information. If Schedule process is not +yet been called, the API will return error. + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +Arguments: + + ProcessHandle + Process handle + + pProcessState + Describes the process state. The different states + are described in XComputeTypes.h + + pProcessSchedulingError + if process state is XCPROCESSSTATE_COMPLETED + then the error code indicates reson. + S_OK means process compeleted without errors. + Other error codes indicate reasons for failed completion. + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessState( + IN XCPROCESSHANDLE ProcessHandle, + OUT PXCPROCESSSTATE pProcessState, + OUT XCERROR* pProcessSchedulingError +); + + + +/*++ + +XcGetProcessId API + +Description: + +Gets the process Id of the process associated with the process handle. +If the process state is anything less than XCPROCESSSTATE_ASSIGNEDTOPN +an error is returned. + +Arguments: + + ProcessHandle + Process handle + + + pProcessId + The id of the process + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessId( + IN XCPROCESSHANDLE ProcessHandle, + OUT GUID* pProcessId +); + + + +/*++ + +XcGetProcessNodeId API + +Description: + +Gets the process node on which the process has been assigned. +If the process state anything other than XCPROCESSSTATE_ASSIGNEDTOPN +an error is returned. + +Arguments: + + ProcessHandle + Process handle + + pProcessNodeId + Pointer to process node Id + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessNodeId( + IN XCPROCESSHANDLE ProcessHandle, + OUT PXCPROCESSNODEID pProcessNodeId +); + + + +/*++ + +ProcessScheduler API + +--*/ + + + +/*++ + +XcScheduleProcess API + +Description: + +Contacts the Process Scheduler to schedule an XCompute Process. +Any XCompute Process in a Job may schedule additional +XCompute Processes in the same Job by requesting their creation +through the XCompute Process Scheduler, using this API. + +NOTE: +This call always returns immediately. +A successful return code from the API indicates that the +XcScheduleProcess request was added to the local scheduleProcess queue. +The user should use the XcWaitForStateChange(XCPROCESSSTATE_ASSIGNEDTOPN) +API to see when the process actually gets scheduled to the Process Scheduler + +Arguments: + + ProcessHandle + Handle to the process. + Use the XcCreateNewProcessHandle () API + to get obtain the handle to the process + + pScheduleProcessDescriptor + See PCXC_SCHEDULEPROCESS_DESCRIPTOR in + XComputeTypes.h. This datastructure is + copied before the function returns and + so it is not necessary for the caller + to preserve the contents during a + async call + + Return Value: + + S_OK indicating the operation was successfully started. + + Any other return value indicates the scheduleprocess request + could not be started + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcScheduleProcess( + IN XCPROCESSHANDLE ProcessHandle, + IN PCXC_SCHEDULEPROCESS_DESCRIPTOR pScheduleProcessDescriptor +); + + + + +/*++ + +XcCancelScheduleProcess API + +Description: + +Contacts the Process Scheduler to cancel the scheduled +XCompute Process. This API is used by the Parent XCompute process +that originally scheduled the XCompute process to cancel its +creation. +NOTE: The XCompute process will get cancelled, only if has not +already been created on a process node. The returned error code +indicates whether the process was successfully cancelled or not. + +Arguments: + + ProcessHandle + Handle to the process. + + pAsyncInfo + The async info structure. Its an alias to the + DR_ASYNC_INFO defined in Dryad.h. IF this + parameter is NULL, then function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then operation is carried + on in asynchronous manner. If asynchronous + operation has been successfully started then + function terminates immediately with + HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start asynchronous operation. + + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCancelScheduleProcess( + IN XCPROCESSHANDLE ProcessHandle, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + +PN API + +--*/ + +/*++ + +XcSetAndGetProcessInfo API + +Description: + +Gets the process related information from the Process Node. +JobManager (e.g. Dryad Job manager), will use this API to get +information about a given XCompute process, of a job. +Various bit flags (explained below) control the amount of data +retreived for a given process +It also provides the user with the ability to block on a +particular property, for maxBlockTime amount of time, before the +API finishes (synchronously or asynchronously). Dryad uses this +to extend the lease period for a given process + +Arguments: + + ProcessHandle + Handle to the process. + Use the XcCreateNewProcessHandle () API + to get obtain the handle to the process + + pXcRequestInputs + Pointer to the + XC_SETANDGETPROCESSINFO_REQINPUT struct. + It contains the various inputs to the API + clubbed together. This structure needs to + be preserverd by the user till the Async + call is completed + + ppXcRequestResults + The results structure.The user should use + the XcFreeMemory(ppXcPnProcessInfo) to free + the memory after the results have been + consumed. + See PXC_SETANDGETPROCESSINFO_REQRESULTS for + more info. + + pAsyncInfo + The async info structure. Its an alias to + the DR_ASYNC_INFO defined in Dryad.h. If + this parameter is NULL, then function + completes in synchronous manner and error + code is returned as return value. + + If parameter is not NULL then operation is + carried on in asynchronous manner. If + asynchronous operation has been successfully + started then function terminates + immediately with + HRESULT_FROM_WIN32(ERROR_IO_PENDING) return + value. + + Any other return value indicates that it was + impossible to start asynchronous operation. + + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcSetAndGetProcessInfo( + IN XCPROCESSHANDLE ProcessHandle, + IN PXC_SETANDGETPROCESSINFO_REQINPUT pXcRequestInputs, + OUT PXC_SETANDGETPROCESSINFO_REQRESULTS* ppXcRequestResults, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + +XcGetNetworkLocalityPathOfProcessNode + +Description: + +This API translates a set of process node IDs into +network locality paths. + +Arguments: + + SessionHandle + Handle to a session associated with + this call + + ProcessNodeId + The Process Node for which the + path is required + + ppNetworkLocalityPath + Returned network locality path for the ProcessNode. + The pNetworkLocalityPath vector should be freed with + XcFreeMemory(ppNetworkLocalityPath) + + pNetworkLocalityParam + The affinity param to be used to get the locality path. + The affinity param lets the user identify the affinity + level relative to the given ProcessNodeId, + which is reflected in the returned ppNetworkLocalityPath. + Thus given a ProcessNodeId, the user might say, + L2Switch as the NetworkLocalityParam, which means + the affinity is to all process nodes under that L2Switch. + + Different affinity params are defined in the + XComputeTypes.h. See Network Locality Params for + more details. + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetNetworkLocalityPathOfProcessNode( + IN XDRESSIONHANDLE SessionHandle, + IN XCPROCESSNODEID ProcessNodeId, + IN PSTR pNetworkLocalityParam, + OUT PCSTR* ppNetworkLocalityPath +); + + + +/*++ + +XcEnumerateProcessNodes + + +Description: + +This API enumerates all the process nodes that are controlled +by the Process scheduler and returns an array of processNodeIds + +Arguments: + + SessionHandle + Handle to a session associated with + this call + + pNumNodeIds + Pointer to a int which gets filled with the + number of process Node Ids in the + ppProcessNodeIds array + + ppProcessNodeIds + Pointer to array of processNode Ids. Use the + XcFreeMemory() API to deallocate. + + pAsyncInfo + The async info structure. Its an alias to + the DR_ASYNC_INFO defined in Dryad.h. If + this parameter is NULL, then function + completes in synchronous manner and error + code is returned as return value. + + If parameter is not NULL then operation is + carried on in asynchronous manner. If + asynchronous operation has been successfully + started then function terminates + immediately with + HRESULT_FROM_WIN32(ERROR_IO_PENDING) return + value. + + Any other return value indicates that it was + impossible to start asynchronous operation. + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcEnumerateProcessNodes( + IN XDRESSIONHANDLE SessionHandle, + OUT UINT32* pNumNodeIds, + OUT PXCPROCESSNODEID* ppProcessNodeIds, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + +XcFetchProcessNodeMetaData + +Description: + +This API fetches the process node related metadata. This +call can result in a call to the Process Scheduler, if the +metadata for a given process node is missing. + +Arguments: + + SessionHandle + Handle to a session associated with + this call + + pProcessNodeIds + Array of IDs of the nodes for which the + metadata is required + + NumNodeIds + Number of node ids in the + pProcessNodeIds array + + pAsyncInfo + The async info structure. Its an alias to + the DR_ASYNC_INFO defined in Dryad.h. If + this parameter is NULL, then function + completes in synchronous manner and error + code is returned as return value. + + If parameter is not NULL then operation is + carried on in asynchronous manner. If + asynchronous operation has been successfully + started then function terminates + immediately with + HRESULT_FROM_WIN32(ERROR_IO_PENDING) return + value. + + Any other return value indicates that it was + impossible to start asynchronous operation. + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcFetchProcessNodeMetaData( + IN XDRESSIONHANDLE SessionHandle, + IN UINT32 NumNodeIds, + IN PXCPROCESSNODEID pProcessNodeIds, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + + Notification/Sync API + +--*/ + + + +/*++ + +XcWaitForStateChange API + +Description: + +The API allows users to get async completion status for +XCompute process when it reaches a desired state. (see XCPROCESSSTATE) +When the desired state is reached the async completion is dispatched. + +NOTE: +1. If the process gets cancelled, then completion is dispatched immediately +2. The pOperationState of the AsyncInfo will have the error code. + +Arguments: + + ProcessHandle + Handle to an XCompute process for which the + state change event is needed + + WaitForState + The state to wait for the XCompute to be in, so + that completion can be dispatched + + MaxWaitInterval + The maximum amount of time (not including network + request latencies) that the API should wait for a + change in the process list before completing. If + XCTIMEINTERVAL_ZERO, the API will return changes + that can be immediately determined without + communication with the process scheduler. If + XC_TIMEINTERVAL_INFINITE, the API will wait until a + change occurs or the process is cancelled. + + pAsyncInfo + The async info structure. Its an alias to the + DR_ASYNC_INFO defined in Dryad.h. If this + parameter is NULL, then the function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then the operation is carried + on in asynchronous manner. If an asynchronous + operation has been successfully started then + this function terminates immediately with + an HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start the asynchronous operation. + + Return Value: + + DrError_OK indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcWaitForStateChange( + IN XCPROCESSHANDLE ProcessHandle, + IN XCPROCESSSTATE WaitForState, + IN XCTIMEINTERVAL MaxWaitInterval, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + + XCompute File access API. + +--*/ + + + +/*++ + +XcGetWorkingDirectoryProcessUri API + +Description: + +Gets a URI to a file or directory within an XCompute process's initial +working directory. +The returned Uri is fully qualified and can be used to +create paths for file URI's in the process's working directory, +by appending file names to the WorkingDirectory Uri. + +Arguments: + + ProcessHandle + The process handle for which to get the + Process Working directory Uri + + pRelativePath + The path relative to process's working directory + that will be appended to the working directory. + NOTE: + If relative path is NULL. or '.' or'/', then the working directory + path is returned. + + ppProcessWdUri + The fully qualified working directory Uri, + Use the XcFreeMemory() API to free this buffer + + Return Value: + + DrError_OK indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetWorkingDirectoryProcessUri( + IN XCPROCESSHANDLE ProcessHandle, + IN PCSTR pRelativePath, + OUT PSTR* ppUri +); + + + +/*++ + +XcGetProcessUri API + +Description: + +Gets the Uri to a file or directory local to XCompute process. +The returned Uri is a fully qualified and can be used to +create paths for file URI's in the processes root directory or +another directory under the root by appending path/s relative +to the initial working directory. + +NOTE: + +The Job does not have access to directories above the +Process's Root Directory. +All directories e.g. Process Working Directory, Data directory +are sub directories under the Process's Root directory + + +Arguments: + + ProcessHandle + The process handle for which to get the + Process File Uri. + + pRelativePath + The path relative to process's working directory + that will be appended to the working directory. + NOTE: + If relative path is NULL. or '.' or '/', then the working directory + path is returned. + + ppProcessRootDirUri + The Processes Root directory Uri. + Use the XcFreeMemory() API to free this buffer. + + Return Value: + + DrError_OK indicates call succeeded + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessUri( + IN XCPROCESSHANDLE ProcessHandle, + IN PCSTR pRelativePath, + OUT PSTR* ppProcessRootDirUri +); + + + +/*++ + +XcTranslateLocalPathToProcessUri API + +Description: + +Translates the local path to a process Uri. This translation is necessary +in various scenarios. The local process will interact with files using the +standard file system API's. Once it is done it will convert the local +file names to process Uri's to be passed to other XCompute processes, which +will use the XCompute file SDK API's to access those files across machines. + + +Arguments: + + ProcessHandle + The process handle. At present it has to be the current process handle. + In future we will allow this handle to be for any process, that belongs + to the same job and is running on the same Process Node. + + pLocalPath + The local path to be translated to the process Uri format + + + ppTranslatedUri + The Translated Uri of the given local path. Use the XcFreeMemory() + API to free this buffer + + Return Value: + + DrError_OK + indicates call succeeded + + DrError_UnknownProcess + if not under PN. + + HRESULT_FROM_WIN32(ERROR_INVALID_PATH) + if path is not in a process root directory on the current PN. + + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcTranslateLocalPathToProcessUri( + IN XCPROCESSHANDLE ProcessHandle, + IN PCSTR pLocalPath, + OUT PSTR* ppTranslatedUri +); + + + +/*++ + +XcTranslateProcessUriToLocalPath API + +Description: + +Tanslates the File Uri to local path. This translation is necessary +in various scenarios. The local process will interact with files using the +standard file system API's. The local process will use the XCompute API's +to get the process Uri's and then would want to convert them to local paths so +as to be able to use the standard File System API's to interact with files locally. + +Arguments: + + ProcessHandle + The process handle. + + pUri + The Uri to translate to local path + + + + ppLocalFilePath + The Translated local path for the above Uri. + Use the XcFreeMemory() API to free this buffer + + Return Value: + + DrError_OK + indicates call succeeded + + DrError_UnknownProcess + if not under a PN + + DrError_InvalidPathname + if the provided URI is not withing a process root directory on the current PN. + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcTranslateProcessUriToLocalPath( + IN XCPROCESSHANDLE ProcessHandle, + IN PCSTR pUri, + OUT PSTR* ppLocalFilePath +); + + + +/*++ + +XcOpenProcessFile API + +Description: + +Opens a handle to a remote XCompute processes working File. +Using this handle, an application can read remote files +written by a XCompute process on a given Node. +Writing of files is not supported. Local files can be written +using Ordinary Windows file I/O (restricted to the working +directory and Its children). + +Arguments: + + SessionHandle + Handle to an XCompute session associated with + this call. + + pFileUri + the fully qualified file Uri (UTF-8) obtained by calling the + XcGetProcessFileUri API. + + Flags + Reserved. Must be 0. + + pFileHandle + The returned handle to the opened file. + Set to NULL if error + + pAsyncInfo + The async info structure. Its an alias to the + DR_ASYNC_INFO defined in Dryad.h. If this + parameter is NULL, then the function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then the operation is carried + on in asynchronous manner. If an asynchronous + operation has been successfully started then + this function terminates immediately with + an HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start the asynchronous operation. + + + Return Value: + + if pAsyncInfo is NULL + DrError_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation (a SUCCESS HRESULT will never + be returned if pAsyncInfo is not NULL). + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcOpenProcessFile( + IN XDRESSIONHANDLE SessionHandle, + IN PCSTR pFileUri, + IN DWORD Flags, + OUT PXCPROCESSFILEHANDLE pFileHandle, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + +XcGetProcessFileSize API + +Description: + +Gets the fileSize of the given process file handle. + +Arguments: + + FileHandle + The handle to the opened file. + + Flags + The options for fetching the size. These option flags are mutually exclusive + One of the following is permissible: + XC_REFRESH_AGGRESSIVE (default) + - visit server to find out latest known length + XC_REFRESH_PASSIVE + - return length from local cache if available otherwise + visit server to find out latest known length + XC_REFRESH_FROM_CACHE + - return length from local cache. + Fail if not available. This is a non blocking call. + + pSize + Pointer to the output size variable. + Must not be NULL. + The memory pointed to by this variable must remain valid and writable + for the duration of the asynchronous operation. + + pAsyncInfo + The async info structure. Its an alias to the + DR_ASYNC_INFO defined in Dryad.h. If this + parameter is NULL, then the function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then the operation is carried + on in asynchronous manner. If an asynchronous + operation has been successfully started then + this function terminates immediately with + an HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start the asynchronous operation. + + + Return Value: + + E_NOTIMPL is returned if the underlyning file does not support GetFileSize. + + if pAsyncInfo is NULL + DrError_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation (a SUCCESS HRESULT will never + be returned if pAsyncInfo is not NULL). + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessFileSize( + IN XCPROCESSFILEHANDLE FileHandle, + IN UINT Flags, + OUT PUINT64 pSize, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + +XcCloseProcessFile API + +Description: + +Closes the file opened by the XcOpenProcessFile + +Arguments: + + FileHandle + The handle to the opened file. + + Return Value: + + DrError_OK indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCloseProcessFile( + IN XCPROCESSFILEHANDLE FileHandle +); + + + +/*++ + +XcReadProcessFile API + +Description: + +Reads the content of the file opened by the XcOpenProcessFile + +Arguments: + + + FileHandle + The handle to the opened file. + + pBuffer + Pointer to the buffer that receives the data read. + + pBytesRead + Pointer to variable containing size of the buffer + on input. On return this variable receives number + of bytes read. + + pReadPosition + The offset from the beginning of the file at + which to read. + + pAsyncInfo + The async info structure. Its an alias to the + DR_ASYNC_INFO defined in Dryad.h. If this + parameter is NULL, then the function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then the operation is + carried on in asynchronous manner. If an asynchronous + operation has been successfully started then + this function terminates immediately with + an HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start the asynchronous operation. + + + Return Value: + + if pAsyncInfo is NULL + DrError_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation (a SUCCESS HRESULT will never + be returned if pAsyncInfo is not NULL). + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcReadProcessFile( + IN XCPROCESSFILEHANDLE FileHandle, + OUT PVOID pBuffer, + IN OUT PSIZE_T pBytesRead, + IN OUT XCPROCESSFILEPOSITION* pReadPosition, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + +XcGetCurrentProcessNodeId API + +Description: + +Gets the current Process Node Id. The Process Node Id to +the node name map is maintained internally. + +Arguments: + + SessionHandle + Handle to an XCompute session associated with + this call. + + pProcessNodeId + Pointer to Pointer of the Id of the node + + Return Value: + + DrError_OK + indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetCurrentProcessNodeId( + IN XDRESSIONHANDLE SessionHandle, + OUT PXCPROCESSNODEID pProcessNodeId +); + + + +/*++ + +XcProcessNodeIdFromName API + +Description: + +Gets the Process Node Id for a node given the node name. The +Process Node Id to the node name map is maintained internally. +If a node name is not found in the internal map, then a new +entry is created and the corrosponding id is returned back + +Arguments: + + SessionHandle + Handle to an XCompute session associated with this + call. Reserved for future use. Must be NULL. + + + pProcessNodeName + Name of the process node for which Id is needed + + pProcessNodeId + Pointer to Pointer of the Id of the node + + Return Value: + + DrError_OK indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcProcessNodeIdFromName( + IN XDRESSIONHANDLE SessionHandle, + IN PCSTR pProcessNodeName, + OUT PXCPROCESSNODEID pProcessNodeId +); + + + +/*++ + +XcProcessNodeNameFromId API + +Description: + +Gets the Process Node name from the given Process Node Id.The +Process Node Id to the node name map is maintained internally. + +Arguments: + + SessionHandle + Handle to an XCompute session associated with this + call. + + + ProcessNodeId + The process Node Id for which the node name + is needed + + ppProcessNodeName + Name of the process node corrosponding to Id + Note: the returned process node name string + is permanently allocated and will remain + valid for the life of the process. There + is no need to make a copy of this string. + + Return Value: + + DrError_OK indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcProcessNodeNameFromId( + IN XDRESSIONHANDLE SessionHandle, + IN XCPROCESSNODEID ProcessNodeId, + OUT PCSTR* ppProcessNodeName +); + + + +#pragma pack( pop ) + +#pragma warning( pop ) + +#if defined(__cplusplus) +} +#endif diff --git a/DryadVertex/VertexHost/system/classlib/include/XComputeTypes.h b/DryadVertex/VertexHost/system/classlib/include/XComputeTypes.h new file mode 100644 index 0000000..a6f06d4 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/XComputeTypes.h @@ -0,0 +1,1236 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#pragma warning( push ) +/* 'X' bytes padding added after member 'Y' */ +#pragma warning( disable: 4820 ) + + + +#if !defined(_PCVOID_DEFINED) +typedef const void* PCVOID; +#define _PCVOID_DEFINED +#endif + + +#if defined(__cplusplus) +extern "C" { +#endif + + +#define XCOMPUTEAPI_EXT +#define XCOMPUTEAPI __stdcall + + +/*++ + +Error codes and exit code typedefs + +--*/ +typedef DWORD XCEXITCODE; +typedef HRESULT XCERROR; + + + +/* + +Process state related typedefs. Used by the sync +API, to depict process state + +*/ +typedef DWORD XCPROCESSSTATE; +typedef XCPROCESSSTATE* PXCPROCESSSTATE; + + + +/* + +Various process states. + +State transictions are shown below +Each state is explained in detail in the section below this + + XCPROCESSSTATE_INVALID -------------> XCPROCESSSTATE_COMPLETED + | + | + \/ + XCPROCESSSTATE_UNSCHEDULED -------------> XCPROCESSSTATE_COMPLETED + | + | + \/ + + XCPROCESSSTATE_SCHEDULING-------------> XCPROCESSSTATE_COMPLETED + | + | + \/ + + XCPROCESSSTATE_SCHEDULED-------------> XCPROCESSSTATE_COMPLETED + | + | + \/ + + XCPROCESSSTATE_ASSIGNEDTONODE-------------> XCPROCESSSTATE_COMPLETED + | + | + \/ + + XCPROCESSSTATE_BINDING-------------> XCPROCESSSTATE_COMPLETED + /\ + | + | + \/ + + XCPROCESSSTATE_BINDCOMPLETED-------------> XCPROCESSSTATE_COMPLETED + | + | + \/ + + XCPROCESSSTATE_LAUNCHING-------------> XCPROCESSSTATE_COMPLETED + | + | + \/ + + XCPROCESSSTATE_RUNNING-------------> XCPROCESSSTATE_COMPLETED + | + | + \/ + XCPROCESSSTATE_TERMINATING + | + | + \/ + XCPROCESSSTATE_COMPLETED + + | + | + \/ + XCPROCESSSTATE_STATEDELETED + | + | + \/ + XCPROCESSSTATE_DELETED + + + + +XCPROCESSSTATE_INVALID + The process state is invalid. This will be returned, if a call to + the XcGetProcessState() api is made, even before + XcScheduleProcess() has been called. + +XCPROCESSSTATE_UNSCHEDULED + The process has NOT been scheduled on the Process Scheduler. + It is possible to XcCancelScheduleProcess() in this state + +XCPROCESSSTATE_SCHEDULING + The scheduling is in flight to the Process Scheduler. But it is + not yet assigned to a process node. + It is possible to XcCancelScheduleProcess() in this state + +XCPROCESSSTATE_SCHEDULED + The process has been scheduled on the Process Scheduler. + It is possible to XcCancelScheduleProcess() in this state + +XCPROCESSSTATE_ASSIGNEDTONODE + The process has been assinged to a process Node. The user can + use XcGetProcessNode() api to get PN related info for the process + After this point it is possible to interact with process node state for + the process (update process constraints, request resource bindings, + launch the process, get and set properties, open and read process + files, etc.) + +XCPROCESSSTATE_BINDING + The process is assigned to a PN and resource binding (copying) is + in progress. All required resources need to be copied before a + process can be launched. While in this state, it is possible to interact + with process node state for the process (update process constraints, + request additional resource bindings, get and set properties, open + and read process files, etc.) + +XCPROCESSSTATE_BINDCOMPLETED + The resources copying is compelted. NOTE: This might signal + completion of binding for a batch of resources. The state can + jump back to XCPROCESSSTATE_BINDING. if further bindings are + requested at this point. If the process was originally scheduled + without the XC_CREATEPROCESS_DESCRIPTOR_LATEBOUNDRESOURCES + option set, or if process flag + XCPROCESS_FLAG_LAUNCH_AFTER_RESOURCE_BIND (TBD) has been set + on the process, then the state will automatically proceed to + XCPROCESSSTATE_LAUNCHING. + +XCPROCESSSTATE_LAUNCHING + All resources binding finished and the process is being launched. + +XCPROCESSSTATE_RUNNING + The corresponding win32 process has been created on the PN + node and is currently running + +XCPROCESSSTATE_TERMINATING + The jobobject for the corresponding win32 process is being terminated/ + GC'ed. + +XCPROCESSSTATE_COMPLETED + The XComputeProcess completed. + If the process successfully reached XCPROCESSSTATE_ASSIGNEDTONODE + before completing, then it is still possible to interact with process node state + for the process (e.g., you can still open and read process files). + + NOTE: The process can complete for various reasons. The error + code associated with the state explains the exact reason. + +XCPROCESSSTATE_STATEDELETED + The XCompute process is completed. Its state like directories etc have + been garbage collected. But the ProcessNode still hold information about + the XCompute process statistics. + +XCPROCESSSTATE_DELETEDFROMNODE + The XComputeProcess has been garbage collected and no info + about it exists at the process node. Only locally cached status + information is available. + +*/ +#define XCPROCESSSTATE_ZERO ((XCPROCESSSTATE)0x00000000) +#define XCPROCESSSTATE_INVALID XCPROCESSSTATE_ZERO + +// The folowing states are managed on the Process Scheduler before the process has been allocated on the Process Node +#define XCPROCESSSTATE_UNSCHEDULED ((XCPROCESSSTATE)0x40000000) +#define XCPROCESSSTATE_SCHEDULING ((XCPROCESSSTATE)0x40100000) +#define XCPROCESSSTATE_SCHEDULED ((XCPROCESSSTATE)0x40200000) +#define XCPROCESSSTATE_SCHEDULINGFAILED ((XCPROCESSSTATE)0x40300000) + +// The following states are managed on the Process Node +#define XCPROCESSSTATE_NODE_UNINITIALIZED ((XCPROCESSSTATE)0x80000000) +#define XCPROCESSSTATE_NODE_PREINITIALIZE ((XCPROCESSSTATE)0x80100000) +#define XCPROCESSSTATE_NODE_ALLOCATED ((XCPROCESSSTATE)0x80200000) +#define XCPROCESSSTATE_ASSIGNEDTONODE XCPROCESSSTATE_NODE_ALLOCATED +#define XCPROCESSSTATE_NODE_READYTOBINDRESOURCES ((XCPROCESSSTATE)0x80300000) +#define XCPROCESSSTATE_NODE_BINDINGRESOURCES ((XCPROCESSSTATE)0x80400000) +#define XCPROCESSSTATE_BINDING XCPROCESSSTATE_NODE_BINDINGRESOURCES +#define XCPROCESSSTATE_NODE_RESOURCEBINDINGCOMPLETE ((XCPROCESSSTATE)0x80500000) +#define XCPROCESSSTATE_BINDCOMPLETED XCPROCESSSTATE_NODE_RESOURCEBINDINGCOMPLETE +#define XCPROCESSSTATE_NODE_LOADPENDING ((XCPROCESSSTATE)0x80600000) +#define XCPROCESSSTATE_LAUNCHING XCPROCESSSTATE_NODE_LOADPENDING +#define XCPROCESSSTATE_NODE_LOADING ((XCPROCESSSTATE)0x80700000) +#define XCPROCESSSTATE_NODE_LOADED ((XCPROCESSSTATE)0x80800000) +#define XCPROCESSSTATE_NODE_APPINITIALIZATION ((XCPROCESSSTATE)0x80900000) +#define XCPROCESSSTATE_NODE_APPRUNPENDING ((XCPROCESSSTATE)0x80A00000) +#define XCPROCESSSTATE_NODE_APPRUNNING ((XCPROCESSSTATE)0x80B00000) +#define XCPROCESSSTATE_RUNNING XCPROCESSSTATE_NODE_APPRUNNING +#define XCPROCESSSTATE_NODE_TERMINATING ((XCPROCESSSTATE)0x80c00000) +#define XCPROCESSSTATE_TERMINATING XCPROCESSSTATE_NODE_TERMINATING +#define XCPROCESSSTATE_NODE_COMPLETE ((XCPROCESSSTATE)0x80d00000) +#define XCPROCESSSTATE_COMPLETED XCPROCESSSTATE_NODE_COMPLETE +#define XCPROCESSSTATE_NODE_DELETINGSTATE ((XCPROCESSSTATE)0x80e00000) +#define XCPROCESSSTATE_NODE_STATEDELETED ((XCPROCESSSTATE)0x80f00000) +#define XCPROCESSSTATE_STATEDELETED XCPROCESSSTATE_NODE_STATEDELETED +#define XCPROCESSSTATE_NODE_ZOMBIE ((XCPROCESSSTATE)0x8fffffff) + +// The following states are maintained by the client SDK after the Process Node has forgotten about the process + +#define XCPROCESSSTATE_DELETEDFROMNODE ((XCPROCESSSTATE)0xc0000000) + +// End Process States + +#define XCPROCESSSTATE_NEVER ((XCPROCESSSTATE)0xffffffff) + + + + +/*++ + +NOTE: All byte strings in this .h file are UTF8 strings +unless otherwise noted. + +--*/ + + + +/*++ + +XCPROCESSHANDLE +Handle to a XCompute Process. Used in the API's that require a +process identifier as its input. +Some e.g. API's are: XcScheduleProcess(), XcSetAndGetProcessInfo(). + +This is used to assist API users in following ways: + a. Pass this identifier instead of generating their own unique + process GUIDs + + b. If using Async API, then there is no need to keep track of / + or lookup GUID to call back function/object map. The user + context associated with the handle can be used instead. + + c. All operations related to a process, e.g. Process states, etc + will ultimately get tied to the handle, thus making it easy + for users to collect/modify process related information + +--*/ +typedef struct tagXCPROCESSHANDLE +{ + ULONG_PTR Unused; +} *XCPROCESSHANDLE, **PXCPROCESSHANDLE; + +const XCPROCESSHANDLE INVALID_XCPROCESSHANDLE = NULL; + + + +/*++ + +XCPROCESSNODEID +Identifies a XCompute Node. Used in the ProcessNodeIdFromName() +and the ProcessNodeNameFromId() API's. All XCompute API's take +the XCPROCESSNODEID as input where ever a node related entry +is needed. This is used to assist users to be able to pass +this identifier instead of passing strings around and having to +go through allocating/deallocating/Copying/Comparing them. + +Fields: +--*/ +typedef struct tagXCPROCESSNODEID +{ + ULONG_PTR Unused; +} *XCPROCESSNODEID, **PXCPROCESSNODEID; + +const XCPROCESSNODEID INVALID_XCPROCESSNODEID = NULL; + + +/*++ + +XC_RESOURCE_SOURCE_TYPE enum + +Determines the type of resource passed in the XC_RESOURCEFILE_DESCRIPTOR + +Fields: + + XC_RESOURCE_SOURCE_UTF8_PATHNAME + Indicates XC_RESOURCEFILE_DESCRIPTOR's pResourceSource + points to a UTF-8share path. + The pathname may be an xstream URI or a path to a + working file in another XCompute Process in the same Job. + The resource will be copied as a binary image from this + location + + XC_RESOURCE_SOURCE_EMBEDDED_CONTENT + Indicates XC_RESOURCEFILE_DESCRIPTOR's pResourceSource + points to a embedded resource. The contents of the buffer + pointed to by pResource are written to the resource file + as a binary image. + + +--*/ +typedef enum tagXC_RESOURCE_SOURCE_TYPE { + XC_RESOURCE_SOURCE_UTF8_PATHNAME = 0, + XC_RESOURCE_SOURCE_EMBEDDED_CONTENT +} XC_RESOURCE_SOURCE_TYPE; + + + + +/*++ + +XC_RESOURCEFILE_DESCRIPTOR structure + +Holds information about a single resource file. + +Fields: + + Size + sizeof(XC_RESOURCEFILE_DESCRIPTOR) + + Flags + Reserved for future use + + pFileName + Name of the file on the destination. + The name is a relative path to the + target processl's working directory + + ResourceSourceType + See XC_RESOURCE_SOURCE_TYPE above + + NumberOfResourceSourceBytes + Length of the pResourceSource buffer. + + pResourceSource + Depending on resourceSourceType, either + points to an embedded resource or to a share + path from where the resource will be copied + +--*/ +typedef struct tagXC_RESOURCEFILE_DESCRIPTOR{ + SIZE_T Size; + DWORD Flags; + PCSTR pFileName; + XC_RESOURCE_SOURCE_TYPE ResourceSourceType; + SIZE_T NumberOfResourceSourceBytes; + PVOID pResourceSource; +} XC_RESOURCEFILE_DESCRIPTOR, *PXC_RESOURCEFILE_DESCRIPTOR; +typedef const XC_RESOURCEFILE_DESCRIPTOR* PCXC_RESOURCEFILE_DESCRIPTOR; + + + +/*++ + +This bit flag indicates that the resources for a process will +be late bound. This is used in the Flags parameter in the +XC_CREATEPROCESS_DESCRIPTOR. See below + +--*/ +#define XCCREATEPROCESSDESCRIPTOR_LATEBOUNDRESOURCES 0x00000002 + + + +/*++ + +Typedefs/various aliases +--*/ +typedef DR_ASYNC_INFO XC_ASYNC_INFO; +typedef XC_ASYNC_INFO* PXC_ASYNC_INFO; +typedef const XC_ASYNC_INFO* PCXC_ASYNC_INFO; + + +/*++ + +A session handle will contain user related information that +can be used to determine what part of information the user +can access. + +--*/ +typedef struct tagXDRESSIONHANDLE +{ + ULONG_PTR Unused; +} *XDRESSIONHANDLE, **PXDRESSIONHANDLE; +const XDRESSIONHANDLE INVALID_XDRESSIONHANDLE = NULL; + + + +/*++ + +XC_OPEN_SESSION_PARAMS structure + +Session related information + +Fields: + + Size + sizeof(XC_OPEN_SESSION_PARAMS) + + Flags + Reserved for future use. Must be 0. + + pCluster + Name of the cluster to connect to. + If this filed is NULL then default cluster + will be used. + + ClientId + The unique ID to use for syncing process + related information with the Process Scheduler. + If a NULL GUID, a default ID will be generated. + + Use this field, when implementing failover. + For e.g if JobManager wants to provide failover, + it can use a well known clientId, between various + redundant JobManagers. When the failover happens, + the new Job Manager can use the same client Id to + sync process states with the Process Scheduler. + +--*/ + +typedef struct tagXC_OPEN_SESSION_PARAMS{ + SIZE_T Size; + DWORD Flags; + PCSTR pCluster; + GUID ClientId; +}XC_OPEN_SESSION_PARAMS, *PXC_OPEN_SESSION_PARAMS; +typedef const XC_OPEN_SESSION_PARAMS* PCXC_OPEN_SESSION_PARAMS; + + + +/*++ + +An XCDATETIME is defined as the number of 100-nanosecond intervals +that have elapsed since 12:00 A.M. January 1, 1601 (UTC). It is +the representation of choice whenever an absolute date/time must +be used. + +XCDATETIME is equivalent to a windows FILETIME value without local +time zone adjustment. + +--*/ +typedef UINT64 XCDATETIME; + +#define XCDATETIME_NEVER _UI64_MAX +#define XCDATETIME_LONGAGO 0 + + + +/*++ + +XCTIMEINTERVAL represents a measurement of elapsed time in 100ns +Intervals. It is a signed entity (elapsed time may be negative). +It is the natural type for the result of subtracting two XCDATETIME +Values. +--*/ +typedef INT64 XCTIMEINTERVAL; + + + +/*++ + +The below #defines help define commonly used time intervals + +--*/ +#define XCTIMEINTERVAL_INFINITE _I64_MAX +#define XCTIMEINTERVAL_NEGATIVEINFINITE _I64_MIN +#define XCTIMEINTERVAL_ZERO 0 +#define XCTIMEINTERVAL_QUANTUM 1 +#define XCTIMEINTERVAL_100NS ( (XCTIMEINTERVAL) (XCTIMEINTERVAL_QUANTUM) ) +#define XCTIMEINTERVAL_MICROSECOND ( (XCTIMEINTERVAL) ( XCTIMEINTERVAL_100NS * 10 ) ) +#define XCTIMEINTERVAL_MILLISECOND ( (XCTIMEINTERVAL) ( XCTIMEINTERVAL_MICROSECOND * 1000 ) ) +#define XCTIMEINTERVAL_SECOND ( (XCTIMEINTERVAL) ( XCTIMEINTERVAL_MILLISECOND * 1000 ) ) +#define XCTIMEINTERVAL_MINUTE ( (XCTIMEINTERVAL) ( XCTIMEINTERVAL_SECOND * 60 ) ) +#define XCTIMEINTERVAL_HOUR ( (XCTIMEINTERVAL) ( XCTIMEINTERVAL_MINUTE * 60 ) ) +#define XCTIMEINTERVAL_DAY ( (XCTIMEINTERVAL) ( XCTIMEINTERVAL_HOUR * 24 ) ) +#define XCTIMEINTERVAL_WEEK ( (XCTIMEINTERVAL) ( XCTIMEINTERVAL_DAY * 7 ) ) + + + +/*++ + +Flags for the various options set in the XC_PROCESS_CONSTRAINTS structure + +--*/ +#define XCPROCESSCONSTRAINTOPTION_SETMAXREMAININGELAPSEDEXECUTIONTIME 0x1 +#define XCPROCESSCONSTRAINTOPTION_SETMAXREMAININGRETAINAFTERTERMINATETIME 0x2 +#define XCPROCESSCONSTRAINTOPTION_SETMAXPERWIN32PROCESSUSERMODETIME 0x4 +#define XCPROCESSCONSTRAINTOPTION_SETMAXREMAININGUSERMODETIME 0x8 +#define XCPROCESSCONSTRAINTOPTION_SETMAXWORKINGSETSIZE 0x10 +#define XCPROCESSCONSTRAINTOPTION_SETMAXNUMWIN32PROCESSES 0x20 +#define XCPROCESSCONSTRAINTOPTION_SETMAXPERWIN32PROCESSMEMORYSIZE 0x40 +#define XCPROCESSCONSTRAINTOPTION_SETMAXMEMORYSIZE 0x80 + + + +/*++ + +Default process priority + +--*/ +#define XCPROCESSPRIORITY_DEFAULT 0x80000000 + + + +/*++ + +XC_PROCESS_CONSTRAINTS structure + +Constraints that will be applied to the process that gets started +on a given node + +Fields: + + Size + sizeof(XC_PROCESS_CONSTRAINTS) + + ProcessConstraintOptions + Bit flag indicating what options have + been set + + MaxRemainingElapsedExecutionTime + Maximum amount of time process can + continue to run without terminating. + + MaxRemainingRetainAfterTerminateTime + Amount of time after process + termination before the process + persistent state is discarded + + MaxPerWin32ProcessUserModeTime + Max amount of user-mode CPU time for + each Windows process associated with + the XCompute process + + MaxRemainingUserModeTime + Max amount of total user-mode CPU + time for XCompute process + + MaxWorkingSetSize + Maximum working set size for + Windows processes + + MaxNumWin32Processes + Maximum number of Windows processes + that can be running + + MaxPerWin32ProcessMemorySize + Maximum amount of memory per win32 + process + + MaxMemorySize + Max total memory for the XCompute + process + +--*/ +typedef struct tagXC_PROCESS_CONSTRAINTS{ + SIZE_T Size; + DWORD ProcessConstraintOptions; + XCTIMEINTERVAL MaxRemainingElapsedExecutionTime; + XCTIMEINTERVAL MaxRemainingRetainAfterTerminateTime; + XCTIMEINTERVAL MaxPerWin32ProcessUserModeTime; + XCTIMEINTERVAL MaxRemainingUserModeTime; + UINT64 MaxWorkingSetSize; + UINT32 MaxNumWin32Processes; + UINT64 MaxPerWin32ProcessMemorySize; + UINT64 MaxMemorySize; +} XC_PROCESS_CONSTRAINTS, *PXC_PROCESS_CONSTRAINTS; +typedef const XC_PROCESS_CONSTRAINTS* PCXC_PROCESS_CONSTRAINTS; + + + +/*++ + +XC_CREATEPROCESS_DESCRIPTOR structure + +Used in ScheduleProcess API. Has all the information needed to +launch a process on a particular node + +Fields: + + Size + Sizeof(XC_CREATEPROCESS_DESCRIPTOR) + + Flags + Option bit flags + + XC_CREATEPROCESS_DESCRIPTOR_LATEBOUNDRESOURCES + indicates that the resources are late + bound. When the process gets created + on a node, the process is set to + UnInitialized state, and waits for the + parent process to contact the PN to bind + the resources + + pCommandLine + Command line that will launch the process + + pProcessClass + Process class name. User-defined. + + pProcessFriendlyName + The process friendly name.User-defined + + pEnvironmentStrings + The environment strings that will be + set before launching the process on a node + The environment strings are represented as + a series of null-terminated UTF8 strings + with an extra NULL at the end + + pAppProcessConstraints + See PCXC_PROCESS_CONSTRAINTS above + + NumberOfResourceFileDescriptors + The number of resource file descriptors in + the pResourceFileDescriptors array + + pResourceFileDescriptors + Pointer to array of + PCXC_RESOURCEFILE_DESCRIPTOR's. These + resources will be copied to the process + working directory before launching the + process using the commandline + +--*/ +typedef struct tagXC_CREATEPROCESS_DESCRIPTOR{ + SIZE_T Size; + DWORD Flags; + PCSTR pCommandLine; + PCSTR pProcessClass; + PCSTR pProcessFriendlyName; + PCSTR pEnvironmentStrings; + PCXC_PROCESS_CONSTRAINTS pAppProcessConstraints; + SIZE_T NumberOfResourceFileDescriptors; + PCXC_RESOURCEFILE_DESCRIPTOR pResourceFileDescriptors; +} XC_CREATEPROCESS_DESCRIPTOR, *PXC_CREATEPROCESS_DESCRIPTOR; +typedef const XC_CREATEPROCESS_DESCRIPTOR* PCXC_CREATEPROCESS_DESCRIPTOR; + + + +/*++ + +Defines the Network Locality Params used in the +XcGetNetworkLocalityPathOfProcessNode() API. +These params are passed to the API, to identify, the +Affinity level. The resulting NetworkLocalityParam returned +from the API, can then be passed cia the XC_AFFINITY struct +(defined below), to the Process Scheduler, to help the +Process Scheduler in making decisions about the choice of +Process Node to pick to run a given XCompute process. + +NOTE: The special Network Locality Param ".." can be combined + with other locality params to represent one level up from + the current level. + E.g. XCLOCALITYPARAM_POD/.. (NOTE the forward slash) +--*/ +#define XCLOCALITYPARAM_ONELEVELUP ".." +#define XCLOCALITYPARAM_POD "POD" +#define XCLOCALITYPARAM_L2SWITCH "L2" +#define XCLOCALITYPARAM_L3SWITCH "L3" +#define XCLOCALITYPARAM_VLAN "VLAN" +#define XCLOCALITYPARAM_CLUSTER "CLUSTER" +#define XCLOCALITYPARAM_DATACENTER "DC" + + + +/*++ + +Defines the bit flag used in the XC_AFFINITY structure. +If XCAFFINITY_HARD of the Flags in the XC_AFFINITY structure is +set, then the affinity is considered to have hard affinity +to the NetworkNodePath/s. See below for details. + +--*/ +#define XCAFFINITY_HARD 0x01 + + + +/*++ + +XC_AFFINITY structure + +Each Affinity is comprised of list of network locality paths, +an associated weight and a flag for hard affinity. +A network locality can refer to a data center, a top/middle +level switch, POD, or a specific host machine. + +Fields: + + Size + Sizeof(XC_AFFINITY) + + Flags + Bit flags. XC_AFFINITY_HARD indicates that + affinity is hard affinity. + + Weight + The Process Scheduler will give preference to + the Affinity (list of Nodes) that have higher + weight, while picking up the Node on which to + run the XCompute Process. + The intended units for Weight are + "estimated bytes of I/O" + + NumberOfNetworkLocalityPaths + Number of Nodes in + pNetworkLocalityPaths array. + + pNetworkLocalityPaths + Pointer to the network locality paths array. + A network locality path is represented as a + string and is an opaque format. The caller + gets the locality path information by calling the + XcGetNetworkLocalityPath API + +--*/ +typedef struct tagXC_AFFINITY{ + SIZE_T Size; + DWORD Flags; + UINT64 Weight; + SIZE_T NumberOfNetworkLocalityPaths; + PCSTR* pNetworkLocalityPaths; +} XC_AFFINITY, *PXC_AFFINITY; +typedef const XC_AFFINITY* PCXC_AFFINITY; + + + +/*++ + +XC_LOCALITY_DESCRIPTOR structure + +Locality is represented as a collection of Affinities. + +Fields: + + Size + sizeof(XC_LOCALITY_DESCRIPTOR) + + Flags + Reserved. Must be 0. + + NumberOfAffinities + Number of XC_AFFINITY'es + + pAffinities + Pointer to Array of Affinities + +--*/ +typedef struct tagXC_LOCALITY_DESCRIPTOR{ + SIZE_T Size; + DWORD Flags; + SIZE_T NumberOfAffinities; + PXC_AFFINITY pAffinities; +} XC_LOCALITY_DESCRIPTOR, *PXC_LOCALITY_DESCRIPTOR; +typedef const XC_LOCALITY_DESCRIPTOR* PCXC_LOCALITY_DESCRIPTOR; + + + +/*++ + +XC_SCHEDULEPROCESS_DESCRIPTOR + +The descriptor that has all the information about the process +to be scheduled + +Fields: + + Size + sizeof(XC_SCHEDULEPROCESS_DESCRIPTOR) + + Flags Reserved for later use. Must be 0. + + ProcessPriority + The priority of the process. The priority + is process priority, within all the + processes for a given job. This is + different from job priority. + + pLocalityDescriptor + See XC_LOCALITY_DESCRIPTOR above + + pCreateProcessDescriptor + See XC_CREATEPROCESS_DESCRIPTOR above + +--*/ +typedef struct tagXC_SCHEDULEPROCESS_DESCRIPTOR{ + SIZE_T Size; + DWORD Flags; + UINT32 ProcessPriority; + PCXC_LOCALITY_DESCRIPTOR pLocalityDescriptor; + PCXC_CREATEPROCESS_DESCRIPTOR pCreateProcessDescriptor; +} XC_SCHEDULEPROCESS_DESCRIPTOR, *PXC_SCHEDULEPROCESS_DESCRIPTOR; + +typedef +const XC_SCHEDULEPROCESS_DESCRIPTOR* PCXC_SCHEDULEPROCESS_DESCRIPTOR; + + + +/*++ + +XC_PROCESSPROPERTY_INFO + +The structure is embedded in the XC_POCESS_INFO struct explained +below. It has all the information related to a particular property + +Fields: + + Size + sizeof(XC_SCHEDULE_PROCESS_RESULTS) + + pPropertyLabel + The property label + + propertyVersion + The property version + + pPropertyString + The property string value + + PropertyBlockSize + Memory block size of property + + pPropertyBlock + Pointer to memory block related to property + +--*/ +typedef struct tagXC_PROCESSPROPERTY_INFO{ + SIZE_T Size; + PCSTR pPropertyLabel; + UINT64 PropertyVersion; + PCSTR pPropertyString; + SIZE_T PropertyBlockSize; + PVOID pPropertyBlock; +} XC_PROCESSPROPERTY_INFO, *PXC_PROCESSPROPERTY_INFO; +typedef const XC_PROCESSPROPERTY_INFO* PCXC_PROCESSPROPERTY_INFO; + + + +/*++ + +XC_PROCESS_STATISTIDR + +Contains all the statistics related to a given process/job + +Fields: + + Size + sizeof(XC_PROCESS_STATISTIDR) + + Flags + Reserved for later use + + ProcessUserTime + Total user time the whole process + consumed in 100 nanosec + + ProcessKernelTime + Total kernel time the whole process + consumed in 100 nanosec + + PageFaults + Total #page faults for the whole process + + TotalProcessesCreated + Total #win32 processes the process ever + created + + PeakVMUsage + The peak Virtual memory usage + + PeakMemUsage + The peak working set memory usage + + MemUsageSeconds + Working set memory usage * time used + + TotalIo + Total IO transferred + +--*/ +typedef struct tagXC_PROCESS_STATISTIDR{ + SIZE_T Size; + DWORD Flags; + XCTIMEINTERVAL ProcessUserTime; + XCTIMEINTERVAL ProcessKernelTime; + INT32 PageFaults; + INT32 TotalProcessesCreated; + UINT64 PeakVMUsage; + UINT64 PeakMemUsage; + UINT64 MemUsageSeconds; + UINT64 TotalIo; +} XC_PROCESS_STATISTIDR, *PXC_PROCESS_STATISTIDR; +typedef const XC_PROCESS_STATISTIDR* PCXC_PROCESS_STATISTIDR; + + + +/*++ + +Bit flag definitions for XC_PROCESSINFO structure that is used +in the GetProcessProperty API + +--*/ +#define XCPROCESSINFOOPTION_STATICINFO (0x01) +#define XCPROCESSINFOOPTION_TIMINGINFO (0x02) +#define XCPROCESSINFOOPTION_EFFECTIVECONSTRAINTS (0x04) +#define XCPROCESSINFOOPTION_EXTENDEDPROCESSDESCRIPTOR (0x08) +#define XCPROCESSINFOOPTION_EXTENDEDJOBDESCRIPTOR (0x10) +#define XCPROCESSINFOOPTION_PROCESSSTAT (0x20) +#define XCPROCESSINFOOPTION_APPCONSTRAINTS (0x40) +#define XCPROCESSINFOOPTION_SYSTEMCONSTRAINTS (0x80) + +#define XCPROCESSINFOOPTION_All \ + XCPROCESSINFOOPTION_STATICINFO | \ + XCPROCESSINFOOPTION_TIMINGINFO | \ + XCPROCESSINFOOPTION_EFFECTIVECONSTRAINTS \ + XCPROCESSINFOOPTION_EXTENDEDPROCESSDESCRIPTOR | \ + XCPROCESSINFOOPTION_EXTENDEDJOBDESCRIPTOR | \ + XCPROCESSINFOOPTION_PROCESSSTAT | \ + XCPROCESSINFOOPTION_APPCONSTRAINTS | \ + XCPROCESSINFOOPTION_SYSTEMCONSTRAINTS + + + +/*++ + +XC_SETANDGETPROCESSINFO_REQINPUT + +The structure is used to make the XcPnSetAndGetProcessInfo call. +It contains the various inputs to the API clubbed together. + +Fields: + + Size + sizeof(XC_SETANDGETPROCESSINFO_REQINPUT) + + pAppProcessConstraints + The process constraints to be set for the + process. The user will need to preserve this + structure till the async call is completed + + NumberofProcessPropertiesToSet + The number of properties to set in the + pPropertiesToSet array + + ppPropertiesToSet Pointer to property info array. These are the + properties that will be set in this call. + + pBlockOnPropertyLabel + Name of the property on which to block.The + request finishes, when either the process + terminates, or the property is changed or + after timeout amount of time. + + BlockOnPropertyversionLastSeen + The latest known version number of property + on which to block + + MaxBlockTime Time to wait for property to change or pricess + to terminste before returning with unchanged + property version. If 0, API returns + immediately with current values. + + pPropertyFetchTemplate + The property fetch template. It support the + * wild card. A set of properties, whose + labels match the propertyFetchTemplate are + returned. If NULL, no properties are returned + + ProcessInfoFetchOptions + bit flag indicating the different + processInfo fields to fetch. + +--*/ +typedef struct tagXC_SETANDGETPROCESSINFO_REQINPUT{ + DWORD Size; + PXC_PROCESS_CONSTRAINTS pAppProcessConstraints; + SIZE_T NumberOfProcessPropertiesToSet; + PXC_PROCESSPROPERTY_INFO* ppPropertiesToSet; + PCSTR pBlockOnPropertyLabel; + UINT64 BlockOnPropertyversionLastSeen; + XCTIMEINTERVAL MaxBlockTime; + PCSTR pPropertyFetchTemplate; + DWORD ProcessInfoFetchOptions; +} XC_SETANDGETPROCESSINFO_REQINPUT, + *PXC_SETANDGETPROCESSINFO_REQINPUT; + +typedef +const XC_SETANDGETPROCESSINFO_REQINPUT* PCXC_SETANDGETPROCESSINFO_REQINPUT; + + + +/*++ + +XC_PROCESS_INFO + +The structure gets returned as a result of the XcPnGetProcessProperty +call. Use the XcFreeMemory API to release memory for this structure + +Fields: + + Size + sizeof(XC_SCHEDULE_PROCESS_RESULTS) + + Flags + Bit flag that indicates which fields in the + data structure have valid information. The + bit flags are defined above + + ProcessState + The current state of the process from the PN's point of view. + This field is always sent. + + ProcessStatus + The process status. Indicates whether the + process is running or exited, and the reason. + This field is always sent. + + ExitCode + The process exit code. + This field is always sent. + + Win32Pid The Windows processId of the process + This field is always sent. + + NumberofProcessProperties + Number of XC_PROCESSPROPERTY_INFO's returned + + ppProperties + Array of XC_PROCESSPROPERTY_INFO structs + + CurrentPnTime + Always sent. This is the time on PN + + CreatedTime + Time when Win32 CreateProcess was + initiated (XCDATETIME_NEVER if not yet created) + Bit flag:XCPROCESSINFOOPTION_TIMINGINFO + + BeginExecutionTime + Time when Win32 process was first resumed + (XCDATETIME_NEVER if not yet resumed) + Bit flag:XCPROCESSINFOOPTION_TIMINGINFO + + TerminatedTime + Time when Win32 Process terminated + (XCDATETIME_NEVER if not yet terminated) + Bit flag:XCPROCESSINFOOPTION_TIMINGINFO + + LastPropertyUpdateTime + Most recent time when any property was set + + pEffectiveProcessConstraints + Effective constraints for the process (combined constraints + from application and system) + pointer to XC_PROCESS_CONSTRAINTS struct + Bit flag:XCPROCESSINFOOPTION_EFFECTIVECONSTRAINTS + + pAppProcessConstraints + Application constraints for the process + pointer to XC_PROCESS_CONSTRAINTS struct + Bit flag:XCPROCESSINFOOPTION_APPCONSTRAINTS + + pSystemProcessConstraints + System constraints for the process + pointer to XC_PROCESS_CONSTRAINTS struct + Bit flag:XCPROCESSINFOOPTION_SYSTEMCONSTRAINTS + + pCommandLine + The command line for the process + + pProcessStatistics + Pointer to the XC_PROCESS_STATISTIDR struct + Bit flag:XCPROCESSINFOOPTION_STAT + +--*/ +typedef struct tagXC_PROCESS_INFO{ + SIZE_T Size; + DWORD Flags; + XCPROCESSSTATE ProcessState; + XCERROR ProcessStatus; + XCEXITCODE ExitCode; + UINT32 Win32Pid; + UINT32 NumberofProcessProperties; + PXC_PROCESSPROPERTY_INFO *ppProperties; + XCDATETIME CurrentPnTime; + XCDATETIME CreatedTime; + XCDATETIME BeginExecutionTime; + XCDATETIME TerminatedTime; + XCDATETIME LastPropertyUpdateTime; + PXC_PROCESS_CONSTRAINTS pEffectiveProcessConstraints; + PXC_PROCESS_CONSTRAINTS pAppProcessConstraints; + PXC_PROCESS_CONSTRAINTS pSystemProcessConstraints; + PCSTR pCommandLine; + PXC_PROCESS_STATISTIDR pProcessStatistics; +} XC_PROCESS_INFO, *PXC_PROCESS_INFO; +typedef const XC_PROCESS_INFO* PCXC_PROCESS_INFO; + + + +/*++ + +XC_SETANDGETPROCESSINFO_REQRESULTS + +The structure is gets returned as a result of call to the +XcPnSetAndGetProcessInfo API. +It contains the results that match the ProcessInfoFetchOptions +and the PropertyFetchTemplate passed to the API via the +XC_SETANDGETPROCESSINFO_REQINPUT struct + +Fields: + + Size + sizeof(XC_SETANDGETPROCESSINFO_REQINPUT) + + pProcessInfo + The process info that has information about + all the properties for which information + was asked to be retreived (using the + PropertyFetchTemplate). It also has all + the information that was asked to be + retreived using the ProcessInfoFetchOptions. + + NumberOfPropertyVersions + The number of property versions in the + pPropertyVersions array + + pPropertyVersions + Pointer to array of property versions. + Note: The indexes of version numbers in the + pPropertyVersions array corrosponds 1:1 with the + pPropertiesToSet array in the + XC_SETANDGETPROCESSINFO_REQINPUT that gets + passed to the XcPnSetAndGetProcessInfo() API. +--*/ +typedef struct tagXC_SETANDGETPROCESSINFO_REQRESULTS{ + DWORD Size; + PXC_PROCESS_INFO pProcessInfo; + UINT32 NumberOfPropertyVersions; + UINT64* pPropertyVersions; +} XC_SETANDGETPROCESSINFO_REQRESULTS, + *PXC_SETANDGETPROCESSINFO_REQRESULTS; + +typedef +const XC_SETANDGETPROCESSINFO_REQRESULTS* + PCXC_SETANDGETPROCESSINFO_REQRESULTS; + + + +/*++ + +XCPROCESSFILEHANDLE +A handle to represent an open XCompute Process File. +This is used in the XCompute Process File API, which gives +the ability to read remote files written by a XComputeProcess +into its working directory + +Fields: +--*/ +typedef struct tagXCPROCESSFILEHANDLE +{ + ULONG_PTR Unused; +} *XCPROCESSFILEHANDLE, **PXCPROCESSFILEHANDLE; + +const XCPROCESSFILEHANDLE INVALID_XCPROCESSFILEHANDLE = NULL; + + + +/* File offset value for XCompute files */ +typedef UINT64 XCPROCESSFILEPOSITION, *PXCPROCESSFILEPOSITION; + + + +/* + +Various XcGetProcessFileSize options + + XCREFRESH_AGGRESSIVE (default) + - visit server to find out latest known length + + XCREFRESH_PASSIVE + - return length from local cache if available otherwise + visit server to find out latest known length + + XCREFRESH_FROM_CACHE + - return length from local cache + fail if not available + +*/ +#define XCREFRESH_AGGRESSIVE 0x10000000u +#define XCREFRESH_PASSIVE 0x20000000u +#define XCREFRESH_FROM_CACHE 0x30000000u + + + +#pragma warning( pop ) + +#if defined(__cplusplus) +} +#endif diff --git a/DryadVertex/VertexHost/system/classlib/include/basic_types.h b/DryadVertex/VertexHost/system/classlib/include/basic_types.h new file mode 100644 index 0000000..dc1f75e --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/basic_types.h @@ -0,0 +1,321 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#pragma warning (disable: 4619) // #pragma warning : there is no warning number 'NNNN' + +#pragma warning (push) + +// +// Disabled warnings +// + +#pragma warning (disable: 4255) +#pragma warning (disable: 4668) // XXX is not defined as preprocessor macro, replacing with 0 +#pragma warning (disable: 4820) // 'N' bytes padding added after data member 'XXX' +#pragma warning (disable: 4365) // conversion from 'LONG64' to 'DWORD64', signed/unsigned mismat +#pragma warning (disable: 4548) // malloc.h(245) & STL: expression before comma has no effect; expected expression with side-effect +#pragma warning (disable: 4995) // deprecated functions + +#include +#include +#include +#include +#include +// include for DNS hostname support +//#include + +#pragma warning (pop) + +#if defined (ELEMENTCOUNT) +#undef ELEMENTCOUNT +#endif +#define ELEMENTCOUNT(x) (sizeof(x)/sizeof(x[0])) + +typedef __int8 Int8; +typedef unsigned __int8 UInt8; +typedef __int16 Int16; +typedef unsigned __int16 UInt16; +typedef __int32 Int32; +typedef unsigned __int32 UInt32; +typedef __int64 Int64; +typedef unsigned __int64 UInt64; +typedef size_t Size_t; + +#ifdef _M_IX86 +typedef int xint; +typedef unsigned int uxint; +#else +typedef __int64 xint; +typedef unsigned __int64 uxint; +#endif + +#define MAX_UINT8 ((UInt8)-1) +#define MAX_UINT16 ((UInt16)-1) +#define MAX_UINT32 ((UInt32)-1) +#define MIN_INT32 ((Int32)0x80000000) +#define MAX_INT32 ((Int32)0x7FFFFFFF) // 2147483647 +#define MAX_UINT64 ((UInt64)-1) +#define MAX_INT64 0x7FFFFFFFFFFFFFFFi64 +#define MAX_FLOAT (3.402823466e+38F) +#define MAX_SIZE_T ((size_t)((size_t)0 - (size_t)1)) + +#define PF_I64D "%I64d" +#define PF_I64X "%016I64x" +#define PF_I64O "%022I64o" + +// A structure to represent a fixed-size chunk of +// memory +struct SIZED_STRING +{ + union + { + const UInt8 *pbData; + const char *pcData; + }; + size_t cbData; +}; + +// A helper macro for defining a SIZED_STRING as part of a constant +#define INLINE_SIZED_STRING(str) { (const UInt8 *)str, sizeof(str) - 1 } + +// Structure to store a wchar_t version of a user dictionary word in memory +struct WCHAR_SIZED_STRING +{ + wchar_t* pData; + size_t cchData; +}; + +// A helper macro for defining a SIZED_STRING as part of a constant +#define INLINE_WCHAR_SIZED_STRING(wstr) { wstr, sizeof(wstr)/sizeof(wstr[0]) - 1 } + +// Structure to store a wchar_t version of a user dictionary word in memory +struct WCHAR_SIZED_STRING_CONST +{ + const wchar_t* pData; + size_t cchData; +}; + +// A utility class for creating temporary SIZED_STRINGs +class CStackSizedString : public SIZED_STRING +{ +public: + // Null-terminated input + CStackSizedString(const char *szValue) + { + pbData = (const UInt8 *)szValue; + cbData = strlen(szValue); + } + + // Name/size pair + CStackSizedString( + const UInt8 *pbValue, + size_t cbValue) + { + pbData = pbValue; + cbData = cbValue; + } + + // Name/size pair + CStackSizedString( + const char *pcValue, + size_t cbValue) + { + pcData = pcValue; + cbData = cbValue; + } + +private: + // prevent heap allocation + void *operator new(size_t); +}; + +// A utility class for creating temporary WCHAR_SIZED_STRINGs +class CStackSizedWString : public WCHAR_SIZED_STRING_CONST +{ +public: + // Null-terminated input + CStackSizedWString(const wchar_t *wzValue) + { + pData = wzValue; + cchData = wcslen(wzValue); + } + + // Name/size pair + CStackSizedWString( + const wchar_t *wzValue, + size_t cchValue) + { + pData = wzValue; + cchData = cchValue; + } + +private: + // prevent heap allocation + void *operator new(size_t); +}; + +//This is an interface to disallow heap construction. +class INoHeapInstance +{ +private: + void* operator new(size_t); +}; + +// DDWORD is used to easily access a 64bit number as both a 64bit +// and as two 32bit numbers +typedef union +{ + struct + { + DWORD low; + DWORD high; + } dw; + DWORD64 ddw; +} DDWORD; + +// critical section wrapper +// +class CRITSEC +{ +private: + CRITICAL_SECTION m_critsec; + +public: + CRITSEC() + { + InitializeCriticalSection(&m_critsec); + } + + ~CRITSEC() + { + DeleteCriticalSection(&m_critsec); + } + + void Enter() + { + EnterCriticalSection(&m_critsec); + } + + BOOL TryEnter() + { + return TryEnterCriticalSection(&m_critsec); + } + + void Leave() + { + LeaveCriticalSection(&m_critsec); + } + + DWORD SetSpinCount(DWORD spinCount = 4000) + { + return SetCriticalSectionSpinCount(&m_critsec, spinCount); + } +}; + +//Smart wrapper around CRITSEC so that we dont need to call enter/leave +class AutoCriticalSection : public INoHeapInstance +{ +private: + CRITSEC* m_pCritSec; +public: + AutoCriticalSection(CRITSEC* pCritSec) : m_pCritSec(pCritSec) + { + this->m_pCritSec->Enter(); + } + ~AutoCriticalSection() + { + this->m_pCritSec->Leave(); + } +}; + +/////////////////////////////////// +#pragma warning(push) +#pragma warning (disable: 4201) // nonstandard extension used : nameless struct/union + +typedef union tagFAInt64 +{ + UInt64 n64; + LONG64 i64; + UInt8 nBytes[8]; + FILETIME ft; + + struct + { + UInt32 nData; + UInt32 nCount; + }; + struct + { + Int32 i32_low; + Int32 i32_hi; + }; + struct + { + UInt32 n32_low; + UInt32 n32_hi; + }; + + tagFAInt64():n64(0){}; + tagFAInt64( UInt64 nVal ):n64(nVal){}; + tagFAInt64( Int64 iVal ):i64(iVal){}; + tagFAInt64( UInt32 nHigh, UInt32 nLow ): n32_low(nLow), n32_hi(nHigh){}; + tagFAInt64( Int32 iHigh, Int32 iLow ): i32_low(iLow), i32_hi(iHigh){}; + tagFAInt64( DWORD nHigh, DWORD nLow ): n32_low(nLow), n32_hi(nHigh){}; + tagFAInt64( const FILETIME& rftSrc ): ft(rftSrc){}; + +} FAInt64; // flexible access 64 bit integer + +typedef union tagFAInt32 +{ + UInt32 n32; + Int32 i32; + UInt8 nBytes[4]; + struct + { + Int16 i16_low; + Int16 i16_hi; + }; + struct + { + UInt16 n16_low; + UInt16 n16_hi; + }; + tagFAInt32():n32(0){}; + tagFAInt32( UInt32 nVal ): n32(nVal){}; + tagFAInt32( Int32 iVal ): i32(iVal){}; + +} FAInt32; // flexible access 32 bit integer + +#define NUMELEM(p) (sizeof(p)/sizeof((p)[0])) + +#pragma warning(pop) + +// +// disabling a couple of warnings that show up with /Wall and are pretty much useless -- they come mostly from macros +// +#pragma warning (disable: 4514) // unreferenced inline function has been removed +#pragma warning (disable: 4820) // 'N' bytes padding added after data member 'XXX' +#pragma warning (disable: 4265) // class has virtual functions, but destructor is not virtual +#pragma warning (disable: 4668) // XXX is not defined as preprocessor macro, replacing with 0 +#pragma warning (disable: 4711) // function XXX selected for automatic inline expansion +#pragma warning (disable: 4548) // malloc.h(245) & STL: expression before comma has no effect; expected expression with side-effect +#pragma warning (disable: 4127) // conditional expression is constant diff --git a/DryadVertex/VertexHost/system/classlib/include/fingerprint.h b/DryadVertex/VertexHost/system/classlib/include/fingerprint.h new file mode 100644 index 0000000..30717eb --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/fingerprint.h @@ -0,0 +1,54 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "basic_types.h" +#include "ms_fprint.h" + +class FingerPrint64 +{ +public: + static FingerPrint64* GetInstance(); + + UInt64 GetFingerPrint( + const void* data, + const size_t length + ); + + class FingerPrint64Init + { + public: + FingerPrint64Init(); + ~FingerPrint64Init(); + private: + static UInt64 count; + }; + static void Init(); + static void Dispose(); + FingerPrint64(UInt64 poly); + FingerPrint64(); + ~FingerPrint64(); +private: + ms_fprint_data_t fp; + static FingerPrint64* instance; +}; + +static FingerPrint64::FingerPrint64Init fpInit; diff --git a/DryadVertex/VertexHost/system/classlib/include/ms_fprint.h b/DryadVertex/VertexHost/system/classlib/include/ms_fprint.h new file mode 100644 index 0000000..e7e6529 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/include/ms_fprint.h @@ -0,0 +1,57 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/* (c) Microsoft Corporation. All rights reserved. */ + +#if !defined(_MS_FPRINT_H_) +#define _MS_FPRINT_H_ + +#if defined(__GNUC__) +typedef unsigned long long ms_fprint_t; +#endif + +#if defined(_MSC_VER) +typedef unsigned __int64 ms_fprint_t; +#pragma warning(disable:4127) +#include +#endif + +typedef struct ms_fprint_data_s *ms_fprint_data_t; +/* an opaque type used to keep the data structures need to compute + fingerprints. */ + +ms_fprint_data_t ms_fprint_new (); +/* Computes the tables needed for fingerprint manipulations. */ + +ms_fprint_data_t ms_fprint_new (ms_fprint_t poly); +/* Computes the tables needed for fingerprint manipulations. */ + +ms_fprint_t ms_fprint_of (ms_fprint_data_t fp, + void *data, size_t len); +/* if fp was generated with polynomial P, and bytes + "data[0, ..., len-1]" contain string B, + return the fingerprint under P of the concatenation of B. + Strings are treated as polynomials. The low-order bit in the first + byte is the highest degree coefficient in the polynomial.*/ + +void ms_fprint_destroy (ms_fprint_data_t fp); +/* discard the data associated with "fp" */ + +#endif diff --git a/DryadVertex/VertexHost/system/classlib/src/DrCriticalSection.cpp b/DryadVertex/VertexHost/system/classlib/src/DrCriticalSection.cpp new file mode 100644 index 0000000..2a6f9bc --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/DrCriticalSection.cpp @@ -0,0 +1,196 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrCommon.h" + +#pragma unmanaged + +void DrCriticalSectionBase::Init( + __in PCSTR name, + DWORD spinCount, + bool logUsage, + DrTimeInterval logHeldTooLongTimeout + ) +{ + _name = (name != NULL) ? name : "UnknownCritSec"; + _logUsage = logUsage; + _logHeldTooLongTimeoutMs = DrGetTimerMsFromInterval(logHeldTooLongTimeout); + _lastFunctionName = NULL; + _lastFileName = NULL; + _lastLineNumber = 0; + _enterTimeMs = 0; + LogAssert( InitializeCriticalSectionAndSpinCount( this, spinCount ) ); +} + +void DrCriticalSectionBase::SetCriticalSectionLoggingParameters( + bool logUsage, + DrTimeInterval logHeldTooLongTimeout +) +{ + _logHeldTooLongTimeoutMs = DrGetTimerMsFromInterval(logHeldTooLongTimeout); + _logUsage = logUsage; +} + +void DrCriticalSectionBase::SetCriticalSectionLogging( + bool logUsage +) +{ + _logUsage = logUsage; +} + +void DrCriticalSectionBase::SetCriticalSectionLogHeldTooLongTimeout( + DrTimeInterval logHeldTooLongTimeout +) +{ + _logHeldTooLongTimeoutMs = DrGetTimerMsFromInterval(logHeldTooLongTimeout); +} + +void DrCriticalSectionBase::Enter( PCSTR functionName, PCSTR fileName, UINT lineNumber ) +{ + UInt32 tms = GetTickCount(); + + EnterCriticalSection( this ); + + _enterTimeMs = GetTickCount(); + + tms = _enterTimeMs - tms; + + Int32 entryCount = -1; + Int32 contentionCount = -1; + if ( DebugInfo != NULL ) + { + entryCount = (Int32)DebugInfo->EntryCount; + contentionCount = (Int32)DebugInfo->ContentionCount; + } + + _lastFunctionName = functionName; + _lastFileName = fileName; + _lastLineNumber = lineNumber; + + if ( tms > _logHeldTooLongTimeoutMs ) + { + if (_lastFileName != NULL) { + DrLogW( "CritSect WAITED TO ENTER TOO LONG %s at %s %s(%u), entryCount=%d, contentionCount=%d, waited for %ums, addr=%08Ix", + _name, _lastFunctionName, _lastFileName, _lastLineNumber, entryCount, contentionCount, tms, this ); + } + else + { + DrLogW( "CritSect WAITED TO ENTER TOO LONG %s, entryCount=%d, contentionCount=%d, waited for %ums, addr=%08Ix", + _name, entryCount, contentionCount, tms, this ); + } + } + + if ( _logUsage ) + { + if (_lastFileName != NULL) { + DrLogD( "CritSect ENTER %s at %s %s(%u), entryCount=%d, contentionCount=%d, waited for %ums, addr=%08Ix", + _name, _lastFunctionName, _lastFileName, _lastLineNumber, entryCount, contentionCount, tms, this ); + } + else + { + DrLogD( "CritSect ENTER %s at %s, entryCount=%d, contentionCount=%d, waited for %ums, addr=%08Ix", + _name, _lastFunctionName, entryCount, contentionCount, tms, this ); + } + } +} + +void DrCriticalSectionBase::Leave( PCSTR functionName, PCSTR fileName, UINT lineNumber ) +{ + DebugLogAssert( Aquired() ); + + UInt32 tms = 0; + if ( _enterTimeMs != 0 ) // technically wrong, but it will fail to complain on 1 out of 4 billion slow locks + { + tms = GetTickCount() - _enterTimeMs; + _enterTimeMs = 0; + } + + Int32 entryCount = -1; + Int32 contentionCount = -1; + if ( DebugInfo != NULL ) + { + entryCount = (Int32)DebugInfo->EntryCount; + contentionCount = (Int32)DebugInfo->ContentionCount; + } + + if ( tms > _logHeldTooLongTimeoutMs ) + { + if (fileName != NULL) + { + if (_lastFileName != NULL) + { + DrLogW( "CritSect LEAVE, HELD TOO LONG %s at %s %s(%u) entered at %s %s(%u), entryCount=%d, contentionCount=%d, time held=%ums, addr=%08Ix", + _name, functionName, fileName, lineNumber, _lastFunctionName, _lastFileName, _lastLineNumber, entryCount, contentionCount, tms, this ); + } + else + { + DrLogW( "CritSect LEAVE, HELD TOO LONG %s at %s %s(%u), entryCount=%d, contentionCount=%d, time held=%ums, addr=%08Ix", + _name, functionName, fileName, lineNumber, entryCount, contentionCount, tms, this ); + } + } + else + { + if (_lastFileName != NULL) + { + DrLogW( "CritSect LEAVE, HELD TOO LONG %s entered at %s %s(%u), entryCount=%d, contentionCount=%d, time held=%ums, addr=%08Ix", + _name, _lastFunctionName, _lastFileName, _lastLineNumber, entryCount, contentionCount, tms, this ); + } + else + { + DrLogW( "CritSect LEAVE, HELD TOO LONG %s, entryCount=%d, contentionCount=%d, time held=%ums, addr=%08Ix", + _name, entryCount, contentionCount, tms, this ); + } + } + } + + if ( _logUsage ) + { + if (fileName != NULL) + { + if (_lastFileName != NULL) + { + DrLogD( "CritSect LEAVE %s at %s %s(%u) entered at %s %s(%u), entryCount=%d, contentionCount=%d, time held=%ums, addr=%08Ix", + _name, functionName, fileName, lineNumber, _lastFunctionName, _lastFileName, _lastLineNumber, entryCount, contentionCount, tms, this ); + } + else + { + DrLogD( "CritSect LEAVE %s at %s %s(%u), entryCount=%d, contentionCount=%d, time held=%ums, addr=%08Ix", + _name, functionName, fileName, lineNumber, entryCount, contentionCount, tms, this ); + } + } + else + { + if (_lastFileName != NULL) + { + DrLogD( "CritSect LEAVE %s entered at %s %s(%u), entryCount=%d, contentionCount=%d, time held=%ums, addr=%08Ix", + _name, _lastFunctionName, _lastFileName, _lastLineNumber, entryCount, contentionCount, tms, this ); + } + else + { + DrLogD( "CritSect LEAVE %s, entryCount=%d, contentionCount=%d, time held=%ums, addr=%08Ix", + _name, entryCount, contentionCount, tms, this ); + } + } + } + + LeaveCriticalSection( this ); +} + + diff --git a/DryadVertex/VertexHost/system/classlib/src/DrError.cpp b/DryadVertex/VertexHost/system/classlib/src/DrError.cpp new file mode 100644 index 0000000..5302df8 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/DrError.cpp @@ -0,0 +1,480 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrCommon.h" +#include +#include + +#pragma unmanaged + +typedef struct { + DrError value; + const char *pszDescription; +} ErrEntry; + + +typedef struct _ErrHashEntry { + DrError value; + const char *pszDescription; + struct _ErrHashEntry *pNext; +} ErrHashEntry; + + +static ErrEntry g_DryadErrorMap[] = { + +#undef DEFINE_DR_ERROR +#undef COMMON_DR_ERRORS_DEFINED + +#define DEFINE_DR_ERROR(name, number, description) {name, description}, + +#include "DrError.h" + +#undef DEFINE_DR_ERROR +}; + + +class DrErrorTable : public DrTempStringPool +{ +public: + static const DWORD k_hashTableSize = 111; + +public: + DrErrorTable() + { + m_fInitialized = false; + for (int i = 0; i < k_hashTableSize; i++) { + m_pBucket[i] = NULL; + } + m_hModuleNetMsg = NULL; + m_hModuleWinHttp = NULL; + } + + ~DrErrorTable() + { + ErrHashEntry *p; + + for (int i = 0; i < k_hashTableSize; i++) { + while ((p = m_pBucket[i]) != NULL) { + m_pBucket[i] = p->pNext; + delete p; + } + + } + + if (m_hModuleNetMsg != NULL) { + FreeLibrary(m_hModuleNetMsg); + } + + if (m_hModuleWinHttp != NULL) { + FreeLibrary(m_hModuleWinHttp); + } + } + + void Lock() + { + m_lock.Enter(); + } + + void Unlock() + { + m_lock.Leave(); + } + + // The returned error string should be freed with free(); + // + // If the error code is unknown, a generic error description is returned. + char *GetSystemErrorText(DWORD dwError) + { + LPSTR MessageBuffer; + DWORD dwBufferLength; + HANDLE hModule = NULL; + + DWORD dwUse = dwError; + DWORD dwNormalized = dwError; + if ((dwNormalized & 0xFFFF0000) == ((FACILITY_WIN32 << 16) | 0x80000000)) { + dwNormalized = dwNormalized & 0xFFFF; + } + + DWORD dwFormatFlags = FORMAT_MESSAGE_ALLOCATE_BUFFER | + FORMAT_MESSAGE_IGNORE_INSERTS | + FORMAT_MESSAGE_FROM_SYSTEM ; + + // + // If dwLastError is in the network range, + // load the message source. + // + + if (dwNormalized >= NERR_BASE && dwNormalized <= MAX_NERR) { + Lock(); + if (m_hModuleNetMsg == NULL) { + m_hModuleNetMsg = LoadLibraryEx( + TEXT("netmsg.dll"), + NULL, + LOAD_LIBRARY_AS_DATAFILE + ); + } + Unlock(); + + if(m_hModuleNetMsg != NULL) { + dwFormatFlags |= FORMAT_MESSAGE_FROM_HMODULE; + hModule = m_hModuleNetMsg; + } + } else if (dwNormalized >= WINHTTP_ERROR_BASE && dwNormalized <= WINHTTP_ERROR_LAST) { + Lock(); + if (m_hModuleWinHttp == NULL) { + m_hModuleWinHttp = LoadLibraryEx( + TEXT("winhttp.dll"), + NULL, + LOAD_LIBRARY_AS_DATAFILE + ); + } + Unlock(); + + if(m_hModuleWinHttp != NULL) { + dwFormatFlags |= FORMAT_MESSAGE_FROM_HMODULE; + hModule = m_hModuleWinHttp; + dwUse = dwNormalized; + } + } + + // + // Call FormatMessage() to allow for message + // text to be acquired from the system + // or from the supplied module handle. + // + // For perf, we assume here that all ANSI error messages are also valid UTF-8. If + // this turns out not to be true, we'll have to get the unicode message and convert to UTF-8 + if ((dwBufferLength = FormatMessageA( + dwFormatFlags, + hModule, + dwUse, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // default language + (LPSTR) &MessageBuffer, + 0, + NULL + )) != 0) { + + char *pszBuffer = (char *)malloc(dwBufferLength + 1); + LogAssert(pszBuffer != NULL); + memcpy(pszBuffer, MessageBuffer, dwBufferLength+1); + LocalFree(MessageBuffer); + + DWORD i; + for (i = 0; i < dwBufferLength; ++i) + if (pszBuffer[i] == '\r' || pszBuffer[i] == '\n') + pszBuffer[i] = ' '; + + return pszBuffer; + } + + + char *pszBuffer = (char *)malloc(64); + LogAssert(pszBuffer != NULL); + _snprintf(pszBuffer, 64, "Error code %u (0x%08x)", dwError, dwError); + return pszBuffer; + } + + // The returned error string should be freed with free(); + // + // If the error code is unknown, a generic error description is returned. +/*JC + WCHAR *GetSystemErrorTextW(DWORD dwError) + { + LPWSTR MessageBuffer; + DWORD dwBufferLength; + + DWORD dwFormatFlags = FORMAT_MESSAGE_ALLOCATE_BUFFER | + FORMAT_MESSAGE_IGNORE_INSERTS | + FORMAT_MESSAGE_FROM_SYSTEM ; + + // + // If dwLastError is in the network range, + // load the message source. + // + + if (dwError >= NERR_BASE && dwError <= MAX_NERR) { + Lock(); + if (m_hModuleNetMsg == NULL) { + m_hModuleNetMsg = LoadLibraryEx( + TEXT("netmsg.dll"), + NULL, + LOAD_LIBRARY_AS_DATAFILE + ); + } + Unlock(); + + if(m_hModuleNetMsg != NULL) { + dwFormatFlags |= FORMAT_MESSAGE_FROM_HMODULE; + } + } + + // + // Call FormatMessage() to allow for message + // text to be acquired from the system + // or from the supplied module handle. + // + + if ((dwBufferLength = FormatMessageW( + dwFormatFlags, + m_hModuleNetMsg, + dwError, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // default language + (LPWSTR) &MessageBuffer, + 0, + NULL + )) != 0) { + + WCHAR *pszBuffer = (WCHAR *)malloc((dwBufferLength + 1) * sizeof(WCHAR)); + LogAssert(pszBuffer != NULL); + memcpy(pszBuffer, MessageBuffer, (dwBufferLength+1) * sizeof(WCHAR)); + LocalFree(MessageBuffer); + return pszBuffer; + } + + + WCHAR*pszBuffer = (WCHAR *)malloc(64 * sizeof(WCHAR)); + LogAssert(pszBuffer != NULL); + _snwprintf(pszBuffer, 64, L"Error code %u (0x%08x)", dwError, dwError); + return pszBuffer; + } +*/ + + static DWORD IndexOf(DrError err) { + return ((DWORD)err % k_hashTableSize); + } + + // Returns NULL if not found + // Must be called under lock + ErrHashEntry *FindEntry(DrError err) + { + DWORD dw = IndexOf(err); + ErrHashEntry *p = m_pBucket[dw]; + while (p != NULL) { + if (p->value == err) { + return p; + } + p = p->pNext; + } + return NULL; + } + + // returns ERROR_ALREADY_ASSIGNED if the error has already been defined + DrError AddEntry(DrError code, const char *pszDescription) + { + DrError err = DrError_Fail; + Lock(); + + if (FindEntry(code) != NULL) { + err = HRESULT_FROM_WIN32( ERROR_ALREADY_ASSIGNED ); + goto done; + } + DWORD dw = IndexOf(code); + ErrHashEntry *p = new ErrHashEntry; + LogAssert(p != NULL); + p->value = code; + p->pszDescription = dupStr(pszDescription); + p->pNext = m_pBucket[dw]; + m_pBucket[dw] = p; + err = DrError_OK; + + done: + Unlock(); + return err; + } + + // The returned error string should be freed with free(); + // + // If the error code is unknown, a generic error description is returned. + char *GetErrorText(DrError err) + { + char *pszText = NULL; + Lock(); + Initialize(); + ErrHashEntry *p = FindEntry(err); + if (p != NULL) { + const char *pszOrig = p->pszDescription; + if (pszOrig != NULL) { + Size_t len = strlen(pszOrig)+1; + pszText = (char *)malloc(len); + memcpy(pszText, pszOrig, len); + } + } + Unlock(); + if (pszText == NULL) { + pszText = GetSystemErrorText((DWORD) err); + } + LogAssert(pszText != NULL); + return pszText; + } + +/*JC + WCHAR *GetErrorTextW(DrError err) + { + WCHAR *pszText = NULL; + Lock(); + Initialize(); + ErrHashEntry *p = FindEntry(err); + if (p != NULL) { + const char *pszOrig = p->pszDescription; + if (pszOrig != NULL) { + DrWStr128 wstrText(pszOrig); + Size_t len = wstrText.GetLength()+1; + pszText = (WCHAR *)malloc(len * sizeof(WCHAR)); + memcpy(pszText, wstrText.GetString(), len * sizeof(WCHAR)); + } + } + Unlock(); + if (pszText == NULL) { + pszText = GetSystemErrorTextW((DWORD) err); + } + LogAssert(pszText != NULL); + return pszText; + } +*/ + inline void Initialize() + { + if (!m_fInitialized) { + Lock(); + AddStaticCodes(); + m_fInitialized = true; + Unlock(); + } + } + + void AddStaticCodes() + { + int n = sizeof(g_DryadErrorMap) / sizeof(g_DryadErrorMap[0]); + for (int i = 0; i < n; i++) { + DrError err = AddEntry(g_DryadErrorMap[i].value, g_DryadErrorMap[i].pszDescription); + if (err != DrError_OK) { + DrLogE( "DrErrorTable::AddStaticCodes - Could not add error code 0x%08x: [%s] to table, error: 0x%08x", g_DryadErrorMap[i].value, g_DryadErrorMap[i].pszDescription, err); + } + LogAssert(err == DrError_OK); + } +} + +private: + ErrHashEntry *m_pBucket[k_hashTableSize]; + HMODULE m_hModuleNetMsg; + HMODULE m_hModuleWinHttp; + DrCriticalSection m_lock; + bool m_fInitialized; +}; + + +DrErrorTable g_csErrorTable; + +// The returned error string should be freed with free(); +// +// If the error code is unknown, a generic error description is returned. +char *DrGetErrorText(DrError err) +{ + return g_csErrorTable.GetErrorText(err); +} + +// The returned error string should be freed with free(); +// +// If the error code is unknown, a generic error description is returned. +/*JC +WCHAR *DrGetErrorTextW(DrError err) +{ + return g_csErrorTable.GetErrorTextW(err); +} +*/ + +void DrInitErrorTable(void) +{ + g_csErrorTable.Initialize(); +} + +extern DrError DrAddErrorDescription(DrError code, const char *pszDescription) +{ + return g_csErrorTable.AddEntry(code, pszDescription); +} + +// The buffer must be at least 64 bytes long to guarantee a result. If the result won't fit in the buffer, a generic +// error message is generated. +const char *DrGetErrorDescription(DrError err, char *pBuffer, int buffLen) +{ + if (buffLen < 64) { + return "Error Description Buffer too short!"; + } + char *pszText = DrGetErrorText(err); + LogAssert(pszText != NULL); + Size_t len = strlen(pszText)+1; + if ((int) len <= buffLen) { // AKadatch: don't use "int". // TODO: fix this hack. + memcpy(pBuffer, pszText, len); + } else { + _snprintf(pBuffer, (size_t) (buffLen-1), "Error code %u (0x%08x)", err, err); + } + + free(pszText); + + return pBuffer; +} + +// The buffer must be at least 64 chars long to guarantee a result. If the result won't fit in the buffer, a generic +// error message is generated. +/* JC +const WCHAR *DrGetErrorDescription(DrError err, WCHAR *pBuffer, int buffLen) +{ + if (buffLen < 64) { + return L"Error Description Buffer too short!"; + } + WCHAR *pszText = DrGetErrorTextW(err); + LogAssert(pszText != NULL); + Size_t len = wcslen(pszText)+1; + if ((int) len <= buffLen) { // AKadatch: don't use "int". // TODO: fix this hack. + memcpy(pBuffer, pszText, len * sizeof(WCHAR)); + } else { + _snwprintf(pBuffer, (size_t) (buffLen-1), L"Error code %u (0x%08x)", err, err); + } + + free(pszText); + + return pBuffer; +} +*/ + +// The buffer must be at least 64 bytes long to guarantee a result. If the result won't fit in the buffer, a generic +// error message is generated. +DrStr& DrAppendErrorDescription(DrStr& strOut, DrError err) +{ + char *pszText = DrGetErrorText(err); + LogAssert(pszText != NULL); + strOut.Append(pszText); + free(pszText); + return strOut; +} + +// The buffer must be at least 64 chars long to guarantee a result. If the result won't fit in the buffer, a generic +// error message is generated. +/* JC +DrWStr& DrAppendErrorDescription(DrWStr& strOut, DrError err) +{ + WCHAR *pszText = DrGetErrorTextW(err); + LogAssert(pszText != NULL); + strOut.Append(pszText); + free(pszText); + return strOut; +} +*/ + + diff --git a/DryadVertex/VertexHost/system/classlib/src/DrExecution.cpp b/DryadVertex/VertexHost/system/classlib/src/DrExecution.cpp new file mode 100644 index 0000000..e2cc1aa --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/DrExecution.cpp @@ -0,0 +1,50 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrExecution.h" + +DrExecutionEnvironment *g_pDryadExecution = NULL; + +//JC +#if 0 +//JC#define SECURITY_WIN32 +//JC#define SEC_SUCCESS(Status) ((Status) >= 0) +//JC#include + + +DrError DrInitExecution() +{ + LogAssert(g_pDrExecution == NULL); + g_pDrExecution = new DrExecutionEnvironment(); + LogAssert(g_pDrExecution != NULL); + DrError err = g_pDrExecution->Initialize(); + if (err != DrError_OK) { + delete g_pDrExecution; + g_pDrExecution = NULL; + } + return err; +} + +DrError DrExecutionEnvironment::Initialize() +{ +//JC return DrInitPipeManager(); + return DrError_OK; +} +#endif diff --git a/DryadVertex/VertexHost/system/classlib/src/DrExitCodes.cpp b/DryadVertex/VertexHost/system/classlib/src/DrExitCodes.cpp new file mode 100644 index 0000000..93f54d0 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/DrExitCodes.cpp @@ -0,0 +1,209 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrCommon.h" + +#pragma unmanaged + +struct ExitCodeEntry { + DrExitCode value; + const char *pszDescription; +}; + + +struct ExitCodeHashEntry { + DrExitCode value; + const char *pszDescription; + struct ExitCodeHashEntry *pNext; +}; + + +static ExitCodeEntry g_DryadExitCodeMap[] = { + +#ifdef DEFINE_DREXITCODE +#undef DEFINE_DREXITCODE +#endif +#ifdef DECLARE_DREXITCODE +#undef DECLARE_DREXITCODE +#endif +#define DEFINE_DREXITCODE(name, value, description) {(DrExitCode)(value), description}, +#define DECLARE_DREXITCODE(valname, description) {(DrExitCode)(valname), description}, + +#include "DrExitCodes.h" + +#undef DEFINE_DREXITCODE +#undef DECLARE_DREXITCODE +}; + + +class DrExitCodeTable : public DrTempStringPool +{ +public: + static const int k_hashTableSize = 111; + +public: + DrExitCodeTable() + { + m_fInitialized = false; + for (int i = 0; i < k_hashTableSize; i++) { + m_pBucket[i] = NULL; + } + } + + ~DrExitCodeTable() + { + ExitCodeHashEntry *p; + + for (int i = 0; i < k_hashTableSize; i++) { + while ((p = m_pBucket[i]) != NULL) { + m_pBucket[i] = p->pNext; + delete p; + } + + } + } + + void Lock() + { + m_lock.Enter(); + } + + void Unlock() + { + m_lock.Leave(); + } + + static int IndexOf(DrExitCode code) { + return (int)((DWORD)code % (DWORD)k_hashTableSize); + } + + // Returns NULL if not found + // Must be called under lock + ExitCodeHashEntry *FindEntry(DrExitCode code) + { + int i = IndexOf(code); + ExitCodeHashEntry *p = m_pBucket[i]; + while (p != NULL) { + if (p->value == code) { + return p; + } + p = p->pNext; + } + return NULL; + } + + // returns HRESULT(ERROR_ALREADY_ASSIGNED) if the error has already been defined + DrError AddEntry(DrExitCode code, const char *pszDescription) + { + DrError err = DrError_Fail; + Lock(); + + if (FindEntry(code) != NULL) { + err = HRESULT_FROM_WIN32( ERROR_ALREADY_ASSIGNED ); + goto done; + } + + int i = IndexOf(code); + ExitCodeHashEntry *p = new ExitCodeHashEntry; + LogAssert(p != NULL); + p->value = code; + p->pszDescription = dupStr(pszDescription); + p->pNext = m_pBucket[i]; + m_pBucket[i] = p; + err = DrError_OK; + + done: + Unlock(); + return err; + } + + inline void Initialize() + { + if (!m_fInitialized) { + Lock(); + AddStaticExitCodes(); + m_fInitialized = true; + Unlock(); + } + } + + void AddStaticExitCodes() + { + int n = sizeof(g_DryadExitCodeMap) / sizeof(g_DryadExitCodeMap[0]); + for (int i = 0; i < n; i++) { + DrError err = AddEntry(g_DryadExitCodeMap[i].value, g_DryadExitCodeMap[i].pszDescription); + LogAssert(err == DrError_OK); + } + } + + DrStr& AppendExitCodeDescription(DrStr& strOut, DrExitCode code) + { + Lock(); + Initialize(); + ExitCodeHashEntry *p = FindEntry(code); + if (p != NULL) { + strOut.Append(p->pszDescription); + } + Unlock(); + if (p == NULL) { + if (code < 256) { + strOut.AppendF("%u", code); + } else { + strOut.AppendF("0x%08x (%u)", code, code); + } + } + return strOut; + } + + DrStr& GetExitCodeDescription(DrStr& strOut, DrExitCode code) + { + strOut = ""; + return AppendExitCodeDescription(strOut, code); + } + +private: + ExitCodeHashEntry *m_pBucket[k_hashTableSize]; + DrCriticalSection m_lock; + bool m_fInitialized; +}; + + +DrExitCodeTable g_csExitCodeTable; + +void DrInitExitCodeTable() +{ + g_csExitCodeTable.Initialize(); +} + +DrError DrAddExitCodeDescription(DrExitCode code, const char *pszDescription) +{ + return g_csExitCodeTable.AddEntry(code, pszDescription); +} + +DrStr& DrAppendExitCodeDescription(DrStr& strOut, DrExitCode code) +{ + return g_csExitCodeTable.AppendExitCodeDescription(strOut, code); +} + +DrStr& DrGetExitCodeDescription(DrStr& strOut, DrExitCode code) +{ + return g_csExitCodeTable.GetExitCodeDescription(strOut, code); +} + diff --git a/DryadVertex/VertexHost/system/classlib/src/DrFPrint.cpp b/DryadVertex/VertexHost/system/classlib/src/DrFPrint.cpp new file mode 100644 index 0000000..5c13a96 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/DrFPrint.cpp @@ -0,0 +1,283 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/* (c) Microsoft Corporation. All rights reserved. */ +#include +#include "DrFPrint.h" +#include "DrFPrint_polynomials.h" + +#pragma unmanaged + +static void initbybyte (Dryad_dupelim_fprint_data_t fp, + Dryad_dupelim_fprint_t bybyte[][256], + Dryad_dupelim_fprint_t f) { + unsigned b; + for (b = 0; b < 8; b++) { + unsigned i, i2; + bybyte[b][0] = 0; + for (i = 0x80, i2 = 0x100; i > 0; i >>= 1) { + unsigned j; + bybyte[b][i] = f; + for (j = i2; i + j < 256; j += i2) + bybyte[b][i+j] = f ^ bybyte[b][j]; + i2 = i; + f = fp->poly[f & 1] ^ (f >> 1); + } + } +} + +void Dryad_dupelim_fprint_init (Dryad_dupelim_fprint_data_t fp, + Dryad_dupelim_fprint_t poly, unsigned span, + int degree) { + Dryad_dupelim_fprint_t l; + int i; + fp->poly[0] = 0; + fp->poly[1] = poly; /*This must be initialized early on */ + fp->empty = poly; + fp->span = span; + initbybyte (fp, fp->bybyte, poly); + memset (&fp->zeroes, 0, sizeof (fp->zeroes)); + /* The initialization of powers[] must happen after bybyte[][] + and zeroes are initialized because concat uses all of + bybyte[][], zeroes and the prefix of powers[] internally. */ + if (degree > 8) + fp->powers[0] = ((Dryad_dupelim_fprint_t) 1) << (degree - 9); + else + fp->powers[0] = fp->bybyte[0][((size_t)0x1) << (8-degree)]; + for (i = 1, l = 1; + i != sizeof (fp->powers) / sizeof (fp->powers[0]); + i++, l <<= 1) { + fp->powers[i] = + Dryad_dupelim_fprint_concat (fp, fp->powers[i-1] ^ poly, 0, l); + } + if (span != 0) { + initbybyte (fp, fp->bybyte_out, + Dryad_dupelim_fprint_concat (fp, 0, 0, + (span-1) * 8)); + } +} + +Dryad_dupelim_fprint_data_t Dryad_dupelim_fprint_new (Dryad_dupelim_fprint_t poly, + unsigned span ) { + Dryad_dupelim_fprint_data_t fp = (Dryad_dupelim_fprint_data_t) malloc (sizeof (*fp)); + Dryad_dupelim_fprint_init(fp, poly, span, 64); + return fp; +} + +Dryad_dupelim_fprint_data_t Dryad_dupelim_fprint_new2 (Dryad_dupelim_fprint_t poly, + unsigned span, int degree) { + Dryad_dupelim_fprint_data_t fp; + if ((degree > 64) || (degree < 1)) + return 0; /* bad choice for degree */ + fp = (Dryad_dupelim_fprint_data_t) malloc (sizeof (*fp)); + Dryad_dupelim_fprint_init(fp, poly, span, degree); + return fp; +} + +Dryad_dupelim_fprint_t Dryad_dupelim_fprint_empty (Dryad_dupelim_fprint_data_tc fp) { + return (fp->empty); +} + +Dryad_dupelim_fprint_t +Dryad_dupelim_fprint_slideword (Dryad_dupelim_fprint_data_tc fp, + Dryad_dupelim_fprint_t f, + Dryad_dupelim_fprint_uint64_t a, + Dryad_dupelim_fprint_uint64_t b ) { + a ^= fp->poly[1] ^ (((Dryad_dupelim_fprint_t) 1) << 63); + /* a now also gets rid of the old leading 1, and adds a new + one */ + f ^=fp->bybyte_out[7][a & 0xff] ^ + fp->bybyte_out[6][(a >> 8) & 0xff] ^ + fp->bybyte_out[5][(a >> 16) & 0xff] ^ + fp->bybyte_out[4][(a >> 24) & 0xff] ^ + fp->bybyte_out[3][(a >> 32) & 0xff] ^ + fp->bybyte_out[2][(a >> 40) & 0xff] ^ + fp->bybyte_out[1][(a >> 48) & 0xff] ^ + fp->bybyte_out[0][a >> 56]; + f ^= b; + f = fp->bybyte[7][f & 0xff] ^ + fp->bybyte[6][(f >> 8) & 0xff] ^ + fp->bybyte[5][(f >> 16) & 0xff] ^ + fp->bybyte[4][(f >> 24) & 0xff] ^ + fp->bybyte[3][(f >> 32) & 0xff] ^ + fp->bybyte[2][(f >> 40) & 0xff] ^ + fp->bybyte[1][(f >> 48) & 0xff] ^ + fp->bybyte[0][f >> 56]; + return (f); +} + +Dryad_dupelim_fprint_t +Dryad_dupelim_fprint_extend_word (Dryad_dupelim_fprint_data_tc fp, + Dryad_dupelim_fprint_t init, + const Dryad_dupelim_fprint_uint64_t *data, + unsigned len ) { + unsigned i; + for (i = 0; i != len; i++) { + init ^= data[i]; + init = fp->bybyte[7][init & 0xff] ^ + fp->bybyte[6][(init >> 8) & 0xff] ^ + fp->bybyte[5][(init >> 16) & 0xff] ^ + fp->bybyte[4][(init >> 24) & 0xff] ^ + fp->bybyte[3][(init >> 32) & 0xff] ^ + fp->bybyte[2][(init >> 40) & 0xff] ^ + fp->bybyte[1][(init >> 48) & 0xff] ^ + fp->bybyte[0][init >> 56]; + } + return (init); +} + +Dryad_dupelim_fprint_t +Dryad_dupelim_fprint_extend (Dryad_dupelim_fprint_data_tc fp, + Dryad_dupelim_fprint_t init, const unsigned char *data, + unsigned len ) { + unsigned char *p = (unsigned char*) data; + unsigned char *e = p+len; + while (p != e && (((Dryad_dupelim_fprint_uint64_t) p) & 7L) != 0) { + init = (init >> 8) ^ fp->bybyte[0][(init & 0xff) ^ *p++]; + } + while (p+8 <= e) { + init ^= *(Dryad_dupelim_fprint_t *)p; + init = fp->bybyte[7][init & 0xff] ^ + fp->bybyte[6][(init >> 8) & 0xff] ^ + fp->bybyte[5][(init >> 16) & 0xff] ^ + fp->bybyte[4][(init >> 24) & 0xff] ^ + fp->bybyte[3][(init >> 32) & 0xff] ^ + fp->bybyte[2][(init >> 40) & 0xff] ^ + fp->bybyte[1][(init >> 48) & 0xff] ^ + fp->bybyte[0][init >> 56]; + p += 8; + } + + while (p != e) { + init = (init >> 8) ^ fp->bybyte[0][(init & 0xff) ^ *p++]; + } + return (init); +} + + +Dryad_dupelim_fprint_t +Dryad_dupelim_fprint_concat (Dryad_dupelim_fprint_data_tc fp, + Dryad_dupelim_fprint_t a, + Dryad_dupelim_fprint_t b, + Dryad_dupelim_fprint_t blen) { + int i; + Dryad_dupelim_fprint_t x = blen; + unsigned low = (unsigned)x & ((1 << fp->LOGZEROBLOCK)-1); + a ^= fp->poly[1]; + if (low != 0) { + a = Dryad_dupelim_fprint_extend (fp, a, fp->zeroes.zeroes, low); + } + x >>= fp->LOGZEROBLOCK; + i = fp->LOGZEROBLOCK; + while (x != 0) { + if (x & 1) { + Dryad_dupelim_fprint_t m = 0; + Dryad_dupelim_fprint_t bit; + Dryad_dupelim_fprint_t e = fp->powers[i]; + for (bit = ((Dryad_dupelim_fprint_t) 1) << 63; bit != 0; bit >>= 1) { + if (e & bit) { + m ^= a; + } + a = (a >> 1) ^ fp->poly[a & 1]; + } + a = m; + } + x >>= 1; + i++; + } + return (a ^ b); +} + + +void Dryad_dupelim_fprint_toascii (Dryad_dupelim_fprint_t f, char *buf) { + int i; + for (i = 60; i != -4; i -= 4) { + *buf++ = "0123456789abcdef"[(f >> i) & 0xf]; + } +} + + +void Dryad_dupelim_fprint_close (Dryad_dupelim_fprint_data_t fp) { + free (fp); +} + + +// rabin hash functions +bool Dryad_dupelim_rabinhash_init (Dryad_dupelim_fprint_data_s* pHashData, + HashPolyLength hashLen, + UInt32 seed) +{ + UInt64 polySize = 0; + int degree = 0; + const Dryad_dupelim_fprint_t* pPolys = NULL; + switch(hashLen) + { + + case Poly8bit: + polySize = cbPolys8; + pPolys = polys8; + degree = 8; + break; + case Poly16bit: + polySize = cbPolys16; + pPolys = polys16; + degree = 16; + break; + case Poly32bit: + polySize = cbPolys32; + pPolys = polys32; + degree = 32; + break; + case Poly64bit: + polySize = cbPolys64; + pPolys = polys64; + degree = 64; + break; + default: + return false; + } + + if (pHashData == NULL) + { + return false; + } + + if (polySize<=seed) + { + return false; + } + + Dryad_dupelim_fprint_init(pHashData,pPolys[seed], 0, degree); + return true; +} + +Dryad_dupelim_fprint_t Dryad_dupelim_rabinhash_process(Dryad_dupelim_fprint_data_s* pHashFunction, + const unsigned char *data, unsigned len) +{ + return Dryad_dupelim_fprint_extend( pHashFunction, Dryad_dupelim_fprint_empty(pHashFunction), data, len); +} + +Dryad_dupelim_fprint_t Dryad_dupelim_rabinhash_add(Dryad_dupelim_fprint_data_s* pHashFunction, Dryad_dupelim_fprint_t initialHash, + const unsigned char *data, unsigned len) +{ + return Dryad_dupelim_fprint_extend( pHashFunction, initialHash, data, len); +} + + diff --git a/DryadVertex/VertexHost/system/classlib/src/DrFunctions.cpp b/DryadVertex/VertexHost/system/classlib/src/DrFunctions.cpp new file mode 100644 index 0000000..b8a0e0e --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/DrFunctions.cpp @@ -0,0 +1,1680 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include +#include + +#pragma unmanaged + +#pragma warning (disable: 4995) // '_snprintf': name was marked as #pragma deprecated + +DrError DrStringToSignedOrUnsignedInt64(const char *psz, UInt64 *pResult, bool fSigned) +{ + UInt64 v = 0; + UInt64 vnew; + bool neg = false; + bool gotDig = false; + int base = 10; + + while (ISSPACE(*psz)) { + psz++; + } + + if (*psz == '+') { + psz++; + } else if (fSigned && *psz == '-') { + neg = true; + psz++; + } + + if (*psz == '0' && (*(psz+1) == 'x' ||*(psz+1) == 'X') ) { + psz += 2; + base = 16; + } + + if (base == 16 && neg == false) { + // we allow hex constants to set the sign bit. + fSigned = false; + } + + while (*psz != '\0' && !ISSPACE(*psz)) { + int dig = -1; + if (*psz >= '0' && *psz <= '9') { + dig = *psz - '0'; + } else if (*psz >= 'a' && *psz <= 'f') { + dig = *psz - 'a' + 10; + } else if (*psz >= 'A' && *psz <= 'F') { + dig = *psz - 'A' + 10; + } + if (dig < 0 || dig >= base) { + return DrError_InvalidParameter; + } + vnew = v * base + dig; + if ((fSigned && (Int64)vnew < 0) || ((vnew - dig)/base != v)) { + // overflow + return DrError_InvalidParameter; + } + v = vnew; + gotDig = true; + psz++; + } + + if (!gotDig) { + return DrError_InvalidParameter; + } + + while (ISSPACE(*psz)) { + psz++; + } + + if (*psz != '\0') { + return DrError_InvalidParameter; + } + + if (neg) { + v= (UInt64)(-(Int64)v); + } + + *pResult = v; + return DrError_OK; +} + +DrError DrStringToFloat(const char *psz, float *pResult) +{ + double v; + DrError err = DrStringToDouble(psz, &v); + if (err == DrError_OK) { + *pResult = (float)v; + } + return err; +} + +DrError DrStringToUInt64(const char *psz, UInt64 *pResult) +{ + return DrStringToSignedOrUnsignedInt64(psz, pResult, false); +} + +DrError DrStringToInt64(const char *psz, Int64 *pResult) +{ + return DrStringToSignedOrUnsignedInt64(psz, (UInt64 *)(void *)pResult, true); +} + +DrError DrStringToUInt16(const char *psz, UInt16 *pResult) +{ + UInt64 v; + DrError err = DrStringToSignedOrUnsignedInt64(psz, &v, false); + if (err != DrError_OK) { + return err; + } + UInt16 v16 = (UInt16)v; + if (v != (UInt64)v16) { + return DrError_InvalidParameter; + } + *pResult = v16; + return DrError_OK; +} + +DrError DrStringToUInt32(const char *psz, UInt32 *pResult) +{ + UInt64 v; + DrError err = DrStringToSignedOrUnsignedInt64(psz, &v, false); + if (err != DrError_OK) { + return err; + } + UInt32 v32 = (UInt32)v; + if (v != (UInt64)v32) { + return DrError_InvalidParameter; + } + *pResult = v32; + return DrError_OK; +} + +DrError DrStringToInt32(const char *psz, Int32 *pResult) +{ + Int64 v; + DrError err = DrStringToSignedOrUnsignedInt64(psz, (UInt64 *)(void *)&v, true); + if (err != DrError_OK) { + return err; + } + Int32 v32 = (Int32)v; + if (v != (Int64)v32) { + return DrError_InvalidParameter; + } + *pResult = v32; + return DrError_OK; +} + +DrError DrStringToUInt(const char *psz, unsigned int *pResult) +{ + UInt64 v; + DrError err = DrStringToSignedOrUnsignedInt64(psz, &v, false); + if (err != DrError_OK) { + return err; + } + unsigned int v32 = (unsigned int)v; + if (v != (UInt64)v32) { + return DrError_InvalidParameter; + } + *pResult = v32; + return DrError_OK; +} + +DrError DrStringToInt(const char *psz, int *pResult) +{ + Int64 v; + DrError err = DrStringToSignedOrUnsignedInt64(psz, (UInt64 *)(void *)&v, true); + if (err != DrError_OK) { + return err; + } + int v32 = (int)v; + if (v != (Int64)v32) { + return DrError_InvalidParameter; + } + *pResult = v32; + return DrError_OK; +} + +DrError DrStringToDouble(const char *psz, double *pResult) +{ + double v = 0; + int exponent = 0; + bool neg = false; + bool gotDig = false; + bool gotPoint = false; + + while (ISSPACE(*psz)) { + psz++; + } + + if (*psz == '+') { + psz++; + } else if (*psz == '-') { + neg = true; + psz++; + } + + while (*psz != '\0' && !ISSPACE(*psz) && *psz != 'e' && *psz != 'E') { + int dig = -1; + if (!gotPoint && *psz == '.') { + gotPoint = true; + psz++; + continue; + } else if (*psz >= '0' && *psz <= '9') { + dig = *psz - '0'; + } + if (dig < 0 || dig >= 10) { + return DrError_InvalidParameter; + } + v = v * 10.0 + dig; + if (gotPoint) { + exponent--; + } + gotDig = true; + psz++; + } + + if (!gotDig) { + return DrError_InvalidParameter; + } + + if (*psz == 'e' || *psz == 'E') { + psz++; + Int32 exp2; + DrError err = DrStringToInt32(psz, &exp2); + if (err != DrError_OK) { + return err; + } + exponent += exp2; + } else { + while (ISSPACE(*psz)) { + psz++; + } + + if (*psz != '\0') { + return DrError_InvalidParameter; + } + } + + if (exponent != 0) { + v = v * pow((double) 10, exponent); + } + + if (neg) { + v= -v; + } + + *pResult = v; + return DrError_OK; +} + +DrError DrStringToBool(const char *psz, bool *pResult) +{ + bool ret = false; + char tmp[16]; + size_t length = 0; + + while (ISSPACE(*psz)) { + psz++; + } + + while (length < 15 && *psz != '\0' && !ISSPACE(*psz)) { + tmp[length++] = *(psz++); + } + + tmp[length] = '\0'; + + while (ISSPACE(*psz)) { + psz++; + } + + if (*psz != '\0') { + return DrError_InvalidParameter; + } + + _strlwr(tmp); + + length++; + if (strncmp(tmp, "true", length) == 0 || + strncmp(tmp, "yes", length) == 0 || + strncmp(tmp, "on", length) == 0 || + strncmp(tmp, "1", length) == 0) { + ret = true; + } else if ( + strncmp(tmp, "false", length) == 0 || + strncmp(tmp, "no", length) == 0 || + strncmp(tmp, "off", length) == 0 || + strncmp(tmp, "0", length) == 0) { + ret = false; + } else { + return DrError_InvalidParameter; + } + + *pResult = ret; + return DrError_OK; +} + +DrError DrStringToSizeEx(PCSTR psz, UInt64* result, bool allowNegative) +{ + if ( psz == NULL || result == NULL ) + { + return DrError_InvalidParameter; + } + + char buf[48]; + DrError err = StringCbCopyA( buf, sizeof( buf ), psz ); + if ( FAILED( err ) ) + { + return DrError_InvalidParameter; + } + + UInt32 shift = 0; + Size_t len = strlen(psz); + if ( len > 2 ) + { + PCSTR tail = buf + len - 2; + if ( _stricmp( tail, "KB" ) == 0 ) + { + shift = 10; + } + else if ( _stricmp( tail, "MB" ) == 0 ) + { + shift = 20; + } + else if ( _stricmp( tail, "GB" ) == 0 ) + { + shift = 30; + } + else if ( _stricmp( tail, "TB" ) == 0 ) + { + shift = 40; + } + else if ( _stricmp( tail, "PB" ) == 0 ) + { + shift = 50; + } + } + + if ( shift != 0 ) + { + buf[len - 2] = 0; + } + + //check if dot is present + PSTR dot = strchr( buf, '.' ); + + if ( dot != NULL ) + { + // can't have fractions if size specifier is not present + if ( shift == 0 ) + { + return DrError_InvalidParameter; + } + + *dot = 0; + } + + bool negative = false; + UInt64 r1; + err = DrStringToSignedOrUnsignedInt64( buf, &r1, allowNegative ); + if ( err != DrError_OK ) + { + return err; + } + if ( allowNegative && ((Int64)r1 < 0ui64) ) + { + negative = true; + r1 = (UInt64)-(Int64)r1; + } + + UInt64 maxVal = allowNegative ? MAX_INT64 : MAX_UINT64; + if ( dot != NULL ) + { + maxVal >>= 4; // reserve space for fraction + } + maxVal >>= shift; + + if ( r1 > maxVal ) + { + return DrError_InvalidParameter; + } + + r1 <<= shift; + + if ( dot != NULL ) + { + char buf2[48]; + buf2[0] = '0'; + buf2[1] = '.'; + buf2[2] = 0; + + err = StringCbCatA( buf2, sizeof(buf2), dot + 1 ); + if ( FAILED( err ) ) + { + return DrError_InvalidParameter; + } + + double rd; + err = DrStringToDouble( buf2, &rd ); + if ( err != DrError_OK ) + { + return err; + } + + rd *= 1ui64 << shift; + r1 += (UInt64)rd; + } + + if ( negative ) + { + r1 = (UInt64)-(Int64)r1; + } + + *result = r1; + + return DrError_OK; +} + +DrError DrStringToSize(PCSTR psz, UInt64* result) +{ + return DrStringToSizeEx( psz, result, false ); +} + +DrError DrStringToIntegerSize(PCSTR psz, Int64* result) +{ + return DrStringToSizeEx( psz, (UInt64*)result, true ); +} + + +// +// Close xcompute session after completing outstanding requests +// +DrError DryadShutdownXCompute(); + +// +// Same as ExitProcess, but flushes logging and stdout/stderr first... +// +void DrExitProcess(UInt32 exitCode) +{ + // + // Close the xcompute connection + // + DrError e = DryadShutdownXCompute(); + if (e == DrError_OK) + { + DrLogI("Completed uninitialise xcompute"); + } + else + { + DrLogE("Couldn't uninitialise xcompute"); + } + + // + // Flush output logs + // + fflush(stdout); + DrLogging::FlushLog(); + + // + // Exit the current process + // + ExitProcess((UINT)exitCode); +} + +DrError DrSystemTimeToTimeStamp(const SYSTEMTIME *pSystemTime, DrTimeStamp *pTimeStamp, bool fFromLocalTimeZone) +{ + union { + FILETIME ft; + DrTimeStamp ts; + }; + + union { + FILETIME ft2; + DrTimeStamp ts2; + }; + + if (!SystemTimeToFileTime(pSystemTime, &ft)) { + return DrGetLastError(); + } + + if (fFromLocalTimeZone && ts != DrTimeStamp_LongAgo && ts != DrTimeStamp_Never) { + if (!LocalFileTimeToFileTime(&ft, &ft2)) { + return DrGetLastError(); + } + *pTimeStamp = ts2; + } else { + *pTimeStamp = ts; + } + + return DrError_OK; +} + +// Returns UTC time +DrTimeStamp DrGetCurrentTimeStamp() +{ + union { + FILETIME ft; + DrTimeStamp ts; + }; + GetSystemTimeAsFileTime(&ft); + return ts; +} + +DrTimeInterval DrGetCurrentLocalTimeZoneBias() +{ + TIME_ZONE_INFORMATION tzi; + LONG biasMinutes; + + DWORD dwRet = GetTimeZoneInformation(&tzi); + switch(dwRet) { + case TIME_ZONE_ID_UNKNOWN: + biasMinutes = tzi.Bias; + break; + + case TIME_ZONE_ID_STANDARD: + biasMinutes = tzi.Bias + tzi.StandardBias; + break; + + case TIME_ZONE_ID_DAYLIGHT: + biasMinutes = tzi.Bias + tzi.DaylightBias; + break; + + default: + LogAssert(dwRet != dwRet); + return DrTimeInterval_Zero; + } + + DrTimeInterval bias = DrTimeInterval_Minute * biasMinutes; + return bias; +} + +DrError DrGenerateTimeZoneBiasSuffix(DrTimeInterval bias, char *szBuff, size_t nbBuff) +{ + + if (bias == DrTimeInterval_Zero) { + if (nbBuff < 2) { + return HRESULT_FROM_WIN32( ERROR_INSUFFICIENT_BUFFER ); + } + + szBuff[0]= 'Z'; + szBuff[1] = '\0'; + return DrError_OK; + } else { + if (nbBuff < 4) { + return HRESULT_FROM_WIN32( ERROR_INSUFFICIENT_BUFFER ); + } + *(szBuff++)= 'L'; + nbBuff--; + char s = '+'; + if (bias < DrTimeInterval_Zero) { + s = '-'; + bias = -bias; + } + *(szBuff++)= s; + nbBuff--; + *szBuff = '\0'; + return DrTimeIntervalToString(bias, szBuff, nbBuff); + } +} + +__inline DrError DrMultiplyTimeInterval(DrTimeInterval *pValue, Int64 n) +{ + DrTimeInterval v2 = *pValue * n; + if (n != 0 && v2 / n != *pValue) { + return DrError_InvalidTimeInterval; + } + *pValue = v2; + return DrError_OK; +} + +static DrError DrAddDeltaUnitTimeInterval(DrTimeInterval *pValue, Int64 n, DrTimeInterval units) +{ + DrError err = DrMultiplyTimeInterval(&units, n); + if (err != DrError_OK) { + return err; + } + + // if the signs of the two addends are the same, we have to check for overflow in the result + bool checkSign = ((*pValue >= DrTimeInterval_Zero) == (units >= DrTimeInterval_Zero)); + + *pValue += units; + + // if the signs of the two addends were the same, they still should be + if (checkSign && (*pValue >= DrTimeInterval_Zero) != (units >= DrTimeInterval_Zero)) { + return DrError_InvalidTimeInterval; + } + + return DrError_OK; +} + +// Converts a time interval string to a DrTimeInterval. +// If len is -1, it is computed with strlen. +// Strings must include units; e.g., "105.42s" or "12d5h10m". +DrError DrStringToTimeInterval(const char *pszString, DrTimeInterval *pTimeInterval, int len) +{ + DrError err = DrError_OK; + DrTimeInterval val = 0; + bool neg = false; + char szBuff[32]; + + if (pszString == NULL) { + return DrError_InvalidTimeInterval; + } + + if (pszString[0] == '+') { + pszString++; + } else if (pszString[0] == '-') { + pszString++; + neg = true; + } + + if (len < 0) { + len = (int)strlen(pszString); + } + + if (len == 0 || *pszString == '\0') { + return DrError_InvalidTimeInterval; + } + + if (len == 8 && _strnicmp(pszString, "infinite", 8) == 0) { + *pTimeInterval = DrTimeInterval_Infinite; + return DrError_OK; + } + + if (len == 16 && _strnicmp(pszString, "negativeinfinite", 16) == 0) { + *pTimeInterval = DrTimeInterval_NegativeInfinite; + return DrError_OK; + } + + if (len == 1 && pszString[0] == '0') { + *pTimeInterval = DrTimeInterval_Zero; + return DrError_OK; + } + + int i = 0; + while (i < len && pszString[i] != '\0') { + UInt64 n = 0; + UInt64 frac = 0; + int nFrac = 0; + int nDig = 0; + + while (nDig < 31 && i < len && pszString[i] >= '0' && pszString[i] <= '9') { + szBuff[nDig++] = pszString[i]; + i++; + } + + if (nDig > 0) { + szBuff[nDig] = '\0'; + err = DrStringToUInt64(szBuff, &n); + if (err != DrError_OK || (Int64)n < 0) { + return DrError_InvalidTimeInterval; + } + } + + if (i < len && pszString[i] == '.') { + i++; + while (nFrac < 31 && i < len && pszString[i] >= '0' && pszString[i] <= '9') { + szBuff[nFrac++] = pszString[i]; + i++; + } + if (nFrac > 0) { + szBuff[nFrac] = '\0'; + err = DrStringToUInt64(szBuff, &frac); + if (err != DrError_OK || (Int64)frac < 0) { + return DrError_InvalidTimeInterval; + } + } + } + + if (nDig == 0 && nFrac == 0) { + return DrError_InvalidTimeInterval; + } + + char szUnit[2]; + szUnit[0] = ' '; + szUnit[1] = ' '; + for (int k = 1; k >= 0; --k) { + if (i < len && pszString[i] != '\0' && (pszString[i] < '0' || pszString[i] > '9') && pszString[i] != '.') { + szUnit[k] = pszString[i++]; + if (szUnit[k] >= 'A' && szUnit[k] <= 'Z') { + szUnit[k] = szUnit[k] - 'A' + 'a'; + } + } else { + break; + } + } + + int tag = (int)*(WORD *)(void *)szUnit; + DrTimeInterval units = 0; + int fracKeep = 0; // # of fraction digits to keep + DrTimeInterval fracunits; + switch(tag) { + default: + return DrError_InvalidTimeInterval; + + case 'q ': + units = DrTimeInterval_Quantum; + fracunits = DrTimeInterval_Quantum; + fracKeep = 0; + break; + + case 'us': + units = DrTimeInterval_Microsecond; + fracunits = DrTimeInterval_Quantum; + fracKeep = 1; // 10 microseconds per interval + break; + + case 'ms': + units = DrTimeInterval_Millisecond; + fracunits = DrTimeInterval_Quantum; + fracKeep = 4; // 10,000 + break; + + case 's ': + units = DrTimeInterval_Second; + fracunits = 1; + fracKeep = 7; // 10,000,000 + break; + + case 'm ': + units = DrTimeInterval_Minute; + fracunits = DrTimeInterval_Minute / 100000000; + fracKeep = 8; // 600,000,000 + break; + + case 'h ': + units = DrTimeInterval_Hour; + fracunits = DrTimeInterval_Hour / 1000000000; + fracKeep = 9; // 36,000,000,000 + break; + + case 'd ': + units = DrTimeInterval_Day; + fracunits = DrTimeInterval_Day / 1000000000; + fracKeep = 9; // 864,000,000,000 + break; + + case 'w ': + units = DrTimeInterval_Week; + fracunits = DrTimeInterval_Week / 1000000000; + fracKeep = 9; // 6,048,000,000,000 + break; + + case 'y ': + units = DrTimeInterval_Year; + fracunits = DrTimeInterval_Year / 1000000000; + fracKeep = 9; // 314,496,000,000,000 + break; + + } + + // Normalize the fraction to the correct number of digits + while (nFrac > fracKeep) { + frac = frac / 10; + nFrac--; + } + while (nFrac < fracKeep) { + frac = frac * 10; + nFrac++; + } + + // result should be n*units + frac*fracunits + + err = DrAddDeltaUnitTimeInterval(&val, (Int64)n, units); + if (err != DrError_OK) { + return err; + } + err = DrAddDeltaUnitTimeInterval(&val, (Int64)frac, fracunits); + if (err != DrError_OK) { + return err; + } + } + + if (neg) { + val = -val; + } + + *pTimeInterval = val; + + return DrError_OK; +} + +DrError DrTimeStampToString(DrTimeStamp timeStamp, char *pBuffer, int buffLen, DrTimeInterval bias, Int32 nFracDig) +{ + SYSTEMTIME st; + + if (timeStamp == DrTimeStamp_Never) { + if (buffLen < 6) { + return HRESULT_FROM_WIN32( ERROR_INSUFFICIENT_BUFFER ); + } + strncpy(pBuffer, "never", 6); + return DrError_OK; + } else if (timeStamp == DrTimeStamp_LongAgo) { + if (buffLen < 9) { + return HRESULT_FROM_WIN32( ERROR_INSUFFICIENT_BUFFER ); + } + strncpy(pBuffer, "longago", 9); + return DrError_OK; + } + + if (bias == DrTimeInterval_Infinite) { + bias = DrGetCurrentLocalTimeZoneBias(); + } + + timeStamp -= bias; + + DrError err = DrTimeStampToSystemTime(timeStamp, &st, false); + if (err != DrError_OK) { + return err; + } + + if (nFracDig < 0) { + if (st.wMilliseconds == 0) { + nFracDig = 0; + } else { + nFracDig = 3; + } + } + + char szFrac[16]; + if (nFracDig == 0) { + szFrac[0] = '\0'; + } else { + szFrac[0] = '.'; + LogAssert(st.wMilliseconds < 1000); + sprintf(szFrac+1, "%03u", st.wMilliseconds); + if (nFracDig > 3) { + if (nFracDig > 7) { + nFracDig = 7; + } + // Compute quantum ticks in excess of 1ms boundary + UInt32 remainder = (UInt32)(timeStamp % 10000); + sprintf(szFrac+4, "%04u", remainder); + } + szFrac[nFracDig+1] = '\0'; + } + + + if (bias == DrTimeInterval_Zero) { + // No local time bias -- Zulu time + _snprintf(pBuffer, (size_t)buffLen, "%04u-%02u-%02uT%02u:%02u:%02u%sZ", + st.wYear, + st.wMonth, + st.wDay, + st.wHour, + st.wMinute, + st.wSecond, + szFrac); + } else { + // Local time zone bias + char szSuffix[k_DrTimeIntervalStringBufferSize+2]; + const char *suffix; + err = DrGenerateTimeZoneBiasSuffix(bias, szSuffix, ELEMENTCOUNT(szSuffix)); + LogAssert(err == DrError_OK); + suffix = szSuffix; + _snprintf(pBuffer, (size_t)buffLen, "%04u-%02u-%02uT%02u:%02u:%02u%s%s", + st.wYear, + st.wMonth, + st.wDay, + st.wHour, + st.wMinute, + st.wSecond, + szFrac, + suffix); + } + + return DrError_OK; +} + +// Converts a Dryad timeinterval to a human-readable string. +// The generated string may be fed back into DrStringToTimeInterval +DrError DrTimeIntervalToString(DrTimeInterval timeInterval, char *pBuffer, size_t buffLen) +{ + LogAssert(pBuffer != NULL); + LogAssert(buffLen != 0); + char tempBuff[k_DrTimeIntervalStringBufferSize]; + char *pBuff = pBuffer; + if (buffLen < k_DrTimeIntervalStringBufferSize) { + // use a temporary buffer if we aren't sure it will fit + pBuff = tempBuff; + } + DrError err = DrError_OK; + + if (timeInterval == DrTimeInterval_Infinite) { + strcpy(pBuff, "infinite"); + } else if (timeInterval == DrTimeInterval_NegativeInfinite) { + strcpy(pBuff, "negativeinfinite"); + } else { + bool neg = (timeInterval < DrTimeInterval_Zero); + UInt64 v; + if (neg) { + v = (UInt64)(-timeInterval); + } else { + v = (UInt64)timeInterval; + } + UInt32 frac100ns = (UInt32)(v % DrTimeInterval_Second); + v = v / DrTimeInterval_Second; + UInt32 sec = (UInt32)(v % (UInt32)60); + v = v / 60; + UInt32 min = (UInt32)(v % (UInt32)60); + v = v / 60; + UInt32 hr = (UInt32)(v % (UInt32)24); + v = v / 24; + // v now contains days + + + int i =0; + if (neg) { + pBuff[i++] = '-'; + } + int ret = 0; + bool fOutput = false; + if (v != 0) { + ret = _snprintf(pBuff+i, k_DrTimeIntervalStringBufferSize-i-1, "%I64ud", v); + LogAssert(ret > 0); + i += ret; + fOutput = true; + } + if (hr != 0 || (fOutput && (min != 0 || sec != 0 || frac100ns != 0))) { + fOutput = true; + ret = _snprintf(pBuff+i, k_DrTimeIntervalStringBufferSize-i-1, "%uh", hr); + LogAssert(ret > 0); + i += ret; + } + if (min != 0 || (fOutput && (sec != 0 || frac100ns != 0))) { + fOutput = true; + ret = _snprintf(pBuff+i, k_DrTimeIntervalStringBufferSize-i-1, "%um", min); + LogAssert(ret > 0); + i += ret; + } + if (frac100ns == 0) { + // whole number of seconds + if (sec != 0 || !fOutput) { + fOutput = true; + ret = _snprintf(pBuff+i, k_DrTimeIntervalStringBufferSize - i - 1, "%us", sec); + LogAssert(ret > 0); + i += ret; + } + } else { + // fractional seconds + fOutput = true; + ret = _snprintf(pBuff+i, k_DrTimeIntervalStringBufferSize - i - 1, "%u.%07u", sec, frac100ns); + LogAssert(ret > 0); + i += ret; + + // remove traling "0" characters + while (i > 0 && pBuff[i-1] == '0') { + --i; + } + + LogAssert((size_t)i+2 < k_DrTimeIntervalStringBufferSize); + pBuff[i++] = 's'; + pBuff[i] = '\0'; + } + } + + if (err == DrError_OK && pBuff == tempBuff) { + size_t n = strlen(tempBuff) + 1; + if (n <= buffLen) { + memcpy(pBuffer, tempBuff, n); + } else { + err = DrError_StringTooLong; + pBuffer[0] = '\0'; + } + } + + return err; +} + +DrError DrTimeStampToString(DrTimeStamp timeStamp, char *pBuffer, int buffLen, bool fToLocalTimeZone, Int32 nFracDig) +{ + return DrTimeStampToString(timeStamp, pBuffer, buffLen, fToLocalTimeZone ? DrTimeInterval_Infinite : DrTimeInterval_Zero, nFracDig); +} + +DrError DrStringToTimeStamp(const char *pszTime, DrTimeStamp *pTimeStampOut, DrTimeInterval defaultTimeZoneBias) +{ + // 2006-04-18 13:06:44 + // 2006-04-18T13:06:44.mmmL+8h + // 2006-04-18T13:06:44.mmm-08:00 + // 2006-04-18T13:06:44.mmmZ + // 01234567890123456789 + // +timeinterval + // -timeinterval + DrError err = DrError_OK; + + UInt32 uYear; + UInt32 uMonth; + UInt32 uDay; + UInt32 uHour = 0; + UInt32 uMinute = 0; + UInt32 uSecond = 0; + UInt32 uFrac = 0; + UInt32 nFracDigs = 0; + DrTimeInterval bias = defaultTimeZoneBias; + + // If the first character is "+" or "-", it is a relative time interval to the current time + if (pszTime != NULL && (pszTime[0] == '-' || pszTime[0] == '+')) { + DrTimeInterval ti; + err = DrStringToTimeInterval(pszTime, &ti); + if (err == DrError_OK) { + *pTimeStampOut = DrGetCurrentTimeStamp() + ti; + } + return err; + } + + // If the string consists entirely of digits, it is a simple decimal encoding of a DrTimeStamp. We optimize for this case: + if (pszTime != NULL) { + char c1 = *pszTime; + // An early detector, numbers that don't start with 1 or 2 (for year) are usually simple numbers + if ((c1 >= '0' && c1 < '1') || (c1 >= '3' && c1 <= '9')) { + goto trySimple; + } + // If first 5 chars are numeric, it is probably a simple number + for (UInt32 i = 0; i < 5; i++) { + if (pszTime[i] == '\0') { + break; + } else if (pszTime[i] < '0' || pszTime[i] > '9') { + goto notSimple; + } + } + +trySimple: + // might be a simple number + err = DrStringToUInt64(pszTime, pTimeStampOut); + if (err == DrError_OK) { + return err; + } + + // If not a simple number, fall through to try string forms... + + } + +notSimple: + + DrStr32 strTime(pszTime); + + if (strTime.GetLength() == 7 && strTime == "longago") { + *pTimeStampOut = DrTimeStamp_LongAgo; + return DrError_OK; + } else if (strTime == "never") { + *pTimeStampOut = DrTimeStamp_Never; + return DrError_OK; + } + + if (strTime.GetLength() < 10 || strTime.GetLength() > 40) { + return DrError_InvalidParameter; + } + + char *psz = &(strTime[0]); + + if (psz[4] != '-') { + return DrError_InvalidParameter; + } + + psz[4] = '\0'; + err = DrStringToUInt32(psz, &uYear); + if (err != DrError_OK) { + return err; + } + if (uYear < 1600 || uYear > 9999) { + return DrError_InvalidParameter; + } + psz += 5; + + if (psz[2] != '-') { + return DrError_InvalidParameter; + } + + psz[2] = '\0'; + err = DrStringToUInt32(psz, &uMonth); + if (err != DrError_OK) { + return err; + } + if (uMonth < 1 || uMonth > 12) { + return DrError_InvalidParameter; + } + psz += 3; + + + char chNext = psz[2]; + psz[2] = '\0'; + + err = DrStringToUInt32(psz, &uDay); + if (err != DrError_OK) { + return err; + } + if (uDay < 1 || uDay > 31) { + return DrError_InvalidParameter; + } + + if (chNext == '\0') { + psz += 2; + } else { + psz += 3; + } + + if (chNext == ' ' || chNext == 'T') { + // there is HH:MM:SS + if (strTime.GetLength() < 19) { + return DrError_InvalidParameter; + } + if (psz[2] != ':') { + return DrError_InvalidParameter; + } + psz[2] = '\0'; + err = DrStringToUInt32(psz, &uHour); + if (err != DrError_OK) { + return err; + } + if (uHour > 23) { + return DrError_InvalidParameter; + } + psz += 3; + if (psz[2] != ':') { + return DrError_InvalidParameter; + } + psz[2] = '\0'; + err = DrStringToUInt32(psz, &uMinute); + if (err != DrError_OK) { + return err; + } + if (uMinute > 59) { + return DrError_InvalidParameter; + } + psz += 3; + chNext = psz[2]; + psz[2] = '\0'; + + err = DrStringToUInt32(psz, &uSecond); + if (err != DrError_OK) { + return err; + } + if (uSecond > 59) { + return DrError_InvalidParameter; + } + + if (chNext == '\0') { + psz += 2; + } else { + psz += 3; + } + + if (chNext == '.') { + // fraction + while (*psz >= '0' && *psz <= '9') { + nFracDigs++; + if (nFracDigs > 9) { + return DrError_InvalidParameter; + } + uFrac = (10 * uFrac) + (UInt32) (*psz - '0'); + psz++; + } + chNext = *psz; + if (chNext != '\0') { + psz++; + } + } + } + + if (chNext == 'Z') { + bias = 0; + } else if (chNext == 'L') { + bool fNeg = false; + if (*psz == '+' || *psz == '-') { + fNeg = (*psz == '-'); + psz++; + err = DrStringToTimeInterval(psz, &bias); + if (err != DrError_OK) { + return err; + } + if (fNeg) { + bias = -bias; + } + psz += strlen(psz); + } + } + //case where time end with [-/+]HH:MM + else if (chNext == '+' || chNext == '-') + { + UInt32 ubiashour = 0; + UInt32 ubiasMinute = 0; + + while(isdigit((int)(*psz))) + { + ubiashour = ubiashour * 10 + (*psz-'0'); + psz++; + } + + if(*psz != ':') + { + return DrError_InvalidParameter; + } + psz++; + + while(isdigit((int)(*psz))) + { + ubiasMinute= ubiasMinute* 10 + (*psz-'0'); + psz++; + } + + bias = ubiashour * DrTimeInterval_Hour + ubiasMinute * DrTimeInterval_Minute; + + //If timezone is -08:00, we have to add 8 hrs to find utc time. + if(chNext == '+') + { + bias = -bias; + } + } + + if (*psz != '\0') { + return DrError_InvalidParameter; + } + + SYSTEMTIME st; + st.wDayOfWeek = 0; // unknown + st.wYear = (WORD) uYear; + st.wMonth = (WORD) uMonth; + st.wDay = (WORD) uDay; + st.wHour = (WORD) uHour; + st.wMinute = (WORD) uMinute; + st.wSecond = (WORD) uSecond; + st.wMilliseconds = 0; // we handle milliseconds ourselves to get better resolution... + + while (nFracDigs > 7 && uFrac != 0) { + uFrac = uFrac / 10; + nFracDigs --; + } + while (nFracDigs < 7 && uFrac != 0) { + uFrac = 10 * uFrac; + nFracDigs++; + } + + DrTimeStamp ts; + err = DrSystemTimeToTimeStamp(&st, &ts, false); + if (err != DrError_OK) { + return err; + } + + ts += uFrac; + ts += bias; + + *pTimeStampOut = ts; + + return DrError_OK; +} + +DrError DrStringToTimeStamp(const char *pszTime, DrTimeStamp *pTimeStampOut, bool fDefaultLocalTimeZone) +{ + DrTimeInterval bias; + if (fDefaultLocalTimeZone) { + bias = DrGetCurrentLocalTimeZoneBias(); + } else { + bias = 0; + } + return DrStringToTimeStamp(pszTime, pTimeStampOut, bias); +} + +DrError DrTimeStampToSystemTime(DrTimeStamp timeStamp, SYSTEMTIME *pSystemTime, bool fToLocalTimeZone) +{ + union { + FILETIME ft; + DrTimeStamp ts; + }; + + if (fToLocalTimeZone && timeStamp != DrTimeStamp_LongAgo && timeStamp != DrTimeStamp_Never) { + if (!FileTimeToLocalFileTime((const FILETIME *)(const void *)&timeStamp, &ft)) { + return DrGetLastError(); + } + } else { + ts = timeStamp; + } + + if (!FileTimeToSystemTime(&ft, pSystemTime)) { + return DrGetLastError(); + } + return DrError_OK; +} + + +// +// Get an environment variable +// +DrError DrGetEnvironmentVariable(const WCHAR *pszVarName, WCHAR ppszValue[]) +{ + DrError err; + WCHAR * psz = NULL; + DWORD nb2 = 0; + + + // + // Get length of environment variable value + // + DWORD nb = GetEnvironmentVariableW(pszVarName, NULL, 0); + if (nb == 0) + { + err = DrGetLastError(); + goto done; + } + + psz = (WCHAR *)malloc(sizeof(WCHAR) * nb); + + // + // Get environment variable value + // + nb2 = GetEnvironmentVariableW(pszVarName, psz, nb); + if (nb2 == 0) + { + err = DrGetLastError(); + goto done; + } + + err = DrError_OK; + + // Fail if more than MAX_PATH characters + if(MAX_PATH <= nb2) + { + err = DrError_Fail; + } + +done: + if (err != DrError_OK || psz == NULL) + { + // GLE may return wrong results in mixed mode code. If we catch S_OK we need to return a meaningful code instead + if (err == S_OK) err = ERROR_ENVVAR_NOT_FOUND; + + // + // If there has been an error, set value to null + // and free any allocated resources + // + if (psz != NULL) + { + free(psz); + } + + *ppszValue = NULL; + } + else + { + LogAssert(psz != NULL); + // use length + 1 to get null character ending + wcsncpy(ppszValue, psz, nb2+1); + free(psz); + } + + return err; +} + +// +// Get an environment variable +// +DrError DrGetEnvironmentVariable(const char *pszVarName, const char **ppszValue) +{ + DrError err; + LPWSTR myenvname; + LPWSTR psz = NULL; + + int charLen = lstrlenA(pszVarName); + int wcharLen; + + // + // Get length of variable name + // + wcharLen = ::MultiByteToWideChar(CP_ACP, NULL, pszVarName, charLen, NULL, NULL); + if (wcharLen > 0) + { + // + // Get converted variable name + // + myenvname = ::SysAllocStringLen(0, wcharLen); + ::MultiByteToWideChar(CP_ACP, 0, pszVarName, charLen, myenvname, wcharLen); + } + else + { + // + // If unable to get length, fail + // + err = DrGetLastError(); + goto done; + } + + + // + // Get length of environment variable value + // + DWORD nb = GetEnvironmentVariableW(myenvname, NULL, 0); + if (nb == 0) { + err = DrGetLastError(); + goto done; + } + + psz = (LPWSTR )malloc(sizeof(char) * nb); + + // + // Get environment variable value + // + DWORD nb2 = GetEnvironmentVariableW(myenvname, psz, nb); + if (nb2 == 0) { + err = DrGetLastError(); + goto done; + } + + err = DrError_OK; + +done: + if (err != DrError_OK) + { + // + // If there has been an error, set value to null + // and free any allocated resources + // + if (psz != NULL) + { + free(psz); + } + + *ppszValue = NULL; + } + else + { + LogAssert(psz != NULL); + *ppszValue = (char*) psz; + } + + return err; +} + +DrError DrGetSidForUser(LPCWSTR domainUserName, PSID* ppSid) +{ + // Create buffers for SID and domain. If size > default, will retry with new size + DWORD bufferSizeSid = 64; + DWORD bufferSizeDomain = 64; + DWORD newBufferSizeSid = 64; + DWORD newBufferSizeDomain = 64; + WCHAR* pDomainName = NULL; + SID_NAME_USE sidType; + + // Check SID pointer fo null before using + if(ppSid == NULL) + { + return DrError_Fail; + } + + // Create buffers for the SID and domain name. + *ppSid = (PSID) new BYTE[bufferSizeSid]; + if (*ppSid == NULL) + { + return DrError_Fail; + } + memset(*ppSid, 0, bufferSizeSid); + + pDomainName = new WCHAR[bufferSizeDomain]; + if (pDomainName == NULL) + { + FreeSid(*ppSid); + return DrError_Fail; + } + memset(pDomainName, 0, bufferSizeDomain*sizeof(WCHAR)); + + // Try to get SID with default buffer size + if (LookupAccountNameW(NULL, domainUserName, *ppSid, &newBufferSizeSid, pDomainName, &newBufferSizeDomain, &sidType)) + { + delete [] pDomainName; + + if (IsValidSid(*ppSid) == FALSE) + { + return DrError_Fail; + } + + return DrError_OK; + } + + // If unable to get account name, check for insufficient buffer + DWORD err = GetLastError(); + while (err == ERROR_INSUFFICIENT_BUFFER) + { + if (newBufferSizeSid > bufferSizeSid) + { + // Free and reallocate buffer for SID + FreeSid(*ppSid); + *ppSid = (PSID) new BYTE[newBufferSizeSid]; + if (*ppSid == NULL) + { + delete [] pDomainName; + return DrError_Fail; + } + + bufferSizeSid = newBufferSizeSid; + + memset(*ppSid, 0, bufferSizeSid); + } + + if (newBufferSizeDomain > bufferSizeDomain) + { + // Free and reallocate buffer for domain + delete [] pDomainName; + pDomainName = new WCHAR[newBufferSizeDomain]; + if (pDomainName == NULL) + { + FreeSid(*ppSid); + return DrError_Fail; + } + + bufferSizeDomain = newBufferSizeDomain; + + memset(pDomainName, 0, bufferSizeDomain*sizeof(WCHAR)); + } + + // Try to get SID with new buffer size + if (LookupAccountNameW(NULL, domainUserName, *ppSid, &bufferSizeSid, pDomainName, &bufferSizeDomain, &sidType)) + { + delete [] pDomainName; + + if (IsValidSid(*ppSid) == FALSE) + { + return DrError_Fail; + } + + return DrError_OK; + } + + err = GetLastError(); + } + + // If outside loop, failed to lookup SID + return DrError_Fail; +} + +DrError DrGetComputerName(WCHAR ppszValue[]) +{ + + WCHAR azureFlag[MAX_PATH+1] = {0}; + DrError err = DrGetEnvironmentVariable(L"CCP_ONAZURE", azureFlag); + if(err == DrError_OK) + { + // This process is running on Azure + err = DrGetEnvironmentVariable(L"HPC_NODE_NAME", ppszValue); + if(err != DrError_OK) + { + DrLogE( "Error retrieving HPC_NODE_NAME environment variable. Error: %s", DrGetErrorText(err)); + return err; + } + } + else + { + // This process is not running on Azure + + // swap with current lines for DNS hostname support + //DWORD hostLength = DNS_MAX_LABEL_BUFFER_LENGTH; + //if (!GetComputerNameExW(ComputerNameDnsHostname, ppszValue, &hostLength)) + + DWORD hostLength = MAX_COMPUTERNAME_LENGTH + 1; + if (!GetComputerName(ppszValue, &hostLength)) + { + DrLogE( "Error calling GetComputerName. ErrorCode: %u", GetLastError()); + return DrError_Fail; + } + } + + return DrError_OK; +} + +//JC +#if 0 + +DrError DrGetEnvironmentVariable(const char *pszVarName, /* out */ const char **ppszValue) +{ + // We jump through hoops to use unicode API here and convert to UTF-8 + + DrError err; + WCHAR *psz = NULL; + DrWStr256 wstr; + DrWStr64 wstrVarName; + wstrVarName = pszVarName; + DrStr256 str; + char *pszRet = NULL; + + DWORD nb = GetEnvironmentVariableW(wstrVarName.GetString(), NULL, 0); + + if (nb == 0) { + err = DrGetLastError(); + goto done; + } + + psz = wstr.GetWritableBuffer(nb); + + DWORD nb2 = GetEnvironmentVariableW(wstrVarName.GetString(), psz, nb); + if (nb2 == 0) { + err = DrGetLastError(); + goto done; + } + LogAssert(nb2 < nb); + wstr.UpdateLength(nb2); + jstr.Set( wstr ); + LogAssert(str.GetString() != NULL); + + pszRet = (char *)malloc(str.GetLength()+1); + LogAssert(pszRet != NULL); + memcpy(pszRet, str.GetString(), str.GetLength()+1); + + err = DrError_OK; + +done: + if (err != DrError_OK) { + if (pszRet != NULL) { + free(pszRet); + } + *ppszValue = NULL; + } else { + LogAssert(pszRet != NULL); + *ppszValue = pszRet; + } + + return err; +} + +// If pszBaseDir is null, the current working directory is used. If pszRelDir is +// fully qualified, pszBaseDir is ignored. +DrError DrCanonicalizeFilePath(DrStr& strOut, const char *pszRelDir, const char *pszBaseDir) +{ + DrWStr128 wstrRelDir; + DrWStr128 wstrBaseDir; + wstrRelDir = pszRelDir; + wstrBaseDir = pszBaseDir; + + if (pszRelDir != NULL && !PathIsRelativeW(wstrRelDir)) { + strOut = pszRelDir; + return DrError_OK; + } + + if (wstrBaseDir == NULL) { + DrStr256 strBase; + DrGetCurrentDirectory(strBase); + wstrBaseDir = strBase; + } + + WCHAR szBuff1[MAX_PATH]; + WCHAR szBuff2[MAX_PATH]; + + if (wstrRelDir != NULL) { + const WCHAR *pszResult = PathCombineW( + szBuff1, + wstrBaseDir, + wstrRelDir); + + if (pszResult == NULL) { + return DrError_InvalidPathname; + } + + BOOL fRet = PathCanonicalizeW(szBuff2, szBuff1); + if (!fRet) { + return DrError_InvalidPathname; + } + } else { + BOOL fRet = PathCanonicalizeW(szBuff2, wstrBaseDir); + if (!fRet) { + return DrError_InvalidPathname; + } + } + + strOut.Set(szBuff2); + + return DrError_OK; +} +#endif diff --git a/DryadVertex/VertexHost/system/classlib/src/DrGuid.cpp b/DryadVertex/VertexHost/system/classlib/src/DrGuid.cpp new file mode 100644 index 0000000..81eb566 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/DrGuid.cpp @@ -0,0 +1,359 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrCommon.h" +#include + + +#pragma unmanaged + +const GUID g_DrInvalidGuid = { 0xFFFFFFFF, 0xFFFF, 0xFFFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; +const GUID g_DrNullGuid = { 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0 }; + +// Generate a new guid +void DrGuid::Generate() +{ + HRESULT success = CoCreateGuid((GUID *)this); + if (FAILED(success)) + { + LogAssert("Fatal error, failed to create guid"); + } +} + + +static +const UInt8 * +GuidWriteWord ( + char *pDst, + const UInt8 *pSrc, + uxint uBytes +) +{ + static const char Hex[] = "0123456789ABCDEF"; + uxint c; + + pDst += uBytes * 2; + *pDst = '-'; + do + { + c = *pSrc++; + pDst -= 2; + pDst[1] = Hex[c & 15]; + pDst[0] = Hex[c >> 4]; + } + while (--uBytes); + + return (pSrc); +} + +static +void +GuidWrite ( + char *pDst, + const GUID *pGuid, + bool fBraces = true +) +{ + const UInt8 *p; + + p = (const UInt8 *) pGuid; + + if (fBraces) { + *pDst++ = '{'; + } + + p = GuidWriteWord (pDst, p, 4); + pDst += 9; + + p = GuidWriteWord (pDst, p, 2); + pDst += 5; + + p = GuidWriteWord (pDst, p, 2); + pDst += 5; + + GuidWriteWord (pDst, p, 1); + GuidWriteWord (pDst + 2, p + 1, 1); + pDst += 5; + p += 2; + + GuidWriteWord (pDst + 0*2, p + 0, 1); + GuidWriteWord (pDst + 1*2, p + 1, 1); + GuidWriteWord (pDst + 2*2, p + 2, 1); + GuidWriteWord (pDst + 3*2, p + 3, 1); + GuidWriteWord (pDst + 4*2, p + 4, 1); + GuidWriteWord (pDst + 5*2, p + 5, 1); + + pDst += 6*2; + if (fBraces) { + *pDst++ = '}'; + } + *pDst = 0; +} + +static +const UInt8 * +GuidWriteWord ( + WCHAR *pDst, + const UInt8 *pSrc, + uxint uBytes +) +{ + static const WCHAR Hex[] = L"0123456789ABCDEF"; + uxint c; + + pDst += uBytes * 2; + *pDst = L'-'; + do + { + c = *pSrc++; + pDst -= 2; + pDst[1] = Hex[c & 15]; + pDst[0] = Hex[c >> 4]; + } + while (--uBytes); + + return (pSrc); +} + +#pragma warning (disable: 4505) // unreferenced local function +static +void +GuidWrite ( + WCHAR *pDst, + const GUID *pGuid, + bool fBraces = true +) +{ + const UInt8 *p; + + p = (const UInt8 *) pGuid; + + if (fBraces) { + *pDst++ = L'{'; + } + + p = GuidWriteWord (pDst, p, 4); + pDst += 9; + + p = GuidWriteWord (pDst, p, 2); + pDst += 5; + + p = GuidWriteWord (pDst, p, 2); + pDst += 5; + + GuidWriteWord (pDst, p, 1); + GuidWriteWord (pDst + 2, p + 1, 1); + pDst += 5; + p += 2; + + GuidWriteWord (pDst + 0*2, p + 0, 1); + GuidWriteWord (pDst + 1*2, p + 1, 1); + GuidWriteWord (pDst + 2*2, p + 2, 1); + GuidWriteWord (pDst + 3*2, p + 3, 1); + GuidWriteWord (pDst + 4*2, p + 4, 1); + GuidWriteWord (pDst + 5*2, p + 5, 1); + + pDst += 6*2; + if (fBraces) { + *pDst++ = L'}'; + } + *pDst = 0; +} + +static +bool +GuidReadWord ( + const char *pSrc, + UInt8 *pDst, + uxint uBytes +) +{ + uxint c0, c1; + + pDst += uBytes; + do + { + c0 = (uxint) *pSrc++; + if (c0 >= '0' && c0 <= '9') + c0 -= '0'; + else if (c0 >= 'A' && c0 <= 'F') + c0 -= 'A' - 10; + else if (c0 >= 'a' && c0 <= 'f') + c0 -= 'a' - 10; + else + return (false); + + c1 = (uxint) *pSrc++; + if (c1 >= '0' && c1 <= '9') + c1 -= '0'; + else if (c1 >= 'A' && c1 <= 'F') + c1 -= 'A' - 10; + else if (c1 >= 'a' && c1 <= 'f') + c1 -= 'a' - 10; + else + return (false); + + *--pDst = (UInt8) ((c0 << 4) + c1); + } + while (--uBytes); + + return (true); +} + + +static const char* GuidRead(GUID *pGuid, const char *pSrc, bool allowBraces, bool requireBraces, bool requireEOL) +{ + UInt8 *p; + bool closingBrace; + + if (requireBraces && !allowBraces) + { + goto failed; + } + + if (pSrc == NULL) + { + goto failed; + } + + p = (UInt8 *) pGuid; + + closingBrace = false; + if (allowBraces) + { + if (*pSrc == '{') + { + pSrc++; + closingBrace = true; + } + else if (requireBraces) + { + goto failed; + } + } + + if (!GuidReadWord (pSrc, p, 4) || pSrc[8] != '-') + { + goto failed; + } + pSrc += 9; + p += 4; + + if (!GuidReadWord (pSrc, p, 2) || pSrc[4] != '-') + { + goto failed; + } + pSrc += 5; + p += 2; + + if (!GuidReadWord (pSrc, p, 2) || pSrc[4] != '-') + { + goto failed; + } + pSrc += 5; + p += 2; + + if (!GuidReadWord (pSrc, p, 1) || !GuidReadWord (pSrc + 2, p + 1, 1) || pSrc[4] != '-') + { + goto failed; + } + pSrc += 5; + p += 2; + + if (!GuidReadWord (pSrc + 0*2, p + 0, 1) || + !GuidReadWord (pSrc + 1*2, p + 1, 1) || + !GuidReadWord (pSrc + 2*2, p + 2, 1) || + !GuidReadWord (pSrc + 3*2, p + 3, 1) || + !GuidReadWord (pSrc + 4*2, p + 4, 1) || + !GuidReadWord (pSrc + 5*2, p + 5, 1)) + { + goto failed; + } + + pSrc += 6*2; + + if (closingBrace) + { + if (*pSrc != '}') + { + goto failed; + } + ++pSrc; + } + + if (requireEOL && *pSrc != 0) + { + goto failed; + } + + return pSrc; + +failed: + FillMemory(pGuid, sizeof(GUID), 0xFF); + return NULL; +} + +// appends the guid in string form; {EFF6744C-7143-11cf-A51B-080036F12502}. If fBraces +// is false, the braces are omitted. +DrStr& DrGuid::AppendToString(DrStr& strOut, bool fBraces) const +{ + size_t len = strOut.GetLength(); + char *pDest = strOut.GetWritableAppendBuffer(GuidStringLength-1); + GuidWrite(pDest, this, fBraces); + strOut.UpdateLength(len + strlen(pDest)); + return strOut; +} + +/* JC +// appends the guid in string form; {EFF6744C-7143-11cf-A51B-080036F12502}. If fBraces +// is false, the braces are omitted. +DrWStr& DrGuid::AppendToString(DrWStr& strOut, bool fBraces) const +{ + size_t len = strOut.GetLength(); + WCHAR *pDest = strOut.GetWritableAppendBuffer(GuidStringLength-1); + GuidWrite(pDest, this, fBraces); + strOut.UpdateLength(len + wcslen(pDest)); + return strOut; +} +*/ + + // Output the guid in string form; {EFF6744C-7143-11cf-A51B-080036F12502}. If fBraces + // is false, the braces are omitted. + // String must be able to hold DrGuid::GuidStringLength (39 characters = 38 + null terminator) +char *DrGuid::ToString (char *string, bool fBraces) const +{ + GuidWrite (string, this, fBraces); + return (string); +} + + +// Parse a guid from a string. Acceptes guids either with or without braces +// EFF6744C-7143-11cf-A51B-080036F12502 +BOOL DrGuid::Parse(const char *string) +{ + return GuidRead(this, string, true, false, true) != NULL; +} + +//returns pointer to the next char after guid on success, NULL on failure +const char* DrGuid::Parse(const char *string, bool allowBraces, bool requireBraces, bool requireEOL) +{ + return GuidRead(this, string, allowBraces, requireBraces, requireEOL); +} + diff --git a/DryadVertex/VertexHost/system/classlib/src/DrHash.cpp b/DryadVertex/VertexHost/system/classlib/src/DrHash.cpp new file mode 100644 index 0000000..21a1a91 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/DrHash.cpp @@ -0,0 +1,528 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrHash.h" + +#pragma unmanaged + +/* + * Copied and modified from http://burtleburtle.net/bob/c/lookup3.c, + * where it is Public Domain. + */ + +#define UPPER(c) ((UInt32)((((c) >= 'a') && ((c) <= 'z')) ? (c) - ('a' - 'A') : (c))) + + +// Compute2: Compute two hashes of an array of bytes. +// The first hash is slightly better mixed than the second hash. +void DrHash32::Compute2 ( + const void *pData, // byte array to hash; may be null if uSize==0 + Size_t uSize, // length of pData + UInt32 uSeed1, // first seed + UInt32 uSeed2, // second seed + UInt32 *uHash1, // OUT: first hash (may not be null) + UInt32 *uHash2) // OUT: second hash (may not be null) +{ + UInt32 a,b,c; + + // Set up the internal state + a = b = c = 0xdeadbeef + ((UInt32)uSize) + uSeed1; + c += uSeed2; + + if (((((UInt8 *)pData)-(UInt8 *)0) & 0x3) == 0) { + const UInt32 *k = (const UInt32 *)pData; // read 32-bit chunks + const UInt8 *k1; + + // all but last block: aligned reads and affect 32 bits of (a,b,c) + while (uSize > 12) + { + a += k[0]; + b += k[1]; + c += k[2]; + Mix(a,b,c); + uSize -= 12; + k += 3; + } + + // handle the last (probably partial) block + k1 = (const UInt8 *)k; + switch(uSize) + { + case 12: c+=k[2]; b+=k[1]; a+=k[0]; break; + case 11: c+=((UInt32)k1[10])<<16; // fall through + case 10: c+=((UInt32)k1[9])<<8; // fall through + case 9 : c+=(UInt32)k1[8]; // fall through + case 8 : b+=k[1]; a+=k[0]; break; + case 7 : b+=((UInt32)k1[6])<<16; // fall through + case 6 : b+=((UInt32)k1[5])<<8; // fall through + case 5 : b+=((UInt32)k1[4]); // fall through + case 4 : a+=k[0]; break; + case 3 : a+=((UInt32)k1[2])<<16; // fall through + case 2 : a+=((UInt32)k1[1])<<8; // fall through + case 1 : a+=k1[0]; break; + case 0 : + *uHash1 = c; + *uHash2 = b; + return; + } + + } else if (((((UInt8 *)pData)-(UInt8 *)0) & 0x1) == 0) { + const UInt16 *k = (const UInt16 *) pData; // read 16-bit chunks + const UInt8 *k1; + + // all but last block: aligned reads and different mixing + while (uSize > 12) + { + a += k[0] + (((UInt32)k[1])<<16); + b += k[2] + (((UInt32)k[3])<<16); + c += k[4] + (((UInt32)k[5])<<16); + Mix(a,b,c); + uSize -= 12; + k += 6; + } + + // handle the last (probably partial) block + k1 = (const UInt8 *)k; + switch(uSize) + { + case 12: c+=k[4]+(((UInt32)k[5])<<16); + b+=k[2]+(((UInt32)k[3])<<16); + a+=k[0]+(((UInt32)k[1])<<16); + break; + case 11: c+=((UInt32)k1[10])<<16; // fall through + case 10: c+=k[4]; + b+=k[2]+(((UInt32)k[3])<<16); + a+=k[0]+(((UInt32)k[1])<<16); + break; + case 9 : c+=k1[8]; // fall through + case 8 : b+=k[2]+(((UInt32)k[3])<<16); + a+=k[0]+(((UInt32)k[1])<<16); + break; + case 7 : b+=((UInt32)k1[6])<<16; // fall through + case 6 : b+=k[2]; + a+=k[0]+(((UInt32)k[1])<<16); + break; + case 5 : b+=k1[4]; // fall through + case 4 : a+=k[0]+(((UInt32)k[1])<<16); + break; + case 3 : a+=((UInt32)k1[2])<<16; // fall through + case 2 : a+=k[0]; + break; + case 1 : a+=k1[0]; + break; + case 0 : + *uHash1 = c; + *uHash2 = b; + return; + } + + } else { // need to read the key one byte at a time + const UInt8 *k = (const UInt8 *)pData; + + // all but the last block: affect some 32 bits of (a,b,c) + while (uSize > 12) + { + a += k[0]; + a += ((UInt32)k[1])<<8; + a += ((UInt32)k[2])<<16; + a += ((UInt32)k[3])<<24; + b += k[4]; + b += ((UInt32)k[5])<<8; + b += ((UInt32)k[6])<<16; + b += ((UInt32)k[7])<<24; + c += k[8]; + c += ((UInt32)k[9])<<8; + c += ((UInt32)k[10])<<16; + c += ((UInt32)k[11])<<24; + Mix(a,b,c); + uSize -= 12; + k += 12; + } + + // last block: affect all 32 bits of (c) + switch(uSize) // all the case statements fall through + { + case 12: c+=((UInt32)k[11])<<24; + case 11: c+=((UInt32)k[10])<<16; + case 10: c+=((UInt32)k[9])<<8; + case 9 : c+=k[8]; + case 8 : b+=((UInt32)k[7])<<24; + case 7 : b+=((UInt32)k[6])<<16; + case 6 : b+=((UInt32)k[5])<<8; + case 5 : b+=k[4]; + case 4 : a+=((UInt32)k[3])<<24; + case 3 : a+=((UInt32)k[2])<<16; + case 2 : a+=((UInt32)k[1])<<8; + case 1 : a+=k[0]; + break; + case 0 : + *uHash1 = c; + *uHash2 = b; + return; + } + } + + Final(a,b,c); + *uHash1 = c; + *uHash2 = b; + return; +} + + +// Hash a string of unknown length case sensitive. I can't just call +// Compute() without allocating a copy of the string, which could have +// complications because there's no max length for strings. +void DrHash32::StringI2 ( + const char *pString, + Size_t uSize, + UInt32 uSeed1, + UInt32 uSeed2, + UInt32 *uHash1, + UInt32 *uHash2) +{ + UInt32 a,b,c; + const UInt8 *k; + + k = (const UInt8 *) pString; + + // Set up the internal state + a = b = c = 0xdeadbeef + ((UInt32)uSize) + uSeed1; + c += uSeed2; + + // all but the last block: affect some 32 bits of (a,b,c) + while (uSize > 12) + { + a += UPPER(k[0]); + a += UPPER(k[1])<<8; + a += UPPER(k[2])<<16; + a += UPPER(k[3])<<24; + b += UPPER(k[4]); + b += UPPER(k[5])<<8; + b += UPPER(k[6])<<16; + b += UPPER(k[7])<<24; + c += UPPER(k[8]); + c += UPPER(k[9])<<8; + c += UPPER(k[10])<<16; + c += UPPER(k[11])<<24; + Mix(a,b,c); + uSize -= 12; + k += 12; + } + + // last block: affect all 32 bits of (c) + switch(uSize) // all the case statements fall through + { + case 12: c+=UPPER(k[11])<<24; + case 11: c+=UPPER(k[10])<<16; + case 10: c+=UPPER(k[9])<<8; + case 9 : c+=UPPER(k[8]); + case 8 : b+=UPPER(k[7])<<24; + case 7 : b+=UPPER(k[6])<<16; + case 6 : b+=UPPER(k[5])<<8; + case 5 : b+=UPPER(k[4]); + case 4 : a+=UPPER(k[3])<<24; + case 3 : a+=UPPER(k[2])<<16; + case 2 : a+=UPPER(k[1])<<8; + case 1 : a+=UPPER(k[0]); + break; + case 0 : + *uHash1 = c; + *uHash2 = b; + return; + } + + Final(a,b,c); + *uHash1 = c; + *uHash2 = b; + return; +} + + + +// +// Self-test to check that the hash behaves as advertized +// Or you can plug your favorite hash in here and see how it fares! +// +#if 0 + +#include + +// used for timings +static void driver1() +{ + UInt8 buf[256]; + UInt32 i; + UInt64 h=0; + time_t a,z; + + time(&a); + for (i=0; i<256; ++i) buf[i] = 'x'; + + // increase the loop size until you can measure wall-clock time taken + for (i=0; i<1; ++i) + { + h = DrHash64::Compute(&buf[0],1,h); + } + time(&z); + if (z-a > 0) printf("time %ld %.8x\n", z-a, h); +} + +// check that every input bit changes every output bit half the time +#define HASHSTATE 1 +#define HASHLEN 1 +#define MAXPAIR 60 +#define MAXLEN 70 +static void driver2() +{ + UInt8 qa[MAXLEN+1], qb[MAXLEN+2], *a = &qa[0], *b = &qb[1]; + UInt64 c[HASHSTATE], d[HASHSTATE]; + UInt32 i=0, j=0, k, l, m=0, z; + UInt64 e[HASHSTATE],f[HASHSTATE],g[HASHSTATE],h[HASHSTATE]; + UInt64 x[HASHSTATE],y[HASHSTATE]; + Size_t hlen; + + printf("No more than %d trials should ever be needed \n",MAXPAIR/2); + for (hlen=0; hlen < MAXLEN; ++hlen) + { + z=0; + for (i=0; i>(8-j)); + c[0] = DrHash64::Compute(a, hlen, m); + b[i] ^= ((k+1)<>(8-j)); + d[0] = DrHash64::Compute(b, hlen, m); + // check every bit is 1, 0, set, and not set at least once + for (l=0; lz) z=k; + if (k==MAXPAIR) + { + printf("Some bit didn't change: "); + printf("%.8x.8x %.8x.8x %.8x.8x %.8x.8x %.8x.8x %.8x.8x ", + (UInt32)(e[0] >> 32), (UInt32)e[0], + (UInt32)(f[0] >> 32), (UInt32)f[0], + (UInt32)(g[0] >> 32), (UInt32)g[0], + (UInt32)(h[0] >> 32), (UInt32)h[0], + (UInt32)(x[0] >> 32), (UInt32)x[0], + (UInt32)(y[0] >> 32), (UInt32)y[0]); + printf("i %ld j %ld m %ld len %ld\n",i,j,m,hlen); + } + if (z==MAXPAIR) goto done; + } + } + } + done: + if (z < MAXPAIR) + { + printf("Mix success %2ld bytes %2ld initvals ",i,m); + printf("required %ld trials\n",z/2); + } + } + printf("\n"); +} + +// Check for reading beyond the end of the buffer and alignment problems +static void driver3() +{ + UInt8 buf[MAXLEN+20], *b; + UInt32 len; + UInt8 q[] = "This is the time for all good men to come to the aid of their country..."; + UInt32 h; + UInt8 qq[] = "xThis is the time for all good men to come to the aid of their country..."; + UInt32 i; + UInt8 qqq[] = "xxThis is the time for all good men to come to the aid of their country..."; + UInt32 j; + UInt8 qqqq[] = "xxxThis is the time for all good men to come to the aid of their country..."; + UInt64 ref,x,y; + UInt8 *p; + + printf("Endianness. These lines should all be the same (for values filled in):\n"); + p = q; + printf("%.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x\n", + DrHash32::Compute(p, sizeof(q)-1, 13), + DrHash32::Compute(p, sizeof(q)-2, 13), + DrHash32::Compute(p, sizeof(q)-3, 13), + DrHash32::Compute(p, sizeof(q)-4, 13), + DrHash32::Compute(p, sizeof(q)-5, 13), + DrHash32::Compute(p, sizeof(q)-6, 13), + DrHash32::Compute(p, sizeof(q)-7, 13), + DrHash32::Compute(p, sizeof(q)-8, 13), + DrHash32::Compute(p, sizeof(q)-9, 13), + DrHash32::Compute(p, sizeof(q)-10, 13), + DrHash32::Compute(p, sizeof(q)-11, 13), + DrHash32::Compute(p, sizeof(q)-12, 13)); + p = &qq[1]; + printf("%.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x\n", + DrHash32::Compute(p, sizeof(q)-1, 13), + DrHash32::Compute(p, sizeof(q)-2, 13), + DrHash32::Compute(p, sizeof(q)-3, 13), + DrHash32::Compute(p, sizeof(q)-4, 13), + DrHash32::Compute(p, sizeof(q)-5, 13), + DrHash32::Compute(p, sizeof(q)-6, 13), + DrHash32::Compute(p, sizeof(q)-7, 13), + DrHash32::Compute(p, sizeof(q)-8, 13), + DrHash32::Compute(p, sizeof(q)-9, 13), + DrHash32::Compute(p, sizeof(q)-10, 13), + DrHash32::Compute(p, sizeof(q)-11, 13), + DrHash32::Compute(p, sizeof(q)-12, 13)); + p = &qqq[2]; + printf("%.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x\n", + DrHash32::Compute(p, sizeof(q)-1, 13), + DrHash32::Compute(p, sizeof(q)-2, 13), + DrHash32::Compute(p, sizeof(q)-3, 13), + DrHash32::Compute(p, sizeof(q)-4, 13), + DrHash32::Compute(p, sizeof(q)-5, 13), + DrHash32::Compute(p, sizeof(q)-6, 13), + DrHash32::Compute(p, sizeof(q)-7, 13), + DrHash32::Compute(p, sizeof(q)-8, 13), + DrHash32::Compute(p, sizeof(q)-9, 13), + DrHash32::Compute(p, sizeof(q)-10, 13), + DrHash32::Compute(p, sizeof(q)-11, 13), + DrHash32::Compute(p, sizeof(q)-12, 13)); + p = &qqqq[3]; + printf("%.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x\n", + DrHash32::Compute(p, sizeof(q)-1, 13), + DrHash32::Compute(p, sizeof(q)-2, 13), + DrHash32::Compute(p, sizeof(q)-3, 13), + DrHash32::Compute(p, sizeof(q)-4, 13), + DrHash32::Compute(p, sizeof(q)-5, 13), + DrHash32::Compute(p, sizeof(q)-6, 13), + DrHash32::Compute(p, sizeof(q)-7, 13), + DrHash32::Compute(p, sizeof(q)-8, 13), + DrHash32::Compute(p, sizeof(q)-9, 13), + DrHash32::Compute(p, sizeof(q)-10, 13), + DrHash32::Compute(p, sizeof(q)-11, 13), + DrHash32::Compute(p, sizeof(q)-12, 13)); + printf("\n"); + for (h=0, b=buf+1; h<8; ++h, ++b) + { + for (i=0; i> 32), (UInt32)ref, + (UInt32)(x >> 32), (UInt32)x, + (UInt32)(y >> 32), (UInt32)y, + h,i); + } + } + } +} + +// check for problems with nulls +static void driver4() +{ + UInt32 i; + UInt64 h,state[HASHSTATE]; + + + for (i=0; i> 32), (UInt32)h); + } +} + +// Check that StringI really is case insensitive +// and equivalent to Compute on an uppercased string +static void driver5() +{ + const char x1[] = "mares eat oats and does eat oats and little lambs eat ivy\n"; + const char x2[] = "Mares Eat Oats And Does Eat Oats And Little Lambs Eat Ivy\n"; + const char x3[] = "MARES EAT OATS AND DOES EAT OATS AND LITTLE LAMBS EAT IVY\n"; + const char y1[] = "bob"; + const char y2[] = "Bob"; + const char y3[] = "BOB"; + printf("\nStringI: Columns are the same, rows are different\n"); + printf("%.8x%.8x %.8x%.8x\n", + (UInt32)(DrHash64::StringI( x1, strlen(x1), 666) >> 32), + (UInt32)(DrHash64::StringI( x1, strlen(x1), 666)), + (UInt32)(DrHash64::StringI( y1, strlen(y1), 666) >> 32), + (UInt32)(DrHash64::StringI( y1, strlen(y1), 666))); + printf("%.8x%.8x %.8x%.8x\n", + (UInt32)(DrHash64::StringI( x2, strlen(x2), 666) >> 32), + (UInt32)(DrHash64::StringI( x2, strlen(x2), 666)), + (UInt32)(DrHash64::StringI( y2, strlen(y2), 666) >> 32), + (UInt32)(DrHash64::StringI( y2, strlen(y2), 666))); + printf("%.8x%.8x %.8x%.8x\n", + (UInt32)(DrHash64::StringI( x3, strlen(x3), 666) >> 32), + (UInt32)(DrHash64::StringI( x3, strlen(x3), 666)), + (UInt32)(DrHash64::StringI( y3, strlen(y3), 666) >> 32), + (UInt32)(DrHash64::StringI( y3, strlen(y3), 666))); + printf("%.8x%.8x %.8x%.8x\n", + (UInt32)(DrHash64::Compute( (const void *)x3, strlen(x3), 666) >> 32), + (UInt32)(DrHash64::Compute( (const void *)x3, strlen(x3), 666)), + (UInt32)(DrHash64::Compute( (const void *)y3, strlen(y3), 666) >> 32), + (UInt32)(DrHash64::Compute( (const void *)y3, strlen(y3), 666))); +} + +int __cdecl main(int argc, char **argv) +{ + driver1(); // test that the key is hashed: used for timings + driver2(); // test that whole key is hashed thoroughly + driver3(); // test that nothing but the key is hashed + driver4(); // test hashing multiple buffers (all buffers are null) + driver5(); // test that StringI really is case insensitive + return 0; +} + + +#endif diff --git a/DryadVertex/VertexHost/system/classlib/src/DrHeap.cpp b/DryadVertex/VertexHost/system/classlib/src/DrHeap.cpp new file mode 100644 index 0000000..66fb52b --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/DrHeap.cpp @@ -0,0 +1,182 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrCommon.h" + + +// Recursively ensure that the node is larger than its two children +void DryadHeap::DownHeapify(DWORD index) +{ + DWORD swapWith = index; + + // If we have to swap, it will be with the larger of the two children + // Find the largest child. The right child may not exist. + if (Exists(RightChild(index)) && m_entries[RightChild(index)]->IsHigherPriorityThan(m_entries[LeftChild(index)])) + swapWith = RightChild(index); + else if (Exists(LeftChild(index))) + swapWith = LeftChild(index); + + // Swap and recurse if necessary + if (swapWith != index) + { + if (m_entries[swapWith]->IsHigherPriorityThan(m_entries[index])) + { + // Child is higher priority, which violates heap rule, so swap + HeapSwap(swapWith, index); + DownHeapify(swapWith); + } + } +} + +void DryadHeap::GrowHeap() +{ + size_t newHeapSize = m_heapAllocSize * 2 + c_heapGrowAmount; + + DryadHeapItem **newHeap = new DryadHeapItem* [newHeapSize]; + LogAssert(newHeap != NULL); + + if (m_entries != NULL) + { + memcpy(newHeap, m_entries, m_heapAllocSize * sizeof(m_entries[0])); + delete[] m_entries; + } + + m_entries = newHeap; + m_heapAllocSize = newHeapSize; +} + +// Insert the given value into the heap +void DryadHeap::InsertHeapEntry(DryadHeapItem *entry) +{ + DWORD insertLoc = m_numEntries+1; + DWORD parent; + + if (insertLoc >= m_heapAllocSize) + GrowHeap(); + + // While we are not the root, and inserting at our current position would make us > the parent + // (which violates the heap rule)... + while (Exists((parent = ParentOf(insertLoc))) && entry->IsHigherPriorityThan(m_entries[parent])) + { + // Copy the parent into the location we would have inserted + // Now plan to insert in the parent's location + m_entries[insertLoc] = m_entries[parent]; + m_entries[insertLoc]->m_heapIndex = insertLoc; + insertLoc = parent; + } + + m_entries[insertLoc] = entry; + m_entries[insertLoc]->m_heapIndex = insertLoc; + m_numEntries++; +} + +// Peek at the first item in the heap (does not remove it) +// Returns NULL if the heap is empty +DryadHeapItem *DryadHeap::PeekHeapRoot() +{ + if (m_numEntries == 0) + return NULL; + + return m_entries[1]; +} + +// Extract the first item from the heap +// Returns NULL if the heap is empty +DryadHeapItem *DryadHeap::DequeueHeapRoot() +{ + if (m_numEntries == 0) + return NULL; + + DryadHeapItem *value = m_entries[1]; + + // Move last node to top + m_entries[1] = m_entries[m_numEntries]; + m_entries[1]->m_heapIndex = 1; + m_numEntries--; + DownHeapify(1); + + return value; +} + +// Preserves heap property by moving item up or down as appropriate +void DryadHeap::Heapify(DWORD index) +{ + DWORD parent = ParentOf(index); + + // If the node has a parent, but is out of order with respect to the parent, then swap with it + if (Exists(parent) && m_entries[index]->IsHigherPriorityThan(m_entries[parent])) + { + // Heap compare says node's value is such that it should be extracted before parent + HeapSwap(index, parent); + UpHeapify(parent); + } + else + { + // Didn't violate anything in the up direction, but may violate in the down direction + DownHeapify(index); + } +} + +// Bubble up the entry starting at the given index, if necessary +void DryadHeap::UpHeapify(DWORD index) +{ + DWORD parent = ParentOf(index); + + // If the node has a parent, but is out of order with respect to the parent, then swap with it + if (Exists(parent) && m_entries[index]->IsHigherPriorityThan(m_entries[parent])) + { + HeapSwap(index, parent); + UpHeapify(parent); + } +} + +// Remove the entry at the given heap position +// To do this, swap the last entry in the heap into the position we want to remove +void DryadHeap::RemoveHeapEntry(DWORD position) +{ + if (position > m_numEntries) + return; + + if (position == m_numEntries) + { + // Entry to remove was already the last entry in the list + m_numEntries--; + } + else + { + HeapSwap(position, m_numEntries); + m_numEntries--; + + // Make sure we haven't violated the heap in either the up or down direction + Heapify(position); + } +} + + +void DryadHeap::HeapSwap(DWORD index1, DWORD index2) +{ + DryadHeapItem *temp = m_entries[index1]; + m_entries[index1] = m_entries[index2]; + m_entries[index2] = temp; + m_entries[index1]->m_heapIndex = index1; + m_entries[index2]->m_heapIndex = index2; +} + + diff --git a/DryadVertex/VertexHost/system/classlib/src/DrLogging.cpp b/DryadVertex/VertexHost/system/classlib/src/DrLogging.cpp new file mode 100644 index 0000000..5ea0b0b --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/DrLogging.cpp @@ -0,0 +1,195 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrCommon.h" + +#pragma unmanaged + +// +// there are race conditions in the accessor functions for these global bools, however +// we are ignoring them since their effects will merely be to change the number of log +// statements printed if the logging level is dynamically changed during a program run, +// which is benign compared to the overhead of acquiring a lock while testing for logging +// being enabled +// + +// +// Default level is warning +// +static LogLevel s_loggingType = LogLevel_Warning; + +// +// Update the logging level to all at supplied level and more severe +// +void DrLogging::SetLoggingLevel(LogLevel type) +{ + s_loggingType = type; +} + +// +// Check whether logging is enabled at a certain event level +// +bool DrLogging::Enabled(LogLevel type) +{ + return ((s_loggingType & type) == type); +} + +// +// Flush the log +// +void DrLogging::FlushLog() +{ + fflush(m_logFile); +} + +// +// Return log file - used by drlogginghelper +// +FILE* DrLogging::GetLogFile() +{ + return m_logFile; +} + +// +// Create the vertex host log file +// +FILE* DrLogging::CreateLogFile() +{ + WCHAR szCurrentDir[MAX_PATH + 1] = {0}; + WCHAR szLogFile[MAX_PATH + 1] = {0}; + + if (GetCurrentDirectory(MAX_PATH, szCurrentDir) != 0) + { + if (S_OK == StringCchPrintf(szLogFile, MAX_PATH, L"%s\\VertexHostLog.txt", szCurrentDir)) + { + FILE * logFile = _wfsopen(szLogFile, L"w", _SH_DENYWR); + if(logFile != NULL) + { + // If log file created successfully, use it + return logFile; + } + } + } + + // if there is an error creating the log file, fall back to stderr + return stderr; +} + +// +// Initialize log file +// +FILE* DrLogging::m_logFile = CreateLogFile(); + +// +// Log the provided string +// +void DrLogHelper::operator()(const char* format, ...) +{ + va_list args; + va_start(args, format); + + SYSTEMTIME utc, local; + FILETIME ft; + GetSystemTimeAsFileTime(&ft); + FileTimeToSystemTime(&ft, &utc); + SystemTimeToTzSpecificLocalTime(NULL, &utc, &local); + + // + // Get character for event logging level + // + char initial = 0; + switch (m_type) + { + case LogLevel_Assert: + initial = 'a'; + break; + case LogLevel_Error: + initial = 'e'; + break; + case LogLevel_Warning: + initial = 'w'; + break; + case LogLevel_Info: + initial = 'i'; + break; + case LogLevel_Debug: + initial = 'd'; + break; + } + + + // + // Get formatted message + // + DrStr128 s = ""; + if (format != NULL) + { + s.VSetF(format, args); + } + + // + // Strip path from filename, if present. + // + const char * filename =m_file; + const char * lastBackslash = strrchr(filename, '\\'); + if (lastBackslash != NULL) + { + filename = lastBackslash + 1; + } + + // + // Print out message to stderr + // + fprintf(DrLogging::GetLogFile(), + "%c, %02d/%02d/%04d %02d:%02d:%02d.%03d, TID=%u,%s,%s:%d,%s\n", + initial, + local.wMonth, + local.wDay, + local.wYear, + local.wHour, + local.wMinute, + local.wSecond, + local.wMilliseconds, + GetCurrentThreadId(), + m_function, + m_file, + m_line, + s.GetString() + ); + + + va_end(args); + + DrLogging::FlushLog(); + // + // If assert level, assert(false) after logging and flushing the stream + // + if (m_type == LogLevel_Assert) + { + DrLogging::FlushLog(); + + if (IsDebuggerPresent()) + { + ::DebugBreak(); + } + + TerminateProcess(GetCurrentProcess(), DrError_Fail); + } +} diff --git a/DryadVertex/VertexHost/system/classlib/src/DrMemory.cpp b/DryadVertex/VertexHost/system/classlib/src/DrMemory.cpp new file mode 100644 index 0000000..efd71bf --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/DrMemory.cpp @@ -0,0 +1,386 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrCommon.h" + +#pragma unmanaged + +/* + + Implementation of core cosmos memory/buffer management + +*/ + +// +// Get address and size of available memory chunk allocating more memory if necessary +// +// Note that this may return memory beyond the current available size; it is +// the caller's responsibility to SetAvailableSize() if necessary. +// +// Note also that the returned *puSize may be greater than uDataSize +// +void *DrMemoryBuffer::GetWriteAddress( + Size_t uOffset, // offset at which to return a write pointer + Size_t uDataSize, // minimum number of bytes to ensure is available (not necessarily contiguous) starting at the specified offset + /* out */ Size_t *puSize, // Number of contiguous bytes starting at the returned pointer + /* out */ Size_t *puPreceedingSize // Number of contiguous bytes that preceed the returned pointer +) +{ + LogAssert(IsWritable()); + + Size_t uEnd = uOffset + uDataSize; + LogAssert(uEnd >= uOffset); + if (uEnd > m_uAllocatedSize) { + IncreaseAllocatedSize(uEnd); + LogAssert(uEnd <= m_uAllocatedSize); + } + + void *pAddress; + Size_t uSize; + Size_t uPrec; + + pAddress = GetDataAddress(uOffset, &uSize, &uPrec); + LogAssert(pAddress != NULL); + + *puSize = uSize; + if (puPreceedingSize != NULL) { + *puPreceedingSize = uPrec; + } + + return pAddress; +}; + +// +// Copy uDataSize bytes from pData array into the buffer at the +// specified offset. Grows the available data size if necessary to include the written data. +// +void DrMemoryBuffer::Write( + Size_t uOffset, // starting offset + const void *pData, // data buffer + Size_t uDataSize // number of bytes to copy +) +{ + const BYTE *pSrc = (const BYTE *)pData; + BYTE *pDst; + Size_t uSize; + + while (uDataSize != 0) { + pDst = (BYTE *)GetWriteAddress(uOffset, uDataSize, &uSize); + if (uSize > uDataSize) { + uSize = uDataSize; + } + memcpy(pDst, pSrc, uSize); + pSrc += uSize; + uOffset += uSize; + uDataSize -= uSize; + } + + if (uOffset > GetAvailableSize()) { + SetAvailableSize(uOffset); + } +}; + +// +// Zero uDataSize bytes in the buffer at the +// specified offset. Grows the available data size if necessary to include the zeroed data. +// +void DrMemoryBuffer::Zero( + Size_t uOffset, // starting offset + Size_t uDataSize // number of bytes to set to 0 +) +{ + BYTE *pDst; + Size_t uSize; + + while (uDataSize != 0) { + pDst = (BYTE *)GetWriteAddress(uOffset, uDataSize, &uSize); + if (uSize > uDataSize) { + uSize = uDataSize; + } + memset(pDst, 0, uSize); + uOffset += uSize; + uDataSize -= uSize; + } + + if (uOffset > GetAvailableSize()) { + SetAvailableSize(uOffset); + } +}; + +// +// Get the address and size of contiguous available readable memory area at offset uOffset +// +const void *DrMemoryBuffer::GetReadAddress( + Size_t uOffset, + /* out */ Size_t *puSize, // Number of contiguous available bytes beginning at uOffset + /* out */ Size_t *puPreceedingSize // Number of contiguous readable bytes that preceed the returned pointer +) +{ + const void *pData; + Size_t uSize; + Size_t uPrec; + + LogAssert(uOffset < GetAvailableSize()); + + pData = GetDataAddress(uOffset, &uSize, &uPrec); + if (uSize + uOffset > GetAvailableSize()) + uSize = GetAvailableSize() - uOffset; + + *puSize = uSize; + if (puPreceedingSize != NULL) { + *puPreceedingSize = uPrec; + } + + return pData; +}; + +// +// Read uDataSize bytes into pData into the buffer starting at uOffset. +// +// It is a fatal error to attempt to read beyond the available size of the buffer +// +void DrMemoryBuffer::Read( + Size_t uOffset, + void *pData, + Size_t uDataSize +) +{ + BYTE *pDst = (BYTE *)pData; + + while (uDataSize != 0) { + Size_t uSize; + const BYTE *pSrc = (const BYTE *)GetReadAddress(uOffset, &uSize); + if (uSize > uDataSize) + uSize = uDataSize; + memcpy(pDst, pSrc, uSize); + pDst += uSize; + uDataSize -= uSize; + uOffset += uSize; + } +}; + +// +// Compares uDataSize bytes from pData with buffer contents starting at uOffset. +// +// Return 0 on match, < 0 if contents of the buffer is less than contents of pData, > 0 otherwise +// +// It is a fatal error to attempt to read beyond the available size of the buffer +// +int DrMemoryBuffer::Compare( + Size_t uOffset, + const void *pData, + Size_t uDataSize +) +{ + int iResult; + const BYTE *pDst = (const BYTE *)pData; + + iResult = 0; + + while (uDataSize != 0) { + Size_t uSize; + const BYTE *pSrc = (const BYTE *)GetReadAddress(uOffset, &uSize); + if (uSize > uDataSize) + uSize = uDataSize; + iResult = memcmp(pSrc, pDst, uSize); + if (iResult != 0) + break; + pDst += uSize; + uDataSize -= uSize; + uOffset += uSize; + } + + return (iResult); +}; + +// +// Copy data from one buffer to another +// +void DrMemoryBuffer::CopyBuffer( + Size_t uDstOffset, // starting offset + DrMemoryBuffer *pSrcBuffer, // source buffer + Size_t uSrcOffset, // starting offset in source buffer + Size_t uDataSize // number of bytes to copy +) +{ + if (uDataSize != 0) { + IncreaseAllocatedSize(uDstOffset + uDataSize); + while (uDataSize != 0) { + Size_t uSize; + const BYTE *pSrc = (const BYTE *)pSrcBuffer->GetReadAddress(uSrcOffset, &uSize); + if (uSize > uDataSize) + uSize = uDataSize; + Write(uDstOffset, pSrc, uSize); + uDataSize -= uSize; + uSrcOffset += uSize; + uDstOffset += uSize; + } + } +} + +DrSimpleHeapBuffer::DrSimpleHeapBuffer() +{ + m_pData = NULL; +} + +DrSimpleHeapBuffer::DrSimpleHeapBuffer(Size_t uSize) +{ + m_pData = NULL; + if (uSize != 0) { + m_pData = (BYTE *)malloc(uSize); + LogAssert(m_pData != NULL); + m_uAllocatedSize = uSize; + } else { + m_pData = NULL; + } +} + +DrSimpleHeapBuffer::~DrSimpleHeapBuffer() +{ + if (m_pData != NULL) { + free(m_pData); + } +} + +// Detaches the underlying heap object (if any) and returns it to the caller, who +// must call free() on the memory when done with it. +// +// returns NULL if there is no underlying heap object (allocedSize = 0) +// +// After this call, the buffer is a new buffer with no data in it. +// +void *DrSimpleHeapBuffer::DetachHeapItem() +{ + void *pRet = m_pData; + m_pData = NULL; + m_uAllocatedSize = 0; + SetAvailableSize(0); + return pRet; +} + +// Attaches an external heap item to the buffer. Any previous heap item is +// freed. The buffer becomes the owner of the heap item. +// +void DrSimpleHeapBuffer::AttachHeapItem(void *pHeapItem, Size_t allocedSize, Size_t dataSize) +{ + LogAssert(allocedSize >= dataSize); + LogAssert(allocedSize == 0 || pHeapItem != NULL); + if (m_pData != NULL) { + free(m_pData); + } + m_pData = (BYTE *)pHeapItem; + m_uAllocatedSize = allocedSize; + SetAvailableSize(dataSize); +} + +// +// Retrieve pointer to the data stored in memory block at uOffset and max size available in this block +// +// Returns NULL if no data at this offset, valid pointer otherwise +// +void *DrSimpleHeapBuffer::GetDataAddress( + Size_t uOffset, // starting offset + Size_t *puSize, // number of bytes available (0 in case of failure) + Size_t *puPriorSize // optional; size of contigious memory area prior to (*GetDataAddress()) +) +{ + BYTE *pRet; + Size_t prec; + + if (uOffset >= m_uAllocatedSize) { + pRet = NULL; + *puSize = 0; + prec = 0; + } else { + pRet = m_pData + uOffset; + *puSize = m_uAllocatedSize - uOffset; + prec = uOffset; + } + + if (puPriorSize != NULL) { + *puPriorSize = prec; + } + + return (void *)pRet; +} + +// +// Preallocate enough memory buffers to fix uMaxSize bytes of data. +// +void DrSimpleHeapBuffer::IncreaseAllocatedSize( + Size_t uSize // preallocate memory blocks to fit at least uSize bytes of data +) +{ + if (uSize > m_uAllocatedSize) { + if (uSize < 32) { + uSize = 32; + } + if (uSize < 2 * m_uAllocatedSize) { + uSize = 2 * m_uAllocatedSize; + } + if (m_pData == NULL) { + m_pData = (BYTE *)malloc(uSize); + LogAssert(m_pData != NULL); + } else { + BYTE *pNew = (BYTE *)realloc(m_pData, uSize); + LogAssert(pNew != NULL); + m_pData = pNew; + } + m_uAllocatedSize = uSize; + } +} + +// +// Retrieve pointer to the data stored in memory block at uOffset and max size available in this block +// +// Returns NULL if no data at this offset, valid pointer otherwise +// +void *DrFixedMemoryBuffer::GetDataAddress( + Size_t uOffset, // starting offset + Size_t *puSize, // number of bytes available (0 in case of failure) + Size_t *puPriorSize // optional; size of contigious memory area prior to (*GetDataAddress()) +) +{ + void *pRet; + Size_t uPrior; + if (uOffset >= m_uAllocatedSize) { + *puSize = 0; + uPrior = 0; + pRet = NULL; + } else { + *puSize = m_uAllocatedSize - uOffset; + uPrior = uOffset; + pRet = (void *)(m_pData + uOffset); + } + if (puPriorSize != NULL) { + *puPriorSize = uPrior; + } + return pRet; +} + +// +// Preallocate enough memory buffers to fix uMaxSize bytes of data. +// +void DrFixedMemoryBuffer::IncreaseAllocatedSize( + Size_t uSize // preallocate memory blocks to fit at least uSize bytes of data +) +{ + LogAssert(uSize <= m_uAllocatedSize); +} diff --git a/DryadVertex/VertexHost/system/classlib/src/DrMemoryStream.cpp b/DryadVertex/VertexHost/system/classlib/src/DrMemoryStream.cpp new file mode 100644 index 0000000..44437f0 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/DrMemoryStream.cpp @@ -0,0 +1,1218 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrCommon.h" + +#include +#include +#include + +#pragma unmanaged + +void DrMemoryStream::DiscardMemoryStreamContext() +{ + pBlockBase = NULL; + blockLength = 0; + pData = NULL; + uBlockBasePhysicalStreamPosition = 0; +} + +DrMemoryStream::DrMemoryStream() +{ + status = DrError_OK; + pBlockBase = NULL; + blockLength = 0; + pData = NULL; + uBlockBasePhysicalStreamPosition = 0; +} + +DrMemoryStream::~DrMemoryStream() +{ + DiscardMemoryStreamContext(); + status = DrError_Impossible; +} + +__declspec(deprecated) DrError DrMemoryStream::Close() +{ + DiscardMemoryStreamContext(); + return status; +} + +DrMemoryWriter::DrMemoryWriter() +{ + m_pendingBufferredWriteOldAvailableSize = 0; + m_fMemoryWriterIsClosed = false; + m_fIgnoreMemoryWriterCloseFailureInDestructor = false; +} + +void DrMemoryWriter::InternalFree() +{ + m_pPendingBufferedWriteBuffer = NULL; + m_pendingBufferredWriteOldAvailableSize = 0; +} + +// Clears all buffer context values to initial defaults (including settting the physical stream position to 0 +// and discarding abandandoned temporary buffered write buffers), but does not change the current status code +// or "closed" status. +// may be overridden by a subclass if it needs to free resources controlled by the +// context pointers before delegating to this method implementation. +void DrMemoryWriter::DiscardMemoryWriterContext() +{ + DrMemoryStream::DiscardMemoryStreamContext(); + InternalFree(); +} + +DrError DrMemoryWriter::CloseMemoryWriter() +{ + // Note, if we got here, we were chained to by any subclass implementation of CloseMemoryWriter. + + // We only allow CloseMemoryWriter to be called once + if (!m_fMemoryWriterIsClosed) + { + // Before we close, we call FlushMemoryWriter. + // We do not bother flushing if we are in an error condition. + // This is a virtual method which can be implemented by a subclass. + if (status == DrError_OK) + { + SetStatus(FlushMemoryWriter()); + } + + // Release any abandoned temporary buffered write buffers, and clear the current contiguous block context. + // The current status remains unchanged. + // We do this regardless of status because it frees unneeded resources even on failure. + DiscardMemoryWriterContext(); + + // We set the closed flag regardless of the status. CloseMemoryWriter can only be called once. + m_fMemoryWriterIsClosed = true; + } + + // At this point the only meaningful state we maintain is the status and the closed flag. + return status; +} + +void DrMemoryWriter::MemoryWriterDestructorClose() +{ + if (!m_fMemoryWriterIsClosed) + { + // If status is already failing, then it is OK to fail close at destructor time, since no new error + // is occurring, so flushing is unnecessary. + bool fIgnoreFailures = (m_fIgnoreMemoryWriterCloseFailureInDestructor || status != DrError_OK); + + // We call CloseMemoryWriter even if status is not DrError_OK because implementations + // use CloseMemoryWriter to free resources even in an error state. + // NOTE that this call goes through the virtual memory chain, executing on the most-derived class first. + SetStatus(CloseMemoryWriter()); + + // if m_fMemoryWriterIsClosed is not set, it means that someone forgot to chain to their parent class's + // CloseMemoryWriter() + LogAssert(m_fMemoryWriterIsClosed); + + // If a new error occured on close and we are not ignoring errors, then it is a fatal error. + LogAssert(fIgnoreFailures || status == DrError_OK); + } +} + +DrMemoryWriter::~DrMemoryWriter() +{ + //======================================================================================================= + // NOTE: each subclass should duplicate this code if it implements CloseMemoryWriter or FlushMemoryWriter + // + if (!MemoryWriterIsClosed()) + { + // Close the memory writer. We cannot tolerate new failures in CloseMemoryWriter at this point because we are in a destructor! + // This will call FlushMemoryWriter() before closing, but because of virtual destructor unwinding, it will call the + // current class's implementation rather than the subclass's overriding implementation. So each subclass needs to implement this call in their + // virtual destructor if they implement CloseMemoryWriter or FlushMemoryWriter + MemoryWriterDestructorClose(); + } + // + // NOTE: End dublicated destructor code + //======================================================================================================= + + InternalFree(); +} + +// Write a narrow (char[]) string "pstr" as a property, with an explicit provided string length. pstr may be NULL, which is +// properly encoded as a distinct value from a zero-length string. Typically, a non-NULL pstr points to a +// UTF-8 string (not-necessarily '\0'-terminated) that is "length" bytes long, without any embedded '\0' bytes; +// however, this function may be used to encode arbitrary byte blocks, including blocks that contain embedded '\0' bytes, +// so it can be used, e.g., to encode a concatenated list of '\0'-terminated strings. +// The primary distinctions between this function and WriteBlob are: +// a) NULL is a distinct value from a zero-length string. +// b) The property length is 0 for NULL values, and Llength"+1 for non-NULL values. +// b) For non-NULL values, a '\0' is appended to the string bytes as the last byte of the property data. This does not change +// the value of the string, but allows a reader to use the serialized data as a null terminated string without copying it, and +// allows a reader to distinguish between a NULL string and an empty string. +// +// The provided length does not include the terminating '\0' byte; however, a non-NULL pstr +// is always '\0'-terminated in the output stream, and a NULL pstr is wriiten as an empty property. +// The actual length of the property data will be "length" + 1 if pstr is not NULL. +// +// These semantics are consistent with round-trip encoding of a DrStr value. +// +// This vertion only works with LONGATOM properties. This is asserted, to help detect bugs where strings that may +// occasionally exceed 254 bytes are accidently tagged as SHORTATOM. If you have a string property that you know +// *must* always be less than 255 bytes long, you can use WriteShortStringPropertyWithLength. +// +// Causes an assertion failure if: +// a) enumId is a SHORTATOM +// b) length >= _UI32_MAX +// c) pstr == NULL && length != 0 +DrError DrMemoryWriter::WriteLongStringPropertyWithLength(UInt16 enumId, const char *pstr, Size_t length) +{ + if (pstr == NULL) + { + LogAssert(length == 0); + WriteEmptyLongProperty(enumId); + } + else + { + LogAssert(length < (Size_t)_UI32_MAX); + if (WritePropertyTagLong(enumId, length + 1) == DrError_OK) + { + WriteBytes(pstr, length); + WriteChar('\0'); + } + + } + return status; +} + +// This version writes a SHORTATOM string value. +// +// This vertion only works with SHORTATOM properties. If you have a string property that you know +// *must* always be less than 255 bytes long, you can use a SHORTATOM property ID and this method. +// Otherwise, you should use WriteLongStringPropertyWithLength(). +// +// Causes an assertion failure if: +// a) enumId is a LONGATOM +// b) length >= 255 +// c) pstr == NULL && length != 0 +DrError DrMemoryWriter::WriteShortStringPropertyWithLength(UInt16 enumId, const char *pstr, Size_t length) +{ + if (pstr == NULL) + { + LogAssert(length == 0); + WriteEmptyShortProperty(enumId); + } + else + { + LogAssert(length < (Size_t)_UI8_MAX); + if (WritePropertyTagShort(enumId, length + 1) == DrError_OK) + { + WriteBytes(pstr, length); + WriteChar('\0'); + } + + } + return status; +} + +// This method should be overridden by memory writers that know how to allocate a new block and keep writing. +// The implementation should update pData, pBlockBase, blockLength, and uBlockBasePhysicalStreamPosition to point to the new block. +// The current block will be completely filled before calling this method, since there is no way to back up. +// After this call, the old block can be disposed of in any way the underlying implementation chooses (e.g., flushing). +// Returns DrError_EndOfStream if a new block can't or shouldn't be allocated. +// The default implementation always returns DrError_EndOfStream, which is appropriate for single-block writers. +// If an error is returned, status has been set. +DrError DrMemoryWriter::AdvanceToNextBlock() +{ + return SetStatus(DrError_EndOfStream); +} + +// This method checks whether an attempt to write beyond the end of the current block +// will succeed. The implementation should return true if the buffer is indefinitely growable. +// The default implementation returns false, which is appropriate for single-block writers. +// NOt that this method does not set or check status +bool DrMemoryWriter::FutureBlocksCanBeWritten(size_t length) +{ + (void)length; + return false; +} + +// This method writes bytes into the buffer, handling the case where the data +// will cross blocks. +DrError DrMemoryWriter::CrossBlockWriteBytes(const BYTE *pBytes, size_t length) +{ + if (EnsureCanBeWritten(length) != DrError_OK) { + return status; + } + + while (length > (size_t)0) { + if (NumContiguousBytesRemaining() == (size_t)0) { + if (AdvanceToNextBlock() != DrError_OK) { + return status; + } + } + size_t nr = NumContiguousBytesRemaining(); + LogAssert(nr > 0); + size_t nb = (length > nr) ? nr : length; + memcpy(pData, pBytes, nb); + pData += nb; + pBytes += nb; + length -= nb; + } + return DrError_OK; +} + +DrError DrMemoryWriter::WriteBytesFromReader(DrMemoryReader *pReader, Size_t length) +{ + SetStatus(pReader->ReadBytesIntoWriter(this, length)); + return status; +} + +DrError DrMemoryWriter::WriteBytesFromBuffer(DrMemoryBuffer *pBuffer, bool fAllowCopyBufferByReference) +{ + if (status == DrError_OK) { + Size_t nb = pBuffer->GetAvailableSize(); + if (nb != 0) { + DrMemoryBufferReader reader(pBuffer, fAllowCopyBufferByReference); + WriteBytesFromReader(&reader, nb); + } + } + return status; +} + + +// DrMemoryReader + +DrMemoryReader::DrMemoryReader() +{ + pFirstTempMemHeader = NULL; + pFirstPeekMemHeader = NULL; + pLastPeekMemHeader = NULL; + fAllowCopyBufferByReference = false; + m_fMemoryReaderIsClosed = false; +} + +void DrMemoryReader::InternalFree() +{ + DiscardTemporaryResults(); + DiscardPeekBlocks(); +} + +// Clears all buffer context values to initial defaults (including settting the physical stream position to 0 +// and discarding peekahead buffers), but does not discard temporary results and does not change the current status code +// or "closed" status. +// may be overridden by a subclass if it needs to free resources controlled by the +// context pointers before delegating to this method implementation. +void DrMemoryReader::DiscardMemoryReaderContext() +{ + DrMemoryStream::DiscardMemoryStreamContext(); + DiscardPeekBlocks(); +} + +// This method clears all buffer context values to initial defaults (including settting the physical stream position to 0, +// discarding peekahead buffers and temporary results, clearing the "closed" state, and setting the status to DrError_OK. +// a subclass may override this to reset its own state along with forwarding the request to this class. +// Calls through the virtual DiscardMemoryReaderContext() before clearing the status and the closed flag. +void DrMemoryReader::ResetMemoryReader() +{ + DiscardMemoryReaderContext(); + DiscardTemporaryResults(); + status = DrError_OK; + m_fMemoryReaderIsClosed = false; +} + +DrError DrMemoryReader::CloseMemoryReader() +{ + // Note, if we got here, we were chained to by any subclass implementation of CloseMemoryReader. + + // We only allow CloseMemoryReader to be called once + if (!m_fMemoryReaderIsClosed) + { + // Release any abandoned temporary buffered write buffers, and clear the current contiguous block context. + // The current status remains unchanged. + // We do this regardless of status because it frees unneeded resources even on failure. + DiscardMemoryReaderContext(); + DiscardTemporaryResults(); + + // We set the closed flag regardless of the status. CloseMemoryReader can only be called once. + m_fMemoryReaderIsClosed = true; + } + + // At this point the only meaningful state we maintain is the status and the closed flag. + return status; +} + +void DrMemoryReader::MemoryReaderDestructorClose() +{ + if (!m_fMemoryReaderIsClosed) + { + bool fAlreadyFailing = (status != DrError_OK); + + // We call CloseMemoryReader even if status is not DrError_OK because implementations + // use CloseMemoryReader to free resources even in an error state. + // NOTE that this call goes through the virtual memory chain, executing on the most-derived class first. + SetStatus(CloseMemoryReader()); + + // if m_fMemoryWriterIsClosed is not set, it means that someone forgot to chain to their parent class's + // CloseMemoryWriter() + LogAssert(m_fMemoryReaderIsClosed); + + if (!fAlreadyFailing && (status != DrError_OK)) + { + // CloseMemoryReader introduced a new failure. But we are in a destructor, so we can't return an error code. + // Close errors on reader streams are generally hamless, so we will ignore the error. + ; + } + } +} + +void DrMemoryReader::DiscardTemporaryResults() +{ + while (pFirstTempMemHeader != NULL) { + TempMemHeader *p = pFirstTempMemHeader; + pFirstTempMemHeader = p->Detach(); + delete p; + } +} + +void DrMemoryReader::DiscardPeekBlocks() +{ + while (pFirstPeekMemHeader != NULL) { + PeekMemHeader *p = pFirstPeekMemHeader; + pFirstPeekMemHeader = p->Detach(); + delete p; + } + pLastPeekMemHeader = NULL; +} + +// The destructor for DrMemoryReader frees all the temporary +// return values allocated since the reader was created. +DrMemoryReader::~DrMemoryReader() +{ + //======================================================================================================= + // NOTE: each subclass should duplicate this code if it implements CloseMemoryReader and needs to have it + // called at destruct time + // + if (!MemoryReaderIsClosed()) + { + // Close the memory reader. This will call CloseMemoryReader() before closing, but because of virtual destructor unwinding, it will call the + // current class's implementation rather than the subclass's overriding implementation. So each subclass needs to implement this call in their + // virtual destructor if they implement CloseMemoryReader. + MemoryReaderDestructorClose(); + } + // + // NOTE: End dublicated destructor code + //======================================================================================================= + + InternalFree(); +} + + +// ReadNextPropertyTag. Reads the next property from the bag, along +// with its length (either 1- or 4-byte depending on the length +// bit in the property name). Returns DrError_EndOfStream if there is not enough +// data in the bag to read out the property name and length, but +// does not check that there are *pDataLen more bytes remaining. +DrError DrMemoryReader::ReadNextPropertyTag( + /* out */ UInt16 *pEnumId, + /* out */ UInt32 *pDataLen) +{ + if (ReadUInt16(pEnumId) != DrError_OK) { + return status; + } + + if (((*pEnumId) & PropLengthMask) == PropLength_Short) { + UInt8 lengthByte; + if (ReadUInt8(&lengthByte) == DrError_OK) { + *pDataLen = lengthByte; + } + } else { + ReadUInt32(pDataLen); + } + + return status; +} + +// PeekNextPropertyTag. Peeks at the next property from the bag, along +// with its length (either 1- or 4-byte depending on the length +// bit in the property name). Returns DrError_EndOfStream if there is not enough +// data in the bag to read out the property name and length, but +// does not check that there are *pDataLen more bytes remaining. +DrError DrMemoryReader::PeekNextPropertyTag( + /* out */ UInt16 *pEnumId, + /* out */ UInt32 *pDataLen) +{ + if (PeekUInt16(pEnumId) != DrError_OK) { + return status; + } + + BYTE tmp[sizeof(UInt16) + sizeof(UInt32)]; + + if (((*pEnumId) & PropLengthMask) == PropLength_Short) { + if (PeekBytes(tmp, sizeof(UInt16) + sizeof(UInt8)) == DrError_OK) { + *pDataLen = tmp[sizeof(UInt16)]; + } + } else { + if (PeekBytes(tmp, sizeof(UInt16) + sizeof(UInt32)) == DrError_OK) { + memcpy(pDataLen, tmp+sizeof(UInt16), sizeof(UInt32)); + } + } + + return status; +} + +// PeekNextProperty: Peeks at the next property in the bag and fills +// in its name and length to pEnumId and pDataLen respectively, +// not advancing the the read pointer. +// Returns a pointer to the contiguous property value (either in +// the buffer or copied to make it contiguous).If PeekNextProperty returns an error, +// the values of *pEnumId, +// *pDataLen and *data are undefined. +DrError DrMemoryReader::PeekNextProperty( + /* out */ UInt16 *pEnumId, + /* out */ UInt32 *pDataLen, + /* out */ const void **data) +{ + if (PeekNextPropertyTag(pEnumId, pDataLen) == DrError_OK) { + UInt32 hdrLen = sizeof(UInt16); + if (((*pEnumId) & PropLengthMask) == PropLength_Short) { + hdrLen += sizeof(UInt8); + } else { + hdrLen += sizeof(UInt32); + } + const BYTE *pProp; + if (PeekBytes(hdrLen + *pDataLen, &pProp) == DrError_OK) { + *data = pProp + hdrLen; + } + } + + return status; +} + + +// ReadNextKnownProperty: Reads the next property which is of +// known ID and length into a preallocated buffer, placing the read pointer +// after the property value. If ReadNextKnownProperty returns error, +// the position of the read pointer is undefined. +// returns DrError_InvalidProperty if the enum id or length don't match. +DrError DrMemoryReader::ReadNextKnownProperty( + UInt16 enumId, + UInt32 dataLen, + void *pDest) +{ + UInt16 realEnumId; + UInt32 realDataLen; + + if (ReadNextPropertyTag(&realEnumId, &realDataLen) == DrError_OK) { + if (realEnumId != enumId || realDataLen != dataLen) { + SetStatus(DrError_InvalidProperty); + } else { + ReadData(dataLen, pDest); + } + } + + return status; +} + +// PeekNextKnownProperty: Peeks at the next property which is of +// known ID and length into a preallocated buffer. +// returns DrError_InvalidProperty if the enum id or length don't match. +DrError DrMemoryReader::PeekNextKnownProperty( + UInt16 enumId, + UInt32 dataLen, + void *pDest) +{ + UInt16 enumIdActual; + UInt32 dataLenActual; + if (PeekNextPropertyTag(&enumIdActual, &dataLenActual) == DrError_OK) { + if (enumIdActual != enumId || dataLenActual != dataLen) { + SetStatus(DrError_InvalidProperty); + } else { + UInt32 hdrLen = sizeof(UInt16); + if ((enumId & PropLengthMask) == PropLength_Short) { + hdrLen += sizeof(UInt8); + } else { + hdrLen += sizeof(UInt32); + } + const BYTE *pProp; + if (PeekBytes(hdrLen + dataLen, &pProp) == DrError_OK) { + memcpy(pDest, pProp + hdrLen, dataLen); + } + } + } + + return status; +} + + +// Reads a string property that has been encoded with WriteStringProperty. +// If the string in the stream is longer than maxLength (not including null), DrError_StringTooLong is returned +DrError DrMemoryReader::ReadNextStringProperty(UInt16 enumId, /* out */ const char **ppStr, Size_t maxLength) +{ + UInt32 length; + UInt16 realEnumId; + + if (ReadNextPropertyTag(&realEnumId, &length) == DrError_OK) { + if (realEnumId != enumId) { + SetStatus(DrError_InvalidProperty); + } else if ((Size_t)length > maxLength) { + SetStatus(DrError_StringTooLong); + } else { + *ppStr = NULL; + + if (length > 0) { + BYTE *pBytes; + if (length > NumContiguousBytesRemaining() || pData[length-1] != (BYTE)0) { + // Must allocate temporary buffer wth enough room for null terminator + pBytes = ReserveTempMemory(length+1); + if (ReadBytes(pBytes, length) == DrError_OK) { + pBytes[length] = (BYTE)0; + } + } else { + ReadBytes(length, (const BYTE **)&pBytes); + } + + if (status == DrError_OK) { + *ppStr = (const char *)(const void *)pBytes; + } + } + } + } + return status; +} + +// Reads a string property that has been encoded with WriteStringProperty. +// If the string in the stream is longer than maxLength (not including null), DrError_StringTooLong is returned +DrError DrMemoryReader::ReadOrAppendNextStringProperty(bool fAppend, UInt16 enumId, /* out */ DrStr& strOut, Size_t maxLength) +{ + UInt32 length; + UInt16 realEnumId; + + if (ReadNextPropertyTag(&realEnumId, &length) == DrError_OK) { + if (realEnumId != enumId) { + SetStatus(DrError_InvalidProperty); + } else if (length == 0) { + // null string + if (fAppend) { + strOut.EnsureNotNull(); + } else { + strOut = NULLSTR; + } + } else if ((Size_t)(length) > maxLength) { + SetStatus(DrError_StringTooLong); + } else { + strOut.EnsureNotNull(); + size_t oldlen = strOut.GetLength(); + char *pOut = strOut.GetWritableAppendBuffer((size_t)length); + ReadBytes((BYTE *)pOut, (Size_t)length); + + if (status == DrError_OK) { + strOut.UpdateLength(oldlen + strlen(pOut)); + } else { + strOut.UpdateLength(oldlen); + } + } + } + return status; +} + + +/* Read a string property from the buffer into a preallocated buffer. + +If the embedded string is NULL, an empty string is returned. + +If the string in the stream is longer than buffLength (including null), DrError_StringTooLong is returned +*/ +DrError DrMemoryReader::ReadNextStringProperty(UInt16 enumId, char *pStr, Size_t buffLength) +{ + UInt32 length; + UInt16 realEnumId; + + if (ReadNextPropertyTag(&realEnumId, &length) == DrError_OK) { + if (realEnumId != enumId) { + SetStatus(DrError_InvalidProperty); + } else if (buffLength == 0) { + SetStatus(DrError_StringTooLong); + } else if (length == 0) { + // NULL string + pStr[0] = '\0'; + } else if ((Size_t)(length) > buffLength) { + // We allow reading 1 more byte than maxLength, hoping that the last byte is a null + SetStatus(DrError_StringTooLong); + } else { + pStr[length-1] = '\0'; + ReadBytes((BYTE *)pStr, length); + + if (status == DrError_OK) { + // If we read more than buffLength-1 bytes, the last byte has to be a null + if ((Size_t)length >= buffLength) { + if (pStr[length-1] != '\0') { + pStr[0] = '\0'; + SetStatus(DrError_StringTooLong); + } + } else { + pStr[length] = '\0'; + } + } + } + } + return status; +} + + +void DrMemoryReader::AllocTempMemBlock(size_t minLength) +{ + if (minLength < DEFAULT_TEMP_MEM_ALLOC_SIZE) { + minLength = DEFAULT_TEMP_MEM_ALLOC_SIZE; + } + + pFirstTempMemHeader = TempMemHeader::Alloc(minLength, pFirstTempMemHeader); + LogAssert(pFirstTempMemHeader != NULL); +} + +// Reserves a block of temporary memory that will be valid until this DrMemoryReader +// is destroyed. +BYTE *DrMemoryReader::ReserveTempMemory(size_t length) +{ + if (pFirstTempMemHeader == NULL || pFirstTempMemHeader->GetLength() < length) { + AllocTempMemBlock(length); + LogAssert (pFirstTempMemHeader != NULL && pFirstTempMemHeader->GetLength() >= length); + } + return pFirstTempMemHeader->ReserveData(length); +} + +// Reads data from blocks starting *after* the current block, without advancing the current read pointer. +// Returns DrError_EndOfStream if the stream reaches the end before all data can be read (partial data +// may still be written into the byte array). +DrError DrMemoryReader::FutureBlockPeekBytes(/* out */ void *pBytes, size_t length) +{ + (void)pBytes; + (void)length; + return SetStatus(DrError_EndOfStream); +} + +// Reads data from memory without advancing the current read pointer. +// Handles cross-block cases. +// Returns DrError_EndOfStream if the stream reaches the end before all data can be read (partial data +// is still wriiten into the byte array). +DrError DrMemoryReader::CrossBlockPeekBytes(/* out */ BYTE *pBytes, size_t length) +{ + Size_t nb; + PeekMemHeader *ph; + const BYTE *pb; + + if (status == DrError_OK && length > (size_t)0) { + if (pFirstPeekMemHeader != NULL) { + ph = pFirstPeekMemHeader; + pb = pData; + nb = NumContiguousBytesRemaining(); + + do { + if (nb > length) { + nb = length; + } + + if (nb > 0) { + memcpy(pBytes, pb, nb); + pBytes += nb; + length -= nb; + } + + if (length == 0) { + break; + } + + ph = ph->GetNext(); + if (ph != NULL) { + pb = ph->GetData(); + nb = ph->GetLength(); + } + } while (ph != NULL); + } + + if (length > 0) { + if (FutureBlockPeekBytes(pBytes, length) == DrError_PeekTooFar) { + // stream doesn't support peeking + while (length > 0 && AppendNextBlock() == DrError_OK) { + ph = pLastPeekMemHeader; + pb = ph->GetData(); + nb = ph->GetLength(); + if (nb > length) { + nb = length; + } + + if (nb > 0) { + memcpy(pBytes, pb, nb); + pBytes += nb; + length -= nb; + } + } + } + } + + } + + return status; +} + +// Reads data from memory, advancing the current read pointer. +// Handles cross-block cases. +// Returns DrError_EndOfStream if the stream reaches the end before all data can be read (partial data +// is still wriiten into the byte array). +DrError DrMemoryReader::CrossBlockReadBytes(/* out */ BYTE *pBytes, size_t length) +{ + while (length > (size_t)0) { + size_t nr = NumContiguousBytesRemaining(); + if (nr == (size_t)0) { + DrError ret = AdvanceToNextPeekBlock(); + if (ret != DrError_OK) { + return ret; + } + nr = NumContiguousBytesRemaining(); + } + LogAssert(nr > 0); + + size_t nb = (length > nr) ? nr : length; + memcpy(pBytes, pData, nb); + pData += nb; + pBytes += nb; + length -= nb; + } + return DrError_OK; +} + +DrError DrMemoryReader::ReadBytesIntoWriter(DrMemoryWriter *pWriter, Size_t length) +{ + DrError ret; + while (length > (size_t)0) { + size_t nr = NumContiguousBytesRemaining(); + if (nr == (size_t)0) { + ret = AdvanceToNextPeekBlock(); + if (ret != DrError_OK) { + return ret; + } + nr = NumContiguousBytesRemaining(); + } + LogAssert(nr > 0); + + size_t nb = (length > nr) ? nr : length; + ret = pWriter->WriteBytes(pData, nb); + if (ret != DrError_OK) { + return SetStatus(ret); + } + pData += nb; + length -= nb; + } + return DrError_OK; +} + +// Skips data in memory, advancing the current read pointer. +// Handles cross-block cases. +// Returns DrError_EndOfStream if the stream reaches the end before all data can be skipped (partial data +// is still skipped). +DrError DrMemoryReader::CrossBlockSkipBytes(size_t length) +{ + while (length > (size_t)0) { + size_t nr = NumContiguousBytesRemaining(); + if (nr == (size_t)0) { + DrError ret = AdvanceToNextPeekBlock(); + if (ret != DrError_OK) { + return ret; + } + nr = NumContiguousBytesRemaining(); + } + LogAssert(nr > 0); + size_t nb = (length > nr) ? nr : length; + pData += nb; + length -= nb; + } + return DrError_OK; +} + +// Appends the next block from the underlying stream to the list of peekable data blocks. +DrError DrMemoryReader::AppendNextBlock() +{ + const BYTE *pBytes; + Size_t length; + + if (status == DrError_OK) { + // Read the next block from the stream + if (SetStatus(ReadNextBlock(&pBytes, &length)) == DrError_OK) { + // We have a new block. Append it to the end of the list of peek blocks. + PeekMemHeader *ph; + ph = new PeekMemHeader(length, pLastPeekMemHeader, pBytes); + LogAssert(ph != NULL); + pLastPeekMemHeader = ph; + if (pFirstPeekMemHeader == NULL) { + // We are appending the very first peek block. + pFirstPeekMemHeader = ph; + + // The new block is now the current block. + // We need to update the base class current block pointers + pData = pBlockBase = pFirstPeekMemHeader->GetData(); + blockLength = pFirstPeekMemHeader->GetLength(); + + // Since this is the first peek block, the prior block was empty and there is + // no effect on uBlockBasePhysicalStreamPosition + } + } + } + + return status; +} + +// Advances the current block to the next available peek block, reading a new block if necessary +DrError DrMemoryReader::AdvanceToNextPeekBlock() +{ + if (status == DrError_OK) { + // remove current peek block, if any + if (pFirstPeekMemHeader != NULL) { + // We have at least one peek block. The first one is the "current" block. We need to remove it. + PeekMemHeader *ph = pFirstPeekMemHeader; + pFirstPeekMemHeader = ph->Detach(); + if (pFirstPeekMemHeader == NULL) { + // we removed the last peek block. we now have no current block. + pLastPeekMemHeader = NULL; + pData = pBlockBase = NULL; + blockLength = 0; + } else { + // the next peek block is now the current block + pData = pBlockBase = pFirstPeekMemHeader->GetData(); + blockLength = pFirstPeekMemHeader->GetLength(); + } + // Since we advanced to a new current block, we need to advance uBlockBasePhysicalStreamPosition + // by the length of the previous current block + uBlockBasePhysicalStreamPosition += ph->GetLength(); + delete ph; + } + + if (pFirstPeekMemHeader == NULL) { + // if no more peek blocks, add a new one + AppendNextBlock(); + } + + if (status == DrError_OK) { + LogAssert(pFirstPeekMemHeader != NULL); + } + } + return status; +} + + +// This method should be overidden by memory readers that know how to advance to a new block. +// Returns DrError_EndOfStream if there are no more blocks to be read. +// The default implementation always returns DrError_EndOfStream, which is appropriate for +// single-block readers. +DrError DrMemoryReader::ReadNextBlock(/* out */ const BYTE **pBytes, /* out */ Size_t *pLength) +{ + return SetStatus(DrError_EndOfStream); +} + +// This method should be overridden by memory readers that know how to advance to a new block. +// The implementation should return true if there are at least "length" readable bytes +// following the current block. +// +// The default implementation always returns false, which is appropriate for single-block readers +bool DrMemoryReader::FutureBlocksCanBeRead(size_t length) +{ + return false; +} + + +// Consumes the (BeginTag, desiredTagType) property and closing (EndTag, desiredTagType) property, and calls you back +// on parser->OnParseProperty() for each decoded property. Each property it calls you back on has only been peeked, +// so you will need to read or skip over it. If another BeginTag appears, you will be called back with that. +DrError DrMemoryReader::ReadAggregate(UInt16 desiredTagType, DrPropertyParser *parser, void *cookie) +{ + DrError err; + UInt16 beginTagType; + + if (ReadNextKnownProperty(Prop_Dryad_BeginTag, sizeof(UInt16), &beginTagType) != DrError_OK) + return status; + + if (beginTagType != desiredTagType) + return SetStatus(DrError_InvalidProperty); + + while (TRUE) + { + UInt16 propertyType; + UInt32 dataLen; + + if (PeekNextPropertyTag(&propertyType, &dataLen) != DrError_OK) + return status; + + // If we find an end tag, it must be for the begin tag we consumed + if (propertyType == Prop_Dryad_EndTag) + { + UInt16 endTagType; + + // Consume it + if (ReadNextUInt16Property(Prop_Dryad_EndTag, &endTagType) != DrError_OK) + return status; + + if (desiredTagType != endTagType) + return SetStatus(DrError_InvalidProperty); + + // We're done + return DrError_OK; + } + else + { + // This could be a begin tag - it's up to the caller to call ReadAggregate() + // or SkipNextPropertyOrAggregate() + err = parser->OnParseProperty(this, propertyType, dataLen, cookie); + if (err != DrError_OK) + return SetStatus(err); + } + } +} + +// If the next property is not a BeginTag, it simply skips it. +// If the next property is a BeginTag, then it skips everything through and including the EndTag, +// and handles recursion +// @TODO Limit recursion depth +DrError DrMemoryReader::SkipNextPropertyOrAggregate() +{ + UInt32 dataLen; + UInt16 propertyType; + UInt16 beginTagType; + + if (PeekNextPropertyTag(&propertyType, &dataLen) != DrError_OK) + return status; + + // If it's not a begin tag, just skip the property and return + if (propertyType != Prop_Dryad_BeginTag) + return SkipNextProperty(); + + // Read the begin tag type + if (ReadNextUInt16Property(Prop_Dryad_BeginTag, &beginTagType) != DrError_OK) + return status; + + // Skip until corresponding end tag + // If another BeginTag is encountered, recurse as appropriate + while (TRUE) + { + if (PeekNextPropertyTag(&propertyType, &dataLen) != DrError_OK) + return status; + + if (propertyType == Prop_Dryad_BeginTag) + { + if (SkipNextPropertyOrAggregate() != DrError_OK) + return status; + } + else if (propertyType == Prop_Dryad_EndTag) + { + UInt16 endTagType; + + if (ReadNextUInt16Property(Prop_Dryad_EndTag, &endTagType) != DrError_OK) + return status; + + if (endTagType != beginTagType) + return SetStatus(DrError_InvalidProperty); + + return DrError_OK; + } + else + { + if (SkipNextProperty() != DrError_OK) + return status; + } + } +} + + +DrError DrMemoryBufferWriter::SetBufferOffset(Size_t offset) +{ + if (status == DrError_OK && !MemoryWriterIsClosed()) { + FlushMemoryWriter(); + if (status == DrError_OK) { + Size_t uSize; + BYTE *p = (BYTE *)m_pBuffer->GetWriteAddress(offset, (Size_t)1, &uSize); + if (p == NULL) { + SetStatus(DrError_EndOfStream); + } else { + uBlockBasePhysicalStreamPosition = (UInt64)offset; + pData = pBlockBase = p; + blockLength = uSize; + } + } + } + + return status; +} + +DrError DrMemoryBufferWriter::FlushMemoryWriter() +{ + if (status == DrError_OK && !MemoryWriterIsClosed()) + { + if (m_pBuffer != NULL) + { + Size_t newPos = GetBufferOffset(); + if (m_fTruncateOnFlush || newPos > m_pBuffer->GetAvailableSize()) { + m_pBuffer->SetAvailableSize(newPos); + } + } + // DrMemoryWriter::FlushMemoryWriter() + } + return status; +} + +void DrMemoryBufferWriter::InternalFree() +{ + m_pBuffer = NULL; +} + +DrError DrMemoryBufferWriter::CloseMemoryWriter() +{ + if (!MemoryWriterIsClosed()) + { + DrMemoryWriter::CloseMemoryWriter(); // this will call FlushMemoryWriter + InternalFree(); + } + return status; +} + +DrMemoryBufferWriter::~DrMemoryBufferWriter() +{ + //======================================================================================================= + // NOTE: each subclass should duplicate this code if it implements CloseMemoryWriter or FlushMemoryWriter + // + if (!MemoryWriterIsClosed()) + { + // Close the memory writer. We cannot tolerate new failures in CloseMemoryWriter at this point because we are in a destructor! + // This will call FlushMemoryWriter() before closing, but because of virtual destructor unwinding, it will call the + // current class's implementation rather than the subclass's overriding implementation. So each subclass needs to implement this call in their + // virtual destructor if they implement CloseMemoryWriter or FlushMemoryWriter + MemoryWriterDestructorClose(); + } + // + // NOTE: End dublicated destructor code + //======================================================================================================= + InternalFree(); +} + +DrError DrMemoryBufferWriter::AdvanceToNextBlock() +{ + if (status == DrError_OK) { + Size_t newPos = GetBufferOffset(); + m_pBuffer->SetAvailableSize(newPos); + Size_t uNewBlockSize; + BYTE *p = (BYTE *)m_pBuffer->GetWriteAddressIfPossible(newPos, (Size_t)1, &uNewBlockSize); + if (p == NULL) { + SetStatus(DrError_EndOfStream); + } else { + + // Update DrMemoryWriter context for the new contiguous block + pBlockBase = p; + blockLength = uNewBlockSize; + pData = pBlockBase; + uBlockBasePhysicalStreamPosition = (UInt64)newPos; + } + } + return status; +} + +bool DrMemoryBufferWriter::FutureBlocksCanBeWritten(Size_t length) +{ + (void)length; + return true; +} + +void DrMemoryBufferReader::InternalFree() +{ + m_pBuffer = NULL; // decrefs the buffer + nextReadOffset = 0; +} + +DrError DrMemoryBufferReader::CloseMemoryReader() +{ + if (!MemoryReaderIsClosed()) + { + DrMemoryReader::CloseMemoryReader(); + InternalFree(); + } + return status; +} + +DrMemoryBufferReader::~DrMemoryBufferReader() +{ + InternalFree(); +} + +DrError DrMemoryBufferReader::SetBufferOffset(Size_t offset) +{ + if (status == DrError_OK && !MemoryReaderIsClosed()) { + Size_t uStreamSize = 0; + if (m_pBuffer != NULL) { + uStreamSize = m_pBuffer->GetAvailableSize(); + } + if (m_pBuffer == NULL || offset > uStreamSize) { + SetStatus(DrError_EndOfStream); + } else { + DiscardMemoryReaderContext(); + const BYTE *p = NULL; + Size_t uSize = 0; + if (offset < uStreamSize) { + p = (const BYTE *)m_pBuffer->GetReadAddress(offset, &uSize); + LogAssert (p != NULL && uSize > 0); + } + uBlockBasePhysicalStreamPosition = (UInt64)offset; + nextReadOffset = offset + uSize; + pData = pBlockBase = (BYTE *)p; + blockLength = uSize; + } + } + return status; +} + + +DrError DrMemoryBufferReader::ReadNextBlock(/* out */ const BYTE **ppBytes, /* out */ Size_t *pLength) +{ + *ppBytes = NULL; + *pLength = 0; + if (status == DrError_OK) { + Size_t uSize; + if (m_pBuffer == NULL || nextReadOffset >= m_pBuffer->GetAvailableSize()) { + SetStatus(DrError_EndOfStream); + } else { + const BYTE *p = (const BYTE *)m_pBuffer->GetReadAddress(nextReadOffset, &uSize); + LogAssert (p != NULL && uSize > 0); + *ppBytes = p; + *pLength = uSize; + nextReadOffset += uSize; + } + } + return status; +} + +DrError DrMemoryBufferReader::FutureBlockPeekBytes(/* out */ void *pBytes, Size_t length) +{ + if (status == DrError_OK) { + Size_t cbAvailable = m_pBuffer->GetAvailableSize(); + if (m_pBuffer == NULL || nextReadOffset > cbAvailable || cbAvailable - nextReadOffset < length) { + SetStatus(DrError_EndOfStream); + } else { + m_pBuffer->Read(nextReadOffset, pBytes, length); + } + } + return status; +} + +bool DrMemoryBufferReader::FutureBlocksCanBeRead(Size_t length) +{ + if (status != DrError_OK) { + return false; + } + Size_t cbAvailable = m_pBuffer->GetAvailableSize(); + return (m_pBuffer != NULL && nextReadOffset <= cbAvailable && cbAvailable - nextReadOffset >= length); +} diff --git a/DryadVertex/VertexHost/system/classlib/src/DrNodeAddress.cpp b/DryadVertex/VertexHost/system/classlib/src/DrNodeAddress.cpp new file mode 100644 index 0000000..233191a --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/DrNodeAddress.cpp @@ -0,0 +1,973 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrCommon.h" + +#pragma warning (push) +#pragma warning (disable:4365) +#include +#pragma warning (pop) + +#pragma prefast(disable:24002, "struct sockaddr not ipv6 compatible") + +#pragma unmanaged + +DrLastAccessTable *g_pDrLastAccessTable; + +void DrInitLastAccessTable() +{ + g_pDrLastAccessTable = new DrLastAccessTable(); + LogAssert(g_pDrLastAccessTable != NULL); +} + +// returns false if not a valid ip/node address +bool DrNodeAddress::Set(const struct sockaddr_in *pSockAddr, Size_t addrLen) +{ + if (pSockAddr == NULL) { + return false; + } + + if (addrLen < sizeof(struct sockaddr_in)) { + return false; + } + + if (pSockAddr->sin_family != AF_INET) { + return false; + } + + Set(pSockAddr->sin_addr, ntohs(pSockAddr->sin_port)); + + return true; +} + +// Looks up a host name using DNS +// Note that this is a blocking request +// Returns up to addressBuffLen entries. If there are more entries than this, the list is truncated without error. +DrError DrNodeAddress::LookupHostName( + const char *pszHostName, + /* out */ DrIpAddress *pAddressBuff, + UInt32 addressBuffLen, + /* out */ UInt32 *pNumReturnedAddresses) +{ + struct addrinfo hints; + struct addrinfo *pResults = NULL; + memset(&hints, 0, sizeof(hints)); + + *pNumReturnedAddresses = 0; + + hints.ai_family = PF_INET; + int ret = EAI_AGAIN; + + while (ret == EAI_AGAIN) { + ret = getaddrinfo(pszHostName, NULL, &hints, &pResults); + + if (ret == 0 && pResults == NULL) { + ret = EAI_NODATA; + } + } + + if (ret != 0) { + goto done; + } + + UInt32 n = 0; + struct addrinfo *pResultEntry = pResults; + while (pResultEntry != NULL && addressBuffLen > 0) { + if (pResultEntry->ai_family == PF_INET && pResultEntry->ai_addr != NULL && pResultEntry->ai_addrlen == sizeof(sockaddr_in)) { + const struct sockaddr_in *paddr = (const struct sockaddr_in *)(const void *)(pResultEntry->ai_addr); + DrIpAddress addr = ntohl(paddr->sin_addr.S_un.S_addr); + *pAddressBuff = addr; + pAddressBuff++; + addressBuffLen--; + n++; + } + pResultEntry = pResultEntry->ai_next; + } + + if (n == 0) { + ret = EAI_NODATA; + } + + *pNumReturnedAddresses = n; + ret = 0; + +done: + if (pResults != NULL) { + freeaddrinfo(pResults); + } + + switch(ret) { + case EAI_NODATA: + ret = DrError_HostNotFound; + break; + } + + return ret; +} + + +// Converts the contained IP/port address to a string of the form "#.#.#.#:port". If the contained port number matches defaultPort, the +// port number is not included in the string. +DrStr& DrNodeAddress::AppendToString(DrStr& strOut, DrPortNumber defaultPort) const +{ + if (m_wPort == defaultPort) { + strOut.AppendF("%u.%u.%u.%u", + (DrIpAddress ) m_ina.S_un.S_un_b.s_b1, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b2, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b3, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b4); + } else { + strOut.AppendF("%u.%u.%u.%u:%u", + (DrIpAddress ) m_ina.S_un.S_un_b.s_b1, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b2, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b3, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b4, + (DrIpAddress ) m_wPort); + } + + return strOut; +} +/* JC +// Converts the contained IP/port address to a string of the form "#.#.#.#:port". If the contained port number matches defaultPort, the +// port number is not included in the string. +DrWStr& DrNodeAddress::AppendToString(DrWStr& strOut, DrPortNumber defaultPort) const +{ + if (m_wPort == defaultPort) { + strOut.AppendF(L"%u.%u.%u.%u", + (DrIpAddress ) m_ina.S_un.S_un_b.s_b1, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b2, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b3, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b4); + } else { + strOut.AppendF(L"%u.%u.%u.%u:%u", + (DrIpAddress ) m_ina.S_un.S_un_b.s_b1, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b2, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b3, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b4, + (DrIpAddress ) m_wPort); + } + + return strOut; +} +*/ + +// Converts the contained IP/port address to a string of the form "#.#.#.#:port". If the contained port number matches defaultPort, the +// port number is not included in the string. +// buffSize must be at least 22 or DrError_StringTooLong is returned. +DrError DrNodeAddress::ToAddressPortString(char *pBuffer, Size_t buffSize, DrPortNumber defaultPort) const +{ + HRESULT hr; + if (m_wPort == defaultPort) { + hr =StringCbPrintfA(pBuffer, buffSize, "%u.%u.%u.%u", + (DrIpAddress ) m_ina.S_un.S_un_b.s_b1, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b2, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b3, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b4); + } else { + hr =StringCbPrintfA(pBuffer, buffSize, "%u.%u.%u.%u:%u", + (DrIpAddress ) m_ina.S_un.S_un_b.s_b1, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b2, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b3, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b4, + (DrIpAddress ) m_wPort); + } + + return SUCCEEDED(hr) ? DrError_OK : DrError_StringTooLong; +} + +// Converts the contained IP/port address to a string of the form "#.#.#.#:port". If the contained port number matches defaultPort, the +// port number is not included in the string. +// buffSize must be at least 22 or DrError_StringTooLong is returned. +DrError DrNodeAddress::ToAddressPortString(WCHAR *pBuffer, Size_t buffSize, DrPortNumber defaultPort) const +{ + HRESULT hr; + if (m_wPort == defaultPort) { + hr =StringCbPrintfW(pBuffer, buffSize, L"%u.%u.%u.%u", + (DrIpAddress ) m_ina.S_un.S_un_b.s_b1, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b2, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b3, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b4); + } else { + hr =StringCbPrintfW(pBuffer, buffSize, L"%u.%u.%u.%u:%u", + (DrIpAddress ) m_ina.S_un.S_un_b.s_b1, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b2, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b3, + (DrIpAddress ) m_ina.S_un.S_un_b.s_b4, + (DrIpAddress ) m_wPort); + } + + return SUCCEEDED(hr) ? DrError_OK : DrError_StringTooLong; +} + + +static DrError InternalParseHostPortName( + Size_t hostLength, + /*out */ DrPortNumber *pPort, + const char *pszName, + DrPortNumber defaultPort, + UInt32 *pInstanceNumOut) +{ + LogAssert(pPort != NULL); + + if (hostLength == 0) { + return DrError_InvalidParameter; + } + + Size_t instanceLength = hostLength; + UInt32 uInstance = 0; + if (pszName[hostLength] == '!') { + char c; + for (instanceLength = hostLength + 1; (c = pszName[instanceLength]) != '\0' && c != ':'; instanceLength++) { + if (c >= '0' && c <= '9') { + uInstance = (10 * uInstance) + (UInt32)(c - '0'); + if (uInstance > 65535) { + return DrError_InvalidParameter; + } + } else { + return DrError_InvalidParameter; + } + } + } + + DrPortNumber finalPort; + + if (pszName[instanceLength] == ':') { + const char *pszPort = pszName + instanceLength + 1; + DrError err = DrStringToPortNumber(pszPort, &finalPort); + if (err != DrError_OK) { + return err; + } + } else { + finalPort = defaultPort; + } + + if (uInstance != 0 && finalPort != DrInvalidPortNumber && finalPort != DrAnyPortNumber) { + finalPort = (DrPortNumber)(finalPort + uInstance); + } + + *pPort = finalPort; + if (pInstanceNumOut != NULL) { + *pInstanceNumOut = uInstance; + } + return DrError_OK; +} + +inline static Size_t InternalHostLength(const char *pszName) +{ + Size_t hostLength; + + LogAssert(pszName != NULL); + + for (hostLength = 0; pszName[hostLength] != '\0' && pszName[hostLength] != ':' && pszName[hostLength] != '!'; hostLength++) { + // Just counting + } + + return hostLength; +} + +// Parses a name in the forms: +// "#.#.#.#" +// "dns-name" +// "#.#.#.#!instance-num" +// "dns-name!instance-num" +// "#.#.#.#:port" +// "dns-name:port" +// "#.#.#.#!instance-num:port" +// "dns-name!instance-num:port" +// and splits out the host name and port. +// If ":port" is missing, uses the default port. +// if instance-num is present, it is added to the final port number. +// Returns DrError_InvalidParameter if the string is malformed. +DrError DrNodeAddress::ParseHostPortName( + /* out */ char *pHostNameBuffer, + Size_t buffLen, + /*out */ DrPortNumber *pPort, + const char *pszName, + DrPortNumber defaultPort, + UInt32 *pInstanceNumOut) +{ + LogAssert(pHostNameBuffer != NULL); + Size_t hostLength = InternalHostLength(pszName); + + if (hostLength == 0 || buffLen < hostLength+1) { + return DrError_InvalidParameter; + } + + memcpy(pHostNameBuffer, pszName, hostLength); + pHostNameBuffer[hostLength] = '\0'; + + return InternalParseHostPortName( + hostLength, + pPort, + pszName, + defaultPort, + pInstanceNumOut); +} + +// Parses a name in the form "#.#.#.#:port" or "dns-name:port" and splits out the host name and port. +// If ":port" is missing, uses the default port. +// strOut is replaced with the parsed host name +// Returns DrError_InvalidParameter if the string is malformed. +DrError DrNodeAddress::ParseHostPortName(DrStr& strOut, /*out */ DrPortNumber *pPort, const char *pszName, DrPortNumber defaultPort, UInt32 *pInstanceNumOut) +{ + Size_t hostLength = InternalHostLength(pszName); + + if (hostLength == 0) { + return DrError_InvalidParameter; + } + + strOut.Set(pszName, hostLength); + + return InternalParseHostPortName( + hostLength, + pPort, + pszName, + defaultPort, + pInstanceNumOut); +} + +// Parses a stringified IP address in the form "#.#.#.#" into a host-order IP address. +// Returns DrError_InvalidParameter if the string is malformed. +DrError DrNodeAddress::ParseIpAddress(const char *pszIpAddress, /* out */ DrIpAddress *pIpAddr) +{ + DrIpAddress ipAddr = 0; + for (int i = 0; i < 4; i++) { + if (i > 0) { + ipAddr = ipAddr << 8; + if (*pszIpAddress != '.') { + return DrError_InvalidParameter; + } + pszIpAddress++; + } + if (*pszIpAddress < '0' || *pszIpAddress > '9') { + return DrError_InvalidParameter; + } + UInt32 uByte = 0; + for (int j = 0; j < 3; j++) { + char c = *pszIpAddress; + if (c < '0' || c > '9') { + break; + } + uByte = 10 * uByte + (UInt32)(c - '0'); + if (uByte > 255) { + return DrError_InvalidParameter; + } + pszIpAddress ++; + } + ipAddr = ipAddr | uByte; + } + if (*pszIpAddress != '\0') { + return DrError_InvalidParameter; + } + + *pIpAddr = ipAddr; + return DrError_OK; +} + +// Parses a name in the form "#.#.#.#:port" or "dns-name:port" and resolves it to an address. +// If ":port" is missing, uses the default port. +// If there is more than one address associated with a DNS name, uses the first one. +// Returns DrError_InvalidParameter if the string is malformed. +// Note that this method may block for DNS resolution if a DNS name is used. +DrError DrNodeAddress::Set(const char *pszName, DrPortNumber defaultPort) +{ + char buff[k_MaxHostNameLength+1]; + DrPortNumber port; + DrIpAddress ipAddr; + + DrError err = ParseHostPortName(buff, sizeof(buff), &port, pszName, defaultPort); + if (err != DrError_OK) { + return err; + } + + // first try to parse it as a numeric IP address + err = ParseIpAddress(buff, &ipAddr); + if (err == DrError_OK) { + Set(ipAddr, port); + } else { + UInt32 nRet = 0; + err = LookupHostName(buff, &ipAddr, 1, &nRet); + if (err != DrError_OK) { + return err; + } + LogAssert(nRet == 1); + Set(ipAddr, port); + } + + return DrError_OK; +} + +// This call may block for DNS +// It resolves the specified host name (with optional ":port") to a list of IP addresses and *appends* those to this DrNodeAddressList, filling in +// the port number for each. +// Note that since this request appends to the existing list, you must Clear() or Discard() the list before you make this +// call if you want the results to replace the existing set. +// Returns DrError_HostNotFound if no hosts match the name. +DrError DrNodeAddressList::ResolveHostName(const char *pszHostName, DrPortNumber defaultPort) +{ + char buff[k_MaxHostNameLength+1]; + DrPortNumber port; + DrIpAddress ipAddrs[32]; + UInt32 nResults = 0; + + DrError err = DrNodeAddress::ParseHostPortName(buff, sizeof(buff), &port, pszHostName, defaultPort); + if (err != DrError_OK) { + return err; + } + + // first try to parse it as a numeric IP address + err = DrNodeAddress::ParseIpAddress(buff, &(ipAddrs[0])); + if (err == DrError_OK) { + nResults = 1; + } else { + err = DrNodeAddress::LookupHostName(buff, ipAddrs, 32, &nResults); + if (err != DrError_OK) { + return err; + } + LogAssert(nResults > 0); + } + + GrowTo(nResults); + + for (UInt32 i = 0; i < nResults; i++) { + DrNodeAddress a; + a.Set(ipAddrs[i], port); + AddEntry(&a); + } + + return DrError_OK; +} + +// This call may block for DNS +// It resolves the host name to a list of IP addresses and *appends* those to the specified DrNodeAddressList, filling in +// the port number for each. +// Note that since this request appends to the existing list, you must Clear() or Discard() the list before you make this +// call if you want the results to replace the existing set. +// Returns DrError_HostNotFound if no hosts match the name. +DrError DrHostAndPort::ResolveToAddresses(DrNodeAddressList *pAddresses) +{ + DrError err = pAddresses->ResolveHostName(m_pszHostName, m_portNumber); + return err; +} + +/* JC +DrError DrPodNameToFaultDomain(__in PCSTR pszPodName, __out XsFaultDomain *pFaultDomainOut) +{ + DrError err = DrError_OK; + *pFaultDomainOut = 0; + if (pszPodName == NULL || pszPodName[0] == '\0') { + // just return 0 + } else if (_strnicmp(pszPodName, "pod", 3) == 0) { + err = DrStringToUInt16(pszPodName+3, pFaultDomainOut); + } else { + *pFaultDomainOut = 0; + err = DrError_InvalidParameter; + } + return err; +} + +PCSTR DrFaultDomainToPodName( __in XsFaultDomain faultDomain) +{ + DrStr32 strPod("pod"); + strPod.AppendUInt32((UInt32)faultDomain); + return g_DrInternalizedStrings.InternalizeString(strPod); +} + +DrError DrHostAndPort::ReplaceFaultDomainFromPod() +{ + m_faultDomain = 0; + DrError err = DrPodNameToFaultDomain(m_pszPodName, &m_faultDomain); + m_fValidFaultDomain = (err == DrError_OK); + return err; +} +*/ + +class DrHostAndPortParser : public DrPropertyParser +{ +public: + DrHostAndPortParser(DrHostAndPort *pEntry) + { + m_pEntry = pEntry; + } + + virtual ~DrHostAndPortParser() + { + } + + virtual DrError OnParseProperty(DrMemoryReader *reader, UInt16 enumID, UInt32 dataLen, void *cookie) + { + if (reader->GetStatus() == DrError_OK) { + switch(enumID) { + case Prop_Dryad_Port: + { + UInt16 port; + + if (reader->ReadNextUInt16Property(Prop_Dryad_Port, &port) == DrError_OK) { + m_pEntry->SetPort(port); + } + } + break; + + case Prop_Dryad_ShortHostName: + case Prop_Dryad_LongHostName: + { + const char *pszHost; + if (reader->ReadNextStringProperty(enumID, &pszHost) == DrError_OK) { + m_pEntry->SetHostName(pszHost); + } + } + break; + +/* case Prop_Dryad_UpgradeDomain: + { + UInt16 upgradeDomain; + + if (reader->ReadNextUInt16Property(Prop_Dryad_UpgradeDomain, &upgradeDomain) == DrError_OK) { + m_pEntry->SetUpgradeDomain(upgradeDomain); + } + } + break; +*/ + // NOTE: current code serializes Prop_Dryad_PodName. This will eventually be deprecated in favor + // of Prop_Dryad_FaultDomain. Here we accept either in preparation for future deprecation. + case Prop_Dryad_PodName: + { + const char *pszPodName; + if (reader->ReadNextStringProperty(Prop_Dryad_PodName, &pszPodName) == DrError_OK) { + m_pEntry->SetPodNameNoFaultDomainUpdate(pszPodName); + } + } + break; + +/*JC + // NOTE: old versions don't serialize this way, but eventually in XStore we will switch to this after all unserializers have been updated. + // To allow for that while maintaining compatibility with the current API, we will reencode the fault domain into pod name form if + // the pod is not provided. + case Prop_Dryad_FaultDomain: + { + UInt16 uFaultDomain = 0; + DrError err = reader->ReadNextUInt16Property(Prop_Dryad_FaultDomain, &uFaultDomain); + if (err == DrError_OK) { + m_pEntry->SetFaultDomainNoPodUpdate(uFaultDomain); + } + } + break; +*/ + + default: + reader->SkipNextPropertyOrAggregate(); + break; + } + + } + + return reader->GetStatus(); + } + +private: + DrHostAndPort *m_pEntry; +}; + +DrError DrHostAndPort::Unserialize(DrMemoryReader *pReader) +{ +//JC // We keep track of whether pod and/or fault domain were explicitly provided. + // We keep track of whether pod was explicitly provided. + +//JC m_fValidFaultDomain = false; + m_fValidPod = false; +//JC m_faultDomain = 0; + m_pszPodName = NULL; + + DrHostAndPortParser p(this); + DrError err = pReader->ReadAggregate(DrTag_DrHostAndPort, &p, NULL); + +/*JC + // If fault domain was provided, but pod was not, synthesize the pod. Otherwise, if fault domain + // was not provided, synthesize it. + if (m_fValidFaultDomain) { + if (m_fValidPod) { + ReplacePodFromFaultDomain(); + } + } else { + ReplaceFaultDomainFromPod(); + } + m_fValidPod = true; +*/ + + return err; +} + +DrError DrHostAndPort::Serialize(DrMemoryWriter *pWriter) const +{ + pWriter->WriteUInt16Property(Prop_Dryad_BeginTag, DrTag_DrHostAndPort); + + // BUGBUG: should have been LONGATOM. We serialize a long version if length >= 255 + { + + size_t length = 0; + if (m_pszHostName != NULL) { + length = strlen(m_pszHostName); + } + if (length >= (size_t)_UI8_MAX) { + // BUGBUG: breaks compat for long host names, but no worse than before + pWriter->WriteLongStringPropertyWithLength(Prop_Dryad_LongHostName, m_pszHostName, length); + } else { + pWriter->WriteShortStringPropertyWithLength(Prop_Dryad_ShortHostName, m_pszHostName, length); + } + } + + pWriter->WriteUInt16Property(Prop_Dryad_Port, m_portNumber); +/*JC if (m_fValidFaultDomain) { + pWriter->WriteUInt16Property(Prop_Dryad_FaultDomain, m_faultDomain); + } +*/ + if (m_fValidPod && m_pszPodName != NULL) { + // Note: Prop_Dryad_PodName will eventually be deprecated in favor of Prop_Dryad_FaultDomain. We will continue to serialize + // this way until all unserializers have been updated. + pWriter->WriteLongStringProperty(Prop_Dryad_PodName, m_pszPodName); + } +//JC pWriter->WriteUInt16Property(Prop_Dryad_UpgradeDomain, m_upgradeDomain); + + pWriter->WriteUInt16Property(Prop_Dryad_EndTag, DrTag_DrHostAndPort); + + return pWriter->GetStatus(); +} + +// If this is an update, preserve the ordering of everything already in the list except for the primary, which must be at the top +// If forceReordering is true, overwrite the existing ordering with the new one (but primary must be first). This is called +// when we DemoteHost() and are putting the new table back +DrHostNameList& DrHostNameList::Set(const DrHostNameList& other, bool forceReordering) +{ + // Sice strings are internalized, we can just copy the pointers + GrowTo(other.m_numEntries); + if (other.m_numEntries > 0) + { + if (m_numEntries == 0 || forceReordering) + { + // There wasn't anything in the existing list, or we want to force reordering, so do a direct copy + for (UInt32 i = 0; i < other.m_numEntries; i++) + { + m_pMultipleHosts[i] = other.m_pMultipleHosts[i]; + } + + if ((other.m_primary >= other.m_numEntries) || (other.m_primary == INVALID_PRIMARY_HOST)) + { + // Invalid primary, ignore + m_primary = INVALID_PRIMARY_HOST; + } + else + { + if (other.m_primary != 0) + { + // Move primary to top + DrHostAndPort temp = m_pMultipleHosts[0]; + m_pMultipleHosts[0] = m_pMultipleHosts[other.m_primary]; + m_pMultipleHosts[other.m_primary] = temp; + } + + m_primary = 0; + } + } + else + { + // We're updating on top of something we already have + // Ensure that order is preserved, except that if the primary changes, it is always at the top + + // We will change this below when we encounter the primary + // Right now we don't know what index this will be in the new list + m_primary = INVALID_PRIMARY_HOST; + + // Remove hosts which are not exists in new list + for (UInt32 i = 0; i < m_numEntries; ++i) + { + UInt32 j; + + // Does the entry available in both local and new lists? + for (j = 0; j < other.m_numEntries; ++j) + { + if (m_pMultipleHosts[i] == other.m_pMultipleHosts[j]) + break; + } + + // If host is not found in new list, remove it locally + if (j == other.m_numEntries) + { + --m_numEntries; + + for (j = i; j < m_numEntries; ++j) + m_pMultipleHosts[j] = m_pMultipleHosts[j+1]; + + i--; + } + } + + // New entries go at the end + UInt32 newEntries = m_numEntries; + + // Add new hosts + for (UInt32 i = 0; i < other.m_numEntries; i++) + { + UInt32 j; + + // Does the entry already exist? + for (j = 0; j < m_numEntries; j++) + { + if (m_pMultipleHosts[j] == other.m_pMultipleHosts[i]) + break; + } + + if (j < m_numEntries) + { + // Found - update fields that aren't checked by the == operator above (e.g. upgrade domain) + m_pMultipleHosts[j] = other.m_pMultipleHosts[i]; + + if (other.m_primary == i) + m_primary = j; + } + else + { + LogAssert(newEntries < m_numAllocated); + + // New entry, append to end of list + m_pMultipleHosts[newEntries] = other.m_pMultipleHosts[i]; + + // Convert primary from old index to new index + if (other.m_primary == i) + m_primary = newEntries; + + newEntries++; + } + } + + // Now move primary to top of list + if ((m_primary != INVALID_PRIMARY_HOST) && (m_primary != 0)) + { + DrHostAndPort temp = m_pMultipleHosts[0]; + m_pMultipleHosts[0] = m_pMultipleHosts[m_primary]; + m_pMultipleHosts[m_primary] = temp; + + m_primary = 0; + } + } + } + + m_numEntries = other.m_numEntries; + return *this; +} + +DrError DrHostNameList::Serialize(DrMemoryWriter *pWriter) const{ + pWriter->WriteUInt16Property(Prop_Dryad_BeginTag, DrTag_DrHostNameList); + + pWriter->WriteUInt32Property(Prop_Dryad_NumEntries, m_numEntries); + for (UInt32 i = 0; i < m_numEntries; i++) { + m_pMultipleHosts[i].Serialize(pWriter); + } + pWriter->WriteUInt32Property(Prop_Dryad_PrimaryHost, m_primary); + //pWriter->WriteUInt32Property(Prop_Dryad_NextHost, m_nextHost); + + pWriter->WriteUInt16Property(Prop_Dryad_EndTag, DrTag_DrHostNameList); + + return pWriter->GetStatus(); +} +DrError DrHostNameList::OnParseProperty(DrMemoryReader *reader, UInt16 property, UInt32 dataLen, void *cookie){ + if (reader->GetStatus() == DrError_OK) { + switch(property) { + case Prop_Dryad_BeginTag: + { + UInt16 tagId; + if (reader->PeekNextUInt16Property(Prop_Dryad_BeginTag, &tagId) == DrError_OK) { + switch(tagId) { + case DrTag_DrHostAndPort: + { + DrHostAndPort *pHost = AddEntry(); + pHost->Unserialize(reader); + } + break; + + default: + reader->SetStatus(DrError_InvalidProperty); + break; + } + } + } + break; + case Prop_Dryad_NumEntries: + { + UInt32 numEntries; + if (reader->ReadNextUInt32Property(Prop_Dryad_NumEntries, &numEntries) == DrError_OK) { + GrowTo(numEntries); + } + } + break; + case Prop_Dryad_PrimaryHost: + { + UInt32 primary; + if(reader->ReadNextUInt32Property(Prop_Dryad_PrimaryHost, &primary) == DrError_OK){ + SetPrimary(primary); + } + } + break; + case Prop_Dryad_NextHost: + { + UInt32 nextHost; + if(reader->ReadNextUInt32Property(Prop_Dryad_NextHost, &nextHost) == DrError_OK){ + //SetNextHost(nextHost); + } + } + break; + default: + reader->SkipNextPropertyOrAggregate(); + break; + } + + } + return reader->GetStatus(); +} + + +void DrHostNameList::SelectOneHost(DrHostAndPort &host, bool wantPrimary) +{ + // The primary (or next node to try) is always the first entry in the list + host.Set(m_pMultipleHosts[0]); +} + + +// This call may block for DNS +// It resolves the list of host names to a list of IP addresses and *appends* those to the specified DrNodeAddressList, filling in +// the port number for each. +// Note that since this request appends to the existing list, you must Clear() or Discard() the list before you make this +// call if you want the results to replace the existing set. +// Returns DrError_HostNotFound if no hosts match the name. +DrError DrHostNameList::ResolveToAddresses(DrNodeAddressList *pAddresses) +{ + bool errSet = false; + DrError err = DrError_HostNotFound; + for (UInt32 i = 0; i < m_numEntries; i++) { + DrError err2 = m_pMultipleHosts[i].ResolveToAddresses(pAddresses); + // we succeed this call if at least one host was resolved successfully. Otherwise, we fail with the first + // error returned. + if (!errSet || err2 == DrError_OK) { + err = err2; + errSet = true; + } + } + + return err; +} + +DrError DrHostNameList::ResolveOneHostToAddresses(DrNodeAddressList *pAddressses, bool wantPrimary, DrHostAndPort &host){ + SelectOneHost(host, wantPrimary); + return host.ResolveToAddresses(pAddressses); +} + + +DrLastAccessTable::DrLastAccessTable() +{ + memset(m_head, 0, sizeof(m_head)); +} + +DrLastAccessEntry* DrLastAccessTable::FindOrCreate(const DrNodeAddress& nodeAddress) +{ + DrLastAccessEntry *entry = Find(nodeAddress); + if (entry != NULL) + return entry; + + UInt32 bucket = nodeAddress.Hash() % k_numLastAccessTableBuckets; + entry = new DrLastAccessEntry(); + entry->m_nodeAddress = nodeAddress; + entry->m_nextHash = m_head[bucket]; + m_head[bucket] = entry; + return entry; +} + +DrLastAccessEntry* DrLastAccessTable::Find(const DrNodeAddress& nodeAddress) +{ + UInt32 bucket = nodeAddress.Hash() % k_numLastAccessTableBuckets; + + for (DrLastAccessEntry* search = m_head[bucket]; search != NULL; search = search->m_nextHash) + { + if (search->m_nodeAddress == nodeAddress) + return search; + } + + return NULL; +} + +void DrLastAccessTable::UpdateSuccess(const DrNodeAddress& nodeAddress) +{ + Lock(); + + DrLastAccessEntry* entry = FindOrCreate(nodeAddress); + + entry->m_nextAttemptAllowed = DrTimeStamp_LongAgo; + entry->m_delayTime = 0; + entry->m_lastError = DrError_OK; + + Unlock(); +} + +// Send failure. +// Returns true if we were already at the maximum allowed delay value +bool DrLastAccessTable::UpdateFailure(const DrNodeAddress& nodeAddress, DrError error) +{ + Lock(); + + DrLastAccessEntry* entry = FindOrCreate(nodeAddress); + bool wasMax; + + wasMax = (entry->m_delayTime >= k_maxDelayedSendInterval); + + if (!wasMax && (error == DrError_TalkToPrimaryServer)) + { + entry->m_delayTime = entry->m_delayTime*2 + k_initialDelayedSendTimeInterval; + if (entry->m_delayTime > k_maxDelayedSendInterval) + entry->m_delayTime = k_maxDelayedSendInterval; + } + + entry->m_nextAttemptAllowed = DrGetCurrentTimeStamp() + entry->m_delayTime; + entry->m_lastError = error; + + Unlock(); + return wasMax; +} + +// When can we send to this node address? +// *when will be DrTimeStamp_LongAgo if there is no delay associated +DrError DrLastAccessTable::GetDelay(const DrNodeAddress& nodeAddress, DrTimeStamp* when) +{ + DrError error; + Lock(); + + DrLastAccessEntry* entry = Find(nodeAddress); + if (entry == NULL) + { + *when = DrTimeStamp_LongAgo; + error = DrError_OK; + } + else + { + *when = entry->m_nextAttemptAllowed; + error = entry->m_lastError; + } + + Unlock(); + + return error; +} diff --git a/DryadVertex/VertexHost/system/classlib/src/DrRefCounter.cpp b/DryadVertex/VertexHost/system/classlib/src/DrRefCounter.cpp new file mode 100644 index 0000000..bd4084e --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/DrRefCounter.cpp @@ -0,0 +1,48 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrCommon.h" +#include + +#pragma unmanaged + +static volatile LONGLONG g_DrRefUniqueObjectCounter = 0; + + +#if defined(_AMD64_) + +UInt64 GetUniqueObjectID() +{ + LONGLONG value = ::InterlockedIncrement64(&g_DrRefUniqueObjectCounter); + return (UInt64) value; +} + +#else + +// We don't use non-x64 builds, but for completeness sake we will generate a +// a 32 bit version. The downside is that unique object IDs would wrap around at 4 Billion if this build were used, +// but even that isn't a problem because we'll never have so many active object at a given time. +UInt64 GetUniqueObjectID() +{ + int value = ::InterlockedIncrement((volatile LONG*) &g_DrRefUniqueObjectCounter); + return (UInt64) value; +} + +#endif \ No newline at end of file diff --git a/DryadVertex/VertexHost/system/classlib/src/DrStringUtil.cpp b/DryadVertex/VertexHost/system/classlib/src/DrStringUtil.cpp new file mode 100644 index 0000000..186e3ea --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/DrStringUtil.cpp @@ -0,0 +1,2490 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include + +#pragma unmanaged + +// +// Add contents of provided string to the DrStr +// +DrStr& DrStr::Append(const WCHAR *psz, size_t len, UINT codePage) +{ + // + // Grow DrStr to accomodate extra length + // + GrowTo(m_stringLen + len); + + // + // If there's anything to append + // + if (len != 0) + { + LogAssert(m_nbBuffer > m_stringLen + 1); + Size_t n = m_nbBuffer - m_stringLen - 1; + + // + // Convert WCHAR to multibyte string + // + int ret = WideCharToMultiByte( + codePage, + 0, + psz, + (int)len, + m_pBuffer + m_stringLen, + (int)n, + NULL, + NULL); + if (ret == 0) + { + // + // If failed, check that there wasn't enough room + // todo: what if there is some other failure? This would fail the assertion + // + DrError err = DrGetLastError(); + LogAssert(err == DrErrorFromWin32(ERROR_INSUFFICIENT_BUFFER)); + + // + // The result wouldn't fit. Rather than iterating on heap allocation, we'll just ask windows to calculate the needed size: + // + int ret2 = WideCharToMultiByte( + codePage, + 0, + psz, + (int)len, + NULL, + 0, + NULL, + NULL); + LogAssert(ret2 > 1); + + + // + // Grow to required size and reconvert + // + GrowTo(m_stringLen + ret2); + n = m_nbBuffer - m_stringLen - 1; + ret = WideCharToMultiByte( + codePage, + 0, + psz, + (int)len, + m_pBuffer + m_stringLen, + (int)n, + NULL, + NULL); + LogAssert(ret == ret2); + } + else + { + LogAssert(ret > 0); + } + + // + // update length of string + // + UpdateLength(m_stringLen+ret); + } + + return *this; +} + +#pragma warning (disable: 4995) // _vsnprintf deprecated +DrStr& DrStr::VAppendF(const char *pszFormat, va_list args) +{ + // TODO: this could be made more efficient by implementing vprintf ourselves, since we could + // TODO: reallocate and continue without starting over + GrowTo(64); // Start with something + + while (true) { + // See how many bytes are remaining in the buffer. + Size_t n = m_nbBuffer - m_stringLen; + + // Don't even bother trying to fit it in less than 10 bytes. + if (n > 10) { + // BUGBUG: not sure what _vsnprintf does when n-1 >_I32_MAX, since it cannot return a correct length as a result. So, we have to cap it at _I32_MAX-1 + // BUGBUG: and crash if it does not fit + if (n > (Size_t)_I32_MAX) { + n = (Size_t)_I32_MAX; + } + int ret = _vsnprintf(m_pBuffer + m_stringLen, n - 1, pszFormat, args); + if (ret >= 0) { + m_stringLen += ret; + m_pBuffer[m_stringLen] = '\0'; + break; + } + } + + // Wouldn't fit. + + // Throw away partial result + m_pBuffer[m_stringLen] = '\0'; + + // If we have already tried the max for _vsnprintf, we have to crash + LogAssert(n < (Size_t)_I32_MAX); + + // grow buffer, and try again. + GrowTo(m_nbBuffer +1); // This will force size doubling + } + + return *this; +} + +// +// Append the value in an environment variable onto the end of the current string +// +DrError DrStr::AppendFromEnvironmentVariable(const char *pszVarName) +{ + const char *pszValue = NULL; + DrError err = DrGetEnvironmentVariable(pszVarName, &pszValue); + if (err == DrError_OK) + { + Append(pszValue); + free((char *)pszValue); + } + return err; +} + +DrStr& DrStr::AppendXmlEncodedString(const char *pszUnencoded, Size_t len, bool fEscapeNewlines) +{ + if (len == 0) { + EnsureNotNull(); + } else { + LogAssert(pszUnencoded != NULL); + + char c; + while ((c = *pszUnencoded++) != '\0') + { + if (c == '<') { + *this += "<"; + } else if (c == '>') { + *this += ">"; + } else if (c == '\'') { + *this += "'"; + } else if (c == '"') { + *this += """; + } else if (c == '&') { + *this += "&"; + } else if (!fEscapeNewlines && (c == '\n' || c == '\r')) { + *this += c; + } else if ((unsigned char)c < ' ') { + AppendF("&#x%02X;", (unsigned int)(unsigned char)c); + } else { + *this += c; + } + } + } + + return *this; +} + +bool DrStr::SubstrIsEqual(size_t index, const char *pszMatch, size_t matchLen) const +{ + if (pszMatch == NULL) { + LogAssert(matchLen == 0); + return (index == 0 && m_pBuffer == NULL); + } + if (index + matchLen > m_stringLen) { + return false; + } + if (matchLen == 0) { + return true; + } + return (memcmp(m_pBuffer+index, pszMatch, matchLen) == 0); +} + +bool DrStr::SubstrIsEqualNoCase(size_t index, const char *pszMatch, size_t matchLen) const +{ + if (pszMatch == NULL) { + LogAssert(matchLen == 0); + return (index == 0 && m_pBuffer == NULL); + } + if (index + matchLen > m_stringLen) { + return false; + } + if (matchLen == 0) { + return true; + } + return (_strnicmp(m_pBuffer+index, pszMatch, matchLen) == 0); +} + +// +// Returns first index where a given character can be found starting at a certain index +// returns DrStr_InvalidIndex if there is no match or the starting index is out of range +// returns the string length if the null terminator is matched +// Uses strchr - not multibyte aware. +// +size_t DrStr::IndexOfChar(char c, size_t startIndex) const +{ + // + // If start Index invalid, return invalid index + // + if(startIndex > m_stringLen) + { + return DrStr_InvalidIndex; + } + + // + // If character not found, return invalid index + // + char *p = strchr(m_pBuffer+startIndex, c); + if (p == NULL) + { + return DrStr_InvalidIndex; + } + + // + // If character found, return offset from where character found from beginning of string + // + return (size_t)(p - m_pBuffer); +} + +// +// Make string lower case +// Uses _strlwr (not multibyte aware) +// +DrStr& DrStr::ToLowerCase() +{ + if (m_pBuffer != NULL) + { + LogAssert(m_pBuffer[m_stringLen] == '\0'); + _strlwr(m_pBuffer); + } + + return *this; +} + +// +// Make string upper case +// Uses _strupr (not multibyte aware) +// +DrStr& DrStr::ToUpperCase() +{ + if (m_pBuffer != NULL) + { + LogAssert(m_pBuffer[m_stringLen] == '\0'); + _strupr(m_pBuffer); + } + + return *this; +} + +// +// Increases size of string if needed +// Ensures that there are at least maxStringLength+1 bytes available (including null terminator) for the string. +// Does not change the string length. +// If the string was NULL, it becomes an empty string. +// +void DrStr::GrowTo(size_t maxStringLength) +{ + size_t maxBuff = maxStringLength + 1; + + // + // If there is not yet a string, make it an empty string + // + if (m_pBuffer == NULL) + { + // switch from NULL to an empty string in the static buffer + m_pBuffer = m_pStaticBuffer; + m_nbBuffer = m_nbStatic; + if (m_pBuffer != NULL) { + LogAssert(m_nbBuffer != 0); + m_pBuffer[0] = '\0'; + } + } + + // + // If requested length is greater than existing length + // + if (m_nbBuffer < maxBuff) { + // + // Grow the heap string. Minimum = 64 characters, grow by factor of 2 beyond that. + // + size_t nbNew = m_nbBuffer * 2; + if (nbNew < 64) + { + nbNew = 64; + } + + // + // If 2*old_size is still less than required, increase to requested size + // + if (nbNew < maxBuff) + { + nbNew = maxBuff; + } + + // + // Allocate a new buffer of required size + // + char *pNew = AllocateBiggerBuffer( nbNew ); + LogAssert(pNew != NULL); + + // + // Copy old buffer contents into new buffer + // + if (m_stringLen != 0) + { + memcpy(pNew, m_pBuffer, m_stringLen); + } + pNew[m_stringLen] = '\0'; + + // + // If old buffer is valid and not static + // + if (m_pBuffer != NULL && m_pBuffer != m_pStaticBuffer) + { + delete[] m_pBuffer; + } + + // + // New buffer and buffer length becomes current buffer and buffer length + // + m_pBuffer = pNew; + m_nbBuffer = nbNew; + } +} + +DrInternalizedStringPool g_DrInternalizedStrings; + +const char *DrInternalizedStringPool::InternalizeStringLowerCase(const char *pszString) +{ + if (pszString == NULL) { + return NULL; + } + + // Normalize the instance to lower case so it is case insensitive... + DrStr1024 strTemp; + strTemp = pszString; + strTemp.ToLowerCase(); + const char *pszResult = InternalizeString(strTemp); + return pszResult; +} + +const char *DrInternalizedStringPool::InternalizeStringUpperCase(const char *pszString) +{ + if (pszString == NULL) { + return NULL; + } + + // Normalize the instance to lower case so it is case insensitive... + DrStr1024 strTemp; + strTemp = pszString; + strTemp.ToUpperCase(); + const char *pszResult = InternalizeString(strTemp); + return pszResult; +} + + +const char *DrInternalizedStringPool::InternalizeString(const char *pszString) +{ + if (pszString == NULL) { + return NULL; + } + + Size_t length; + UInt32 hash = StringHash(pszString, &length); + const char *pszMatch = NULL; + { + MUTEX_LOCK(lock, m_Mutex); + UInt32 bucket = hash % m_hashTableSize; + InternalizedStringHeader *pHeader = m_pBuckets[bucket]; + + while (pHeader != NULL) { + LogAssert(pHeader->magic == k_internalizedStringMagic); + if (pHeader->hash == hash) { + const char *pszExisting = (const char *)(const void *)(pHeader+1); + if (strcmp(pszExisting, pszString) == 0) { + pszMatch = pszExisting; + break; + } + } + pHeader = pHeader->pNext; + } + + if (pszMatch == NULL) { + pHeader = (InternalizedStringHeader *)allocMem(sizeof(InternalizedStringHeader) + length + 1); + char *pszNew = (char *)(void *)(pHeader+1); + memcpy(pszNew, pszString, length+1); + pHeader->magic = k_internalizedStringMagic; + pHeader->hash = hash; + pHeader->pNext = m_pBuckets[bucket]; + m_pBuckets[bucket] = pHeader; + pszMatch = pszNew; + } + } + + return pszMatch; + +} + +DrError DrStringToPortNumber(const char *psz, DrPortNumber *pResult) +{ + if (psz == NULL) { + return DrError_InvalidParameter; + } + + DrError err = DrError_OK; + + if (_stricmp(psz, "any") == 0 || _stricmp(psz, "*") == 0) { + *pResult = DrAnyPortNumber; + } else if (_stricmp(psz, "invalid") == 0) { + *pResult = DrInvalidPortNumber; + } else { + err = DrStringToUInt16(psz, pResult); + } + return err; +} + +//JC +#if 0 +DrStr& DrStr::AppendNodeAddress(const DrNodeAddress& val) +{ + char buff[64]; + DrError err = val.ToAddressPortString(buff, sizeof(buff)); + LogAssert(err == DrError_OK); + return Append(buff); +} + + + +DrStr& DrStr::AppendHexBytes(const void *pData, size_t numBytes, bool fUppercase) +{ + const BYTE *pbData = (const BYTE *)pData; + for (size_t i = 0; i < numBytes; i++) { + AppendF(fUppercase ? "%02X" : "%02x", pbData[i]); + } + return *this; +} + +DrStr& DrStr::AppendCQuoteEncodedString(const char *pszUnencoded, Size_t len) +{ + LogAssert (pszUnencoded != NULL || len == 0); + if (len == 0) { + return *this; + } + + for (Size_t i = 0; i < len; i++) { + char c = pszUnencoded[i]; + if (c < ' ' || c > 126) { + if (c == '\r') { + Append("\\r"); + } else if (c == '\n') { + Append("\\n"); + } else if (c == '\t') { + Append("\\t"); + } else { + Append("\\x"); + AppendHexBytes(&c, 1, false); + } + } else { + switch(c) { + case '\\': + Append("\\\\"); + break; + case '"': + Append("\\\""); + break; + default: + Append(c); + }; + } + } + + return *this; +} + + +// Encodes a string for inclusion in XML text, outside of an element. Escapes "<" and "&", and nonprinting characters. +DrStr& DrStr::AppendXmlTextEncodedString(const char *pszUnencoded, Size_t len, bool fEscapeNewlines) +{ + if (len == 0) { + EnsureNotNull(); + } else { + LogAssert(pszUnencoded != NULL); + + for (Size_t i = 0; i < len; i++) { + char c = pszUnencoded[i]; + if (c == '<') { + *this += "<"; + } else if (c == '&') { + *this += "&"; + } else if (!fEscapeNewlines && (c == '\n' || c == '\r')) { + *this += c; + } else if ((unsigned char)c < ' ') { + AppendF("&#x%02X;", (unsigned int)(unsigned char)c); + } else { + *this += c; + } + } + } + + return *this; +} + +// Encodes a string for inclusion in an XML double-quoted attribute value. Escapes "<", "&", and "\"". +DrStr& DrStr::AppendXmlDQuoteEncodedString(const char *pszUnencoded, Size_t len, bool fEscapeNewlines) +{ + if (len == 0) { + EnsureNotNull(); + } else { + LogAssert(pszUnencoded != NULL); + + for (Size_t i = 0; i < len; i++) { + char c = pszUnencoded[i]; + if (c == '<') { + *this += "<"; + } else if (c == '"') { + *this += """; + } else if (c == '&') { + *this += "&"; + } else if (!fEscapeNewlines && (c == '\n' || c == '\r')) { + *this += c; + } else if ((unsigned char)c < ' ') { + AppendF("&#x%02X;", (unsigned int)(unsigned char)c); + } else { + *this += c; + } + } + } + + return *this; +} + +// Encodes a string for inclusion in an XML single-quoted attribute value. Escapes "<", "&", and "'". +DrStr& DrStr::AppendXmlSQuoteEncodedString(const char *pszUnencoded, Size_t len, bool fEscapeNewlines) +{ + if (len == 0) { + EnsureNotNull(); + } else { + LogAssert(pszUnencoded != NULL); + + for (Size_t i = 0; i < len; i++) { + char c = pszUnencoded[i]; + if (c == '<') { + *this += "<"; + } else if (c == '\'') { + *this += "'"; + } else if (c == '&') { + *this += "&"; + } else if (!fEscapeNewlines && (c == '\n' || c == '\r')) { + *this += c; + } else if ((unsigned char)c < ' ') { + AppendF("&#x%02X;", (unsigned int)(unsigned char)c); + } else { + *this += c; + } + } + } + + return *this; +} + +// Encodes a string for inclusion in XML text, outside of an element. Wraps with CDATA. Useful for +// long strings with lots of delimiters. Handles embedded "]]>" by splitting into two CDATA sections. +DrStr& DrStr::AppendXmlTextEncodedStringAsCDATA(const char *pszUnencoded, Size_t len) +{ + if (len == 0) { + EnsureNotNull(); + } else { + LogAssert(pszUnencoded != NULL); + + *this +="= len || pszUnencoded[i+1] != ']' || pszUnencoded[i+2] != '>') { + *this += c; + } else { + // We found an embedded "]]>". To handle this, we will output "]]", then close the current CDATA section, then open + // a new one and output the ">". + *this += "]]]]>"; + i += 2; + } + } + } + + return *this; +} + + +DrStr& DrStr::ParseNextCommandLineArg(const char *pszCommandLine, size_t *pNumCharsConsumedOut) +{ + size_t n = 0; + if (pszCommandLine == NULL) { + // null commandline + SetToNull(); + } else { + while (pszCommandLine[n] == ' ' || pszCommandLine[n] == '\t') { + // skip leading blanks and tabs + n++; + } + if (pszCommandLine[n] == '\0') { + // empty commandline + SetToNull(); + } else { + SetToEmptyString(); + bool fInQuote = false; + bool fTerminateQuoteAfterWord = false; + char c; + for (; (c = pszCommandLine[n]) != '\0'; n++) { + if (c == ' ' || c == '\t') { + if (fInQuote && !fTerminateQuoteAfterWord) { + // quoted blank or tab + Append(c); + } else { + // End of quote or unquoted word + fInQuote = false; + fTerminateQuoteAfterWord = false; + + // Consume trailing blanks/tabs and terminate parsing + do { + n++; + } while (pszCommandLine[n] == ' ' || pszCommandLine[n] == '\t' ); + break; + } + } else if (c == '"') { + if (pszCommandLine[n+1] == '\"' && pszCommandLine[n+2] == '\"') { + // There are 3 quotes in a row -- it is an escaped quote. + Append('"'); + n += 2; + } else { + // Toggle quote mode. For compatibilty, quoting always extends to the begginning/end of a word. + if (!fInQuote) { + fInQuote = true; + fTerminateQuoteAfterWord = false; + } else { + fTerminateQuoteAfterWord = !fTerminateQuoteAfterWord; + } + } + } else if (c == '\\') { + // backslash can escape a backslash preceding a quote, or a quote + int nBackslashes; + for (nBackslashes = 1; pszCommandLine[n+nBackslashes] == '\\'; nBackslashes++) { + // nothing to do + } + // advance to last backslash + n += (nBackslashes - 1); + bool fAddQuote = false; + if (pszCommandLine[n+1] == '"') { + // Advance past the last backslash + n++; + // The backslashes are terminated by a quote. There are to be interpreted as an escape sequence + // If there are an odd number of backslashes, then the last one is an escape for the quote + fAddQuote = (nBackslashes & 1) != 0; + // each pair of backslashes becomes a single backslash + nBackslashes = nBackslashes >> 1; + } + for (int i = 0; i < nBackslashes; i++) { + Append('\\'); + } + if (fAddQuote) { + Append('"'); + } + } else { + // normal chars are simply copied + Append(c); + } + } + } + } + + if (pNumCharsConsumedOut != NULL) { + *pNumCharsConsumedOut = n; + } + + return *this; +} + +// Parses a conventionally escaped/quoted command line string into a DrStrList of argument strings, with escaping removed +DrStrList& DrParseCommandLineToList(DrStrList& argListOut, const char *pszCommandLine) +{ + argListOut.Clear(); + + size_t nConsumed = 0; + DrStr64 strArg; + while (strArg.ParseNextCommandLineArg(pszCommandLine, &nConsumed) != NULL) { + argListOut.AddString(strArg); + pszCommandLine += nConsumed; + } + + return argListOut; +} + +void DrFreeParsedCommandLineArgv(char **argv) +{ + delete[] argv; +} + +// Parses a conventionally escaped/quoted command line string into a UTF-8 encoded argc/argv pair. +// The returned argv value should be freed with DrFreeParsedCommandLineArgv +void DrParseCommandLineToArgv(const char *pszCommandLine, int *pargc, char ***pargv) +{ + char **newArgv = NULL; + + DrStrList argList; + DrParseCommandLineToList(argList, pszCommandLine); + int argc = (int)argList.GetNumStrings(); + if (argc > 0) { + newArgv = new char *[(size_t)argc]; + LogAssert(newArgv != NULL); + for (int i = 0; i < argc; i++) { + if (argList[i] == NULL) { + newArgv[i] = NULL; + } else { + newArgv[i] = new char[argList[i].GetLength() + 1]; + LogAssert(newArgv[i] != NULL); + memcpy(newArgv[i], argList[i].GetString(), argList[i].GetLength() + 1); + } + } + } + + *pargv = newArgv; + *pargc = argc; +} + +DrStr& DrStr::AppendEncodedCommandLine(const DrStrList& argList) +{ + EnsureNotNull(); + UInt32 n = argList.GetNumStrings(); + for (UInt32 i = 0; i < n; i++) { + if (i != 0) { + Append(' '); + } + AppendCommandLineEncodedString(argList[i]); + } + return *this; +} + +DrStr& DrStr::AppendEncodedCommandLine(int argc, char **argv) +{ + EnsureNotNull(); + for (int i = 0; i < argc; i++) { + if (i != 0) { + Append(' '); + } + AppendCommandLineEncodedString(argv[i]); + } + return *this; +} + + +DrStr& DrStr::AppendCommandLineEncodedString(const char *pszUnencoded, Size_t len) +{ + bool fEncloseInQuotes = false; + Size_t i; + + EnsureNotNull(); + + // first determine if the string must be quoted + for (i = 0;i < len;i++) { + char c = pszUnencoded[i]; + if (c == '\0') { + // no way to encode null character, so truncate the string at the null + len = i; + break; + } else if ((c >= '\0' && c <= ' ') || c == (char)127 || c == '`' || c == '^' || c == '&' || c == '(' || c == ')' || c == '{' || c == '}' || c == '[' || + c == ']' || c == '|' || c == ';' || c == '\'' || c == '"' || c == ',' || c == '<' || c == '>') { + // characters that are interpreted as delimters by the shell should generally be quoted to be safe + fEncloseInQuotes = true; + break; + } + } + + if (fEncloseInQuotes) { + Append('"'); + } + + + for (i = 0; i < len; i++) { + char c = pszUnencoded[i]; + switch(c) { + case '\0': + // Premature null byte, can't be escaped, stop here + goto done; + + case '"': + // Escape a double quote with a backslash + Append("\\\"", 2); + break; + + case '\\': + // Backslashes are escaped only if followed by a double quote + { + // skip past contiguous backslashes + Size_t nBackslashes; + for (nBackslashes = 1; i + nBackslashes < len && pszUnencoded[i + nBackslashes] == '\\'; nBackslashes++) { + // nothing to do + } + if (i + nBackslashes < len && pszUnencoded[i + nBackslashes] == '"') { + // backslashes preceeding a quote must be escaped + for (int i = 0; i < nBackslashes; i++) { + Append("\\\\", 2); + } + } else { + // backslashes not preceding a quote need not be escaped + for (int i = 0; i < nBackslashes; i++) { + Append('\\'); + } + } + i += (nBackslashes - 1); + } + + break; + + default: + // All other characters are not escaped + Append(c); + break; + }; + } + +done: + if (fEncloseInQuotes) { + Append('"'); + } + + return *this; + +} + +DrStr& DrStr::AppendDrvEncodedString(const char *pszUnencoded, Size_t len) +{ + LogAssert (pszUnencoded != NULL || len == 0); + if (len == 0) { + return *this; + } + + // determine if quotes are required + bool fQuote = false; + if (pszUnencoded[0] == ' ' || pszUnencoded[0] == '\t' || pszUnencoded[len-1] == ' ' || pszUnencoded[len-1] == '\t') { + // leading or trailing whitespace always requires quotes + fQuote = true; + } else { + for (Size_t i = 0; i < len; i++) { + char c = pszUnencoded[i]; + // newlines, double-quotes, and commas anywhere in the string require quoting + if (c == '\n' || c == '\r' || c == '"' || c == ',') { + fQuote = true; + break; + } + } + } + + // Append the string, with optional quotes + if (fQuote) { + Append('"'); + } + + for (Size_t i = 0; i < len; i++) { + char c = pszUnencoded[i]; + if (c == '"') { + Append("\"\"", 2); + } else { + Append(c); + } + } + + if (fQuote) { + Append('"'); + } + + return *this; +} + +DrError DrStr::AppendFromEnvironmentVariable(const char *pszVarName) +{ + const char *pszValue = NULL; + DrError err = DrGetEnvironmentVariable(pszVarName, &pszValue); + if (err == DrError_OK) { + Append(pszValue); + free((char *)pszValue); + } + return err; +} + +DrStr& DrStr::AppendFromOptionalEnvironmentVariable(const char *pszVarName, const char *pszDefault) +{ + DrError err = AppendFromEnvironmentVariable(pszVarName); + if (err != DrError_OK && pszDefault != NULL) { + Append(pszDefault); + } + return *this; +} + +// returns 1 if this string is greater, 0 if they are equal, and -1 if this string is less +// NULL is less than any other value +// Uses memcmp (not multibyte aware) +int DrStr::Compare(const char *pszOther, size_t length) const +{ + if (m_pBuffer == NULL) { + if (pszOther == NULL) { + return 0; + } else { + return -1; + } + } else if (pszOther == NULL) { + return 1; + } + + size_t minlen = length; + if (minlen > m_stringLen) { + minlen = m_stringLen; + } + + int ret = 0; + if (minlen != 0) { + ret = memcmp(m_pBuffer, pszOther, minlen); + } + + if (ret == 0) { + if (m_stringLen > length) { + ret = 1; + } else if (m_stringLen < length) { + ret = -1; + } + } + + return ret; +} + +// returns 1 if this string is greater, 0 if they are equal, and -1 if this string is less +// Uses case insensitive compare +// Uses _stricmp (not multibyte aware) +// NULL is less than any other value +int DrStr::CompareNoCase(const char *pszOther) const +{ + if (m_pBuffer == NULL) { + if (pszOther == NULL) { + return 0; + } else { + return -1; + } + } else if (pszOther == NULL) { + return 1; + } + + int ret = _stricmp(m_pBuffer, pszOther); + return ret; +} + +bool DrStr::SubstrIsEqual(size_t index, const char *pszMatch, size_t matchLen) const +{ + if (pszMatch == NULL) { + LogAssert(matchLen == 0); + return (index == 0 && m_pBuffer == NULL); + } + if (index + matchLen > m_stringLen) { + return false; + } + if (matchLen == 0) { + return true; + } + return (memcmp(m_pBuffer+index, pszMatch, matchLen) == 0); +} + +bool DrStr::SubstrIsEqualNoCase(size_t index, const char *pszMatch, size_t matchLen) const +{ + if (pszMatch == NULL) { + LogAssert(matchLen == 0); + return (index == 0 && m_pBuffer == NULL); + } + if (index + matchLen > m_stringLen) { + return false; + } + if (matchLen == 0) { + return true; + } + return (_strnicmp(m_pBuffer+index, pszMatch, matchLen) == 0); +} + + +// returns DrStr_InvalidIndex if there is no match or the starting index is out of range +// returns the string length if the null terminator is matched +// Uses strchr - not multibyte aware. +size_t DrStr::IndexOfChar(char c, size_t startIndex) const +{ + if(startIndex > m_stringLen) { + return DrStr_InvalidIndex; + } + char *p = strchr(m_pBuffer+startIndex, c); + if (p == NULL) { + return DrStr_InvalidIndex; + } + return (size_t)(p - m_pBuffer); +} + +// returns DrStr_InvalidIndex if there is no match or the startLength is out of range +// The startLength should be one greater than the first posible matching index (e.g., the length of the string to search) +size_t DrStr::ReverseIndexOfChar(char c, size_t startLength) const +{ + if(startLength == 0 || startLength > m_stringLen+1 || m_pBuffer == NULL) { + return DrStr_InvalidIndex; + } + char * p = m_pBuffer + startLength; + while (--p >= m_pBuffer) { + if (*p == c) { + return (size_t)(p - m_pBuffer); + } + } + return DrStr_InvalidIndex; +} + +// returns DrStr_InvalidIndex if there is no match or the starting index is out of range +// returns the string length if the null terminator is matched +// not multibyte aware. +size_t DrStr::IndexOfString(const char *psz, size_t startIndex) const +{ + if(psz == NULL || m_pBuffer == NULL || startIndex > m_stringLen) { + return DrStr_InvalidIndex; + } + char *p = strstr(m_pBuffer+startIndex, psz); + if (p == NULL) { + return DrStr_InvalidIndex; + } + return (size_t)(p - m_pBuffer); +} + +// returns DrStr_InvalidIndex if there is no match or the startLength is out of range +// The startLength should be one greater than the first posible matching index (e.g., the length of the string to search) +size_t DrStr::ReverseIndexOfString(const char *psz, size_t startLength) const +{ + if(psz == NULL || startLength == 0 || startLength > m_stringLen+1 || m_pBuffer == NULL) { + return DrStr_InvalidIndex; + } + size_t slen = strlen(psz); + if (slen > m_stringLen) { + return DrStr_InvalidIndex; + } + size_t spos = startLength; + if (spos > m_stringLen - slen) { + spos = m_stringLen - slen; + } + char * p = m_pBuffer + spos; + while (p >= m_pBuffer) { + if (strncmp(p, psz, slen) == 0) { + return (size_t)(p - m_pBuffer); + } + p--; + } + return DrStr_InvalidIndex; +} + +// returns DrStr_InvalidIndex if there is no match or the starting index is out of range +// returns the string length if the null terminator is matched +// not multibyte aware. +size_t DrStr::IndexOfStringNoCase(const char *psz, size_t startIndex) const +{ + if(psz == NULL || m_pBuffer == NULL || startIndex > m_stringLen) { + return DrStr_InvalidIndex; + } + size_t slen = strlen(psz); + if (slen > m_stringLen - startIndex) { + return DrStr_InvalidIndex; + } + size_t epos = m_stringLen - slen; + for (size_t i = startIndex; i <= epos; i++) { + if (_strnicmp(m_pBuffer + i, psz, slen) == 0) { + return i; + } + } + return DrStr_InvalidIndex; +} + +// returns DrStr_InvalidIndex if there is no match or the startLength is out of range +// The startLength should be one greater than the first posible matching index (e.g., the length of the string to search) +size_t DrStr::ReverseIndexOfStringNoCase(const char *psz, size_t startLength) const +{ + if(psz == NULL || startLength == 0 || startLength > m_stringLen+1 || m_pBuffer == NULL) { + return DrStr_InvalidIndex; + } + size_t slen = strlen(psz); + if (slen > m_stringLen) { + return DrStr_InvalidIndex; + } + size_t spos = startLength; + if (spos > m_stringLen - slen) { + spos = m_stringLen - slen; + } + char * p = m_pBuffer + spos; + while (p >= m_pBuffer) { + if (_strnicmp(p, psz, slen) == 0) { + return (size_t)(p - m_pBuffer); + } + p--; + } + return DrStr_InvalidIndex; +} + + +DrStr& DrStr::DeleteRange(size_t startIndex, size_t numChars) +{ + size_t oldend = startIndex + numChars; + LogAssert(oldend >= startIndex); // overflow check + LogAssert(m_pBuffer != NULL && oldend <= m_stringLen); + if (numChars != 0) { + if (oldend < m_stringLen) { + memmove(m_pBuffer + startIndex, m_pBuffer + oldend, m_stringLen - oldend); + } + m_stringLen -= numChars; + m_pBuffer[m_stringLen] = '\0'; + } + return *this; +} + +// Asserts that the startIndex is valid. If startIndex is 0 and the string is NULL, +// it is converted to an empty string before inserting. +DrStr& DrStr::Insert(size_t startIndex, const char *psz, size_t len) +{ + LogAssert(startIndex <= m_stringLen); + size_t newlen = m_stringLen + len; + LogAssert(newlen >= m_stringLen); // overflow check + GrowTo(newlen); + if (len != 0) { + size_t newstart = startIndex + len; + memmove(m_pBuffer + newstart, m_pBuffer + startIndex, m_stringLen -startIndex); + memcpy(m_pBuffer + startIndex, psz, len); + m_stringLen = newlen; + m_pBuffer[m_stringLen] = '\0'; + } + return *this; +} + +// Asserts that the startIndex is valid. If startIndex is 0 and the string is NULL, +// it is converted to an empty string before inserting. +DrStr& DrStr::ReplaceRange(size_t startIndex, size_t oldLen, const char *psz, size_t newLen) +{ + LogAssert(startIndex <= m_stringLen); + size_t oldend = startIndex + oldLen; + LogAssert(oldend >= startIndex); // overflow check + LogAssert(oldend <= m_stringLen); + if (newLen > oldLen) { + size_t nInsert = newLen - oldLen; + size_t ipos = startIndex + oldLen; + GrowTo(m_stringLen + nInsert); + if (ipos < m_stringLen) { + memmove(m_pBuffer + ipos + nInsert, m_pBuffer + ipos, m_stringLen -ipos); + } + m_stringLen += nInsert; + m_pBuffer[m_stringLen] = '\0'; + } else if (newLen < oldLen) { + EnsureNotNull(); + size_t nDelete = oldLen - newLen; + size_t ipos = startIndex + newLen; + size_t epos = ipos + nDelete; + if (epos < m_stringLen) { + memmove(m_pBuffer + ipos , m_pBuffer + epos, m_stringLen -epos); + } + m_stringLen -= nDelete; + m_pBuffer[m_stringLen] = '\0'; + } + + if (newLen != 0) { + memcpy(m_pBuffer+startIndex, psz, newLen); + } + + return *this; +} + +__inline bool ISSPACE(char c) +{ + return isspace((unsigned char)c) != 0; +} + + + +// +// Description: +// +// NextUTF8 - find the next UTF8 character in a UTF8 string +// +// Arguments: +// +// const char *psz - pointer to a valid UTF8 string. +// const char *&pszNext - place to point to the next valid UTF8 character in the string psz +// Size_t & cb - set to the size of the first UTF8 character in psz +// +// Return Value: +// +// true - psz was a valid UTF8 character pointer, and pszNext points to the next UTF8 character in the string. +// false - the string pointed to in psz was not valid UTF8 (as defined by the below table). +// +// +// Ref: http://en.wikipedia.org/wiki/UTF-8 +// +// With these restrictions, bytes in a UTF-8 sequence have the following meanings. The ones marked in red can never appear in a legal UTF-8 sequence. +// The ones in green are represented in a single byte. The ones in white must only appear as the first byte in a multi-byte sequence, +// and the ones in orange can only appear as the second or later byte in a multi-byte sequence: +// +// binary hex decimal notes color +// ----------------- ----- ------- ------------------------------------------------------------------------------ ----- +// 00000000-01111111 00-7F 0-127 US-ASCII (single byte) GREEN +// 10000000-10111111 80-BF 128-191 Second, third, or fourth byte of a multi-byte sequence ORANGE +// 11000000-11000001 C0-C1 192-193 Overlong encoding: start of a 2-byte sequence, but code point <= 127 RED +// 11000010-11011111 C2-DF 194-223 Start of 2-byte sequence WHITE +// 11100000-11101111 E0-EF 224-239 Start of 3-byte sequence WHITE +// 11110000-11110100 F0-F4 240-244 Start of 4-byte sequence WHITE +// 11110101-11110111 F5-F7 245-247 Restricted by RFC 3629: start of 4-byte sequence for codepoint above 10FFFF RED +// 11111000-11111011 F8-FB 248-251 Restricted by RFC 3629: start of 5-byte sequence RED +// 11111100-11111101 FC-FD 252-253 Restricted by RFC 3629: start of 6-byte sequence RED +// 11111110-11111111 FE-FF 254-255 Invalid: not defined by original UTF-8 specification RED +// +// The bits of a Unicode character are distributed into the lower bit positions inside the UTF-8 bytes, with the lowest bit going into the last bit of the last byte: +// +// Unicode Byte1 Byte2 Byte3 Byte4 example +// ----------------- -------- -------- -------- -------- --------------------------------------------------------------------- +// U+000000-U+00007F 0xxxxxxx U+0024 ? 00100100 ? 0x24 +// U+000080-U+0007FF 110xxxxx 10xxxxxx U+00A2 ? 11000010,10100010 ? 0xC2,0xA2 +// U+000800-U+00FFFF 1110xxxx 10xxxxxx 10xxxxxx U+20AC ? 11100010,10000010,10101100 ? 0xE2,0x82,0xAC +// U+010000-U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx U+10ABCD ? 11110100,10001010,10101111,10001101 ? 0xf4,0x8a,0xaf,0x8d + +__inline bool DrStr::NextUTF8(const char* psz, const char* &pszNext, Size_t &cb) +{ + const unsigned char *pc = (const unsigned char *)psz; + unsigned char c; + unsigned int nExtraBytes=0; + + pszNext = NULL; + + c=*pc; + + if(c <= 127) + { + // US ASCII + cb = 1; + pszNext = (char *)++pc; + return true; + } + + if(c == 0xc0 || c == 0xc1) + { + cb = 0; + return false; + } + else if(c >= 0xc2 && c <= 0xdf) + { + // 2 byte sequence. + cb = 2; + nExtraBytes = 1; + } + else if(c >= 0xe0 && c <= 0xef) + { + // 3 byte sequence. + cb = 3; + nExtraBytes = 2; + } + else if(c >= 0xf0 && c <= 0xf4) + { + // 4 bytes sequence. + cb = 4; + nExtraBytes = 3; + } + else + { + // not a legal UTF8 sequence + cb = 0; + return false; + } + + do + { + ++pc; + c=*pc; + nExtraBytes--; + + if(!(c >= 128 && c <= 191)) + { + // not a valid second, third or forth byte of a multi-byte sequence. + cb = 0; + return false; + } + + } while(c != '\0' && nExtraBytes != 0); + + if(nExtraBytes != 0) + { + cb = 0; + return false; + } + else + { + pszNext = (char *)pc; + return true; + } +} + +// +// Description: +// +// IsValidUTF8Ex - check if this is a valid UTF8 string and does not include any characters in the +// specified exclusion list. +// +// Arguments: +// +// this - a DrStr +// const char *excludeUTF8CharList - a list of UTF8 characters which are not permitted in the string. +// Size_t maxLen - the maximum string length in bytes (not characters) allowed. +// bool bVisibleASCII - if set, any ASCII character not in the visible ASCII set is not allowed +// bool bExcludeDoubleSlash - if set, no double slashes are allowed in the string +// bool bExcludeTrailingSlash - if set, the string must not end in a slash +// +// Return Value: +// +// true - the DrStr met the criteria for validity +// false - the DrStr does not meet the criteria for validity +// +bool DrStr::IsValidUTF8Ex(const char *excludeUTF8CharList, Size_t maxLen, bool bVisibleASCII, bool bExcludeDoubleSlash, bool bExcludeTrailingSlash) +{ + if(excludeUTF8CharList != NULL) + { + if(!DrStr64(excludeUTF8CharList).IsValidUTF8()) + { + // the exclude list isn't valid + return false; + } + } + + // variables to walk the exclusion list + const char *pExclude = excludeUTF8CharList; + const char *pNextExclude = pExclude; + + // variables to walk the string being tested + const char *pszUTF8 = m_pBuffer; + const char *pszNextUTF8 = pszUTF8; + + // number of bytes in the string + Size_t count=0; + + // the last ASCII character we parsed, if the last character was not ASCII this is set to '\0' + char lastASCII='\0'; + + bool rc=true; + + while(*pszUTF8 != '\0' && count < maxLen) + { + Size_t cLen; + + // get a pointer to the next character, and the length of the current character + rc = NextUTF8(pszUTF8, pszNextUTF8, cLen); + + if(rc==false) + { + // invalid UTF8 encountered + return false; + } + + // check if this character is in the exclusion list + if(excludeUTF8CharList) + { + Size_t cExcludeLen; + + while(*pExclude != '\0') + { + // walk the exclusion list, and check if the current character matches any excluded character + rc=NextUTF8(pExclude, pNextExclude, cExcludeLen); + + if(rc==true) + { + if(cExcludeLen==cLen && (memcmp(pExclude,pszUTF8,cLen)==0) ) + { + // character is in the exclude list + return false; + } + } + pExclude = pNextExclude; + } + + pExclude = excludeUTF8CharList; // back to the beginning for the next pass. + } + + if(cLen==1) + { + if(bVisibleASCII) + { + if(*pszUTF8 >= 1 && *pszUTF8 <=31) + { + // not printable ASCII character found + return false; + } + } + + if(bExcludeDoubleSlash) + { + if(lastASCII == '/' && *pszUTF8 == '/') + { + // double slash not allowed in this string + return false; + } + } + + lastASCII = *pszUTF8; + } + else + { + lastASCII = '\0'; + } + + pszUTF8 = pszNextUTF8; + count+=cLen; + + } + + if(count > maxLen) + { + // too long + return false; + } + + if(bExcludeTrailingSlash && (lastASCII == '/')) + { + // trailing slash not allowed in this string + return false; + } + + return true; + +} + +// Removes whitespace from the start and end of the string +DrStr& DrStr::Trim() +{ + if (m_pBuffer != NULL) { + size_t newlen = m_stringLen; + while (newlen != 0 && ISSPACE(m_pBuffer[newlen-1])) { + newlen--; + } + size_t nleading = 0; + while(nleading < newlen && ISSPACE(m_pBuffer[nleading])) { + nleading++; + } + if (nleading != 0 && nleading != newlen) { + memmove(m_pBuffer, m_pBuffer + nleading, newlen - nleading); + } + m_stringLen = newlen - nleading; + m_pBuffer[m_stringLen] = '\0'; + } + return *this; +} + +DrStr& DrStr::Append(const WCHAR *psz, size_t len, UINT codePage) +{ + GrowTo(m_stringLen + len); + if (len != 0) { + LogAssert(m_nbBuffer > m_stringLen + 1); + Size_t n = m_nbBuffer - m_stringLen - 1; + /* + int WideCharToMultiByte( + UINT CodePage, + DWORD dwFlags, + LPCWSTR lpWideCharStr, + int cchWideChar, + LPSTR lpMultiByteStr, + int cbMultiByte, + LPDRTR lpDefaultChar, + LPBOOL lpUsedDefaultChar + ); + */ + + int ret = WideCharToMultiByte( + codePage, + 0, + psz, + (int)len, + m_pBuffer + m_stringLen, + (int)n, + NULL, + NULL); + + if (ret == 0) { + DrError err = DrGetLastError(); + LogAssert(err == DrErrorFromWin32(ERROR_INSUFFICIENT_BUFFER)); + // The result wouldn't fit. Rather than iterating on heap allocation, we'll just ask windows to calculate the needed size: + int ret2 = WideCharToMultiByte( + codePage, + 0, + psz, + (int)len, + NULL, + 0, + NULL, + NULL); + LogAssert(ret2 > 1); + GrowTo(m_stringLen + ret2); + n = m_nbBuffer - m_stringLen - 1; + ret = WideCharToMultiByte( + codePage, + 0, + psz, + (int)len, + m_pBuffer + m_stringLen, + (int)n, + NULL, + NULL); + LogAssert(ret == ret2); + } else { + LogAssert(ret > 0); + } + UpdateLength(m_stringLen+ret); + } + + return *this; +} + +DrWStr& DrWStr::Append(const char *psz, size_t len, UINT codePage) +{ + GrowTo(m_stringLen + len); + if (len != 0) { + LogAssert(m_nbBuffer > m_stringLen + 1); + Size_t n = m_nbBuffer - m_stringLen - 1; + /* + int MultiByteToWideChar( + UINT CodePage, + DWORD dwFlags, + LPDRTR lpMultiByteStr, + int cbMultiByte, + LPWSTR lpWideCharStr, + int cchWideChar + ); + */ + + int ret = MultiByteToWideChar( + codePage, + 0, + psz, + (int)len, + m_pBuffer + m_stringLen, + (int)n); + + if (ret == 0) { + DrError err = DrGetLastError(); + LogAssert(err == DrErrorFromWin32(ERROR_INSUFFICIENT_BUFFER)); + // The result wouldn't fit. Rather than iterating on heap allocation, we'll just ask windows to calculate the needed size: + int ret2 = MultiByteToWideChar( + codePage, + 0, + psz, + (int)len, + NULL, + 0); + LogAssert(ret2 > 1); + GrowTo(m_stringLen + ret2); + n = m_nbBuffer - m_stringLen - 1; + ret = MultiByteToWideChar( + codePage, + 0, + psz, + (int)len, + m_pBuffer + m_stringLen, + (int)n); + LogAssert(ret == ret2); + } else { + LogAssert(ret > 0); + } + UpdateLength(m_stringLen+ret); + } + + return *this; +} + + + + +// Ensures that there are at least maxStringLength+1 bytes available (including null terminator) for the string. +// Does not change the string length. +// If the string was NULL, it becomes an empty string. +void DrWStr::GrowTo(size_t maxStringLength) +{ + size_t maxBuff = maxStringLength + 1; + + if (m_pBuffer == NULL) { + // switch from NULL to an empty string in the static buffer + m_pBuffer = m_pStaticBuffer; + m_nbBuffer = m_nbStatic; + if (m_pBuffer != NULL) { + LogAssert(m_nbBuffer != 0); + m_pBuffer[0] = L'\0'; + } + } + + if (m_nbBuffer < maxBuff) { + // Grow the heap string + size_t nbNew =m_nbBuffer * 2; + if (nbNew < 64) { + nbNew = 64; + } + if (nbNew < maxBuff) { + nbNew = maxBuff; + } + WCHAR *pNew = new WCHAR[nbNew]; + LogAssert(pNew != NULL); + if (m_stringLen != 0) { + memcpy(pNew, m_pBuffer, m_stringLen * sizeof(WCHAR)); + } + pNew[m_stringLen] = L'\0'; + if (m_pBuffer != NULL && m_pBuffer != m_pStaticBuffer) { + delete[] m_pBuffer; + } + m_pBuffer = pNew; + m_nbBuffer = nbNew; + } +} + + +DrWStr& DrWStr::AppendNodeAddress(const DrNodeAddress& val) +{ + WCHAR buff[64]; + DrError err = val.ToAddressPortString(buff, ELEMENTCOUNT(buff)); + LogAssert(err == DrError_OK); + return Append(buff); +} + + +DrWStr& DrWStr::VAppendF(const WCHAR *pszFormat, va_list args) +{ + // TODO: this could be made more efficient by implementing vprintf ourselves, since we could + // TODO: reallocate and continue without starting over + GrowTo(64); // Start with something + + while (true) { + // See how many WCHARS are remaining in the buffer. + Size_t n = m_nbBuffer - m_stringLen; + + // Don't even bother trying to fit it in less than 10 WCHARS. + if (n > 10) { + // BUGBUG: not sure what _vsnwprintf does when n-1 >_I32_MAX, since it cannot return a correct length as a result. So, we have to cap it at _I32_MAX-1 + // BUGBUG: and crash if it does not fit + if (n > (Size_t)_I32_MAX) { + n = (Size_t)_I32_MAX; + } + int ret = _vsnwprintf(m_pBuffer + m_stringLen, n - 1, pszFormat, args); + if (ret >= 0) { + m_stringLen += ret; + m_pBuffer[m_stringLen] = L'\0'; + break; + } + } + + // Wouldn't fit. + // Throw away partial result + m_pBuffer[m_stringLen] = '\0'; + + // If we have already tried the max for _vsnwprintf, we have to crash + LogAssert(n < (Size_t)_I32_MAX); + + // grow buffer, and try again. + GrowTo(m_nbBuffer +1); // This will force size doubling + } + + return *this; +} + +DrWStr& DrWStr::AppendHexBytes(const void *pData, size_t numBytes, bool fUppercase) +{ + const BYTE *pbData = (const BYTE *)pData; + for (size_t i = 0; i < numBytes; i++) { + AppendF(fUppercase ? L"%02X" : L"%02x", pbData[i]); + } + return *this; +} + +DrWStr& DrWStr::AppendCQuoteEncodedString(const WCHAR *pszUnencoded, Size_t len) +{ + LogAssert (pszUnencoded != NULL || len == 0); + if (len == 0) { + return *this; + } + + for (Size_t i = 0; i < len; i++) { + WCHAR c = pszUnencoded[i]; + if (c < L' ') { + if (c == L'\r') { + Append(L"\\r"); + } else if (c == L'\n') { + Append(L"\\n"); + } else if (c == L'\t') { + Append(L"\\t"); + } else { + Append(L"\\x"); + AppendHexBytes(&c, 1, false); + } + } else { + switch(c) { + case L'\\': + Append(L"\\\\"); + break; + case L'"': + Append(L"\\\""); + break; + default: + Append(c); + }; + } + } + + return *this; +} + +DrWStr& DrWStr::AppendXmlEncodedString(const WCHAR *pszUnencoded, Size_t len, bool fEscapeNewlines) +{ + if (len == 0) { + EnsureNotNull(); + } else { + LogAssert(pszUnencoded != NULL); + + WCHAR c; + while ((c = *pszUnencoded++) != L'\0') + { + if (c == L'<') { + *this += L"<"; + } else if (c == L'>') { + *this += L">"; + } else if (c == L'\'') { + *this += L"'"; + } else if (c == L'"') { + *this += L"""; + } else if (c == L'&') { + *this += L"&"; + } else if (!fEscapeNewlines && (c == L'\n' || c == L'\r')) { + *this += c; + } else if ((unsigned char)c < L' ') { + AppendF(L"&#x%02X;", (UInt16)c); + } else { + *this += c; + } + } + } + + return *this; +} + + +DrWStr& DrWStr::AppendDrvEncodedString(const WCHAR *pszUnencoded, Size_t len) +{ + LogAssert (pszUnencoded != NULL || len == 0); + if (len == 0) { + return *this; + } + + // determine if quotes are required + bool fQuote = false; + if (pszUnencoded[0] == L' ' || pszUnencoded[0] == L'\t' || pszUnencoded[len-1] == L' ' || pszUnencoded[len-1] == L'\t') { + // leading or trailing whitespace always requires quotes + fQuote = true; + } else { + for (Size_t i = 0; i < len; i++) { + WCHAR c = pszUnencoded[i]; + // newlines, double-quotes, and commas anywhere in the string require quoting + if (c == L'\n' || c == L'\r' || c == L'"' || c == L',') { + fQuote = true; + break; + } + } + } + + // Append the string, with optional quotes + if (fQuote) { + Append(L'"'); + } + + for (Size_t i = 0; i < len; i++) { + WCHAR c = pszUnencoded[i]; + if (c == L'"') { + Append(L"\"\"", 2); + } else { + Append(c); + } + } + + if (fQuote) { + Append(L'"'); + } + + return *this; +} + +DrError DrWStr::AppendFromEnvironmentVariable(const WCHAR *pszVarName) +{ + const WCHAR *pszValue = NULL; + DrError err = DrGetEnvironmentVariable(pszVarName, &pszValue); + if (err == DrError_OK) { + Append(pszValue); + free((WCHAR *)pszValue); + } + return err; +} + +DrWStr& DrWStr::AppendFromOptionalEnvironmentVariable(const WCHAR *pszVarName, const WCHAR *pszDefault) +{ + DrError err = AppendFromEnvironmentVariable(pszVarName); + if (err != DrError_OK && pszDefault != NULL) { + Append(pszDefault); + } + return *this; +} + +// returns 1 if this string is greater, 0 if they are equal, and -1 if this string is less +// NULL is less than any other value +// Uses memcmp (not multibyte aware) +int DrWStr::Compare(const WCHAR *pszOther, size_t length) const +{ + if (m_pBuffer == NULL) { + if (pszOther == NULL) { + return 0; + } else { + return -1; + } + } else if (pszOther == NULL) { + return 1; + } + + size_t minlen = length; + if (minlen > m_stringLen) { + minlen = m_stringLen; + } + + int ret = 0; + if (minlen != 0) { + ret = memcmp(m_pBuffer, pszOther, minlen * sizeof(WCHAR)); + } + + if (ret == 0) { + if (m_stringLen > length) { + ret = 1; + } else if (m_stringLen < length) { + ret = -1; + } + } + + return ret; +} + +// returns 1 if this string is greater, 0 if they are equal, and -1 if this string is less +// Uses case insensitive compare +// Uses _stricmp (not multibyte aware) +// NULL is less than any other value +int DrWStr::CompareNoCase(const WCHAR *pszOther) const +{ + if (m_pBuffer == NULL) { + if (pszOther == NULL) { + return 0; + } else { + return -1; + } + } else if (pszOther == NULL) { + return 1; + } + + int ret = _wcsicmp(m_pBuffer, pszOther); + return ret; +} + +bool DrWStr::SubstrIsEqual(size_t index, const WCHAR *pszMatch, size_t matchLen) const +{ + if (pszMatch == NULL) { + LogAssert(matchLen == 0); + return (index == 0 && m_pBuffer == NULL); + } + if (index + matchLen > m_stringLen) { + return false; + } + if (matchLen == 0) { + return true; + } + return (memcmp(m_pBuffer+index, pszMatch, matchLen * sizeof(WCHAR)) == 0); +} + +bool DrWStr::SubstrIsEqualNoCase(size_t index, const WCHAR *pszMatch, size_t matchLen) const +{ + if (pszMatch == NULL) { + LogAssert(matchLen == 0); + return (index == 0 && m_pBuffer == NULL); + } + if (index + matchLen > m_stringLen) { + return false; + } + if (matchLen == 0) { + return true; + } + return (_wcsnicmp(m_pBuffer+index, pszMatch, matchLen) == 0); +} + + +// returns DrStr_InvalidIndex if there is no match or the starting index is out of range +// returns the string length if the null terminator is matched +// Uses strchr - not multibyte aware. +size_t DrWStr::IndexOfChar(WCHAR c, size_t startIndex) const +{ + if(startIndex > m_stringLen) { + return DrStr_InvalidIndex; + } + WCHAR *p = wcschr(m_pBuffer+startIndex, c); + if (p == NULL) { + return DrStr_InvalidIndex; + } + return (size_t)(p - m_pBuffer); +} + +// returns DrStr_InvalidIndex if there is no match or the startLength is out of range +// The startLength should be one greater than the first posible matching index (e.g., the length of the string to search) +size_t DrWStr::ReverseIndexOfChar(WCHAR c, size_t startLength) const +{ + if(startLength == 0 || startLength > m_stringLen+1 || m_pBuffer == NULL) { + return DrStr_InvalidIndex; + } + WCHAR * p = m_pBuffer + startLength; + while (--p >= m_pBuffer) { + if (*p == c) { + return (size_t)(p - m_pBuffer); + } + } + return DrStr_InvalidIndex; +} + +// returns DrStr_InvalidIndex if there is no match or the starting index is out of range +// returns the string length if the null terminator is matched +// not multibyte aware. +size_t DrWStr::IndexOfString(const WCHAR *psz, size_t startIndex) const +{ + if(psz == NULL || m_pBuffer == NULL || startIndex > m_stringLen) { + return DrStr_InvalidIndex; + } + WCHAR *p = wcsstr(m_pBuffer+startIndex, psz); + if (p == NULL) { + return DrStr_InvalidIndex; + } + return (size_t)(p - m_pBuffer); +} + +// returns DrStr_InvalidIndex if there is no match or the startLength is out of range +// The startLength should be one greater than the first posible matching index (e.g., the length of the string to search) +size_t DrWStr::ReverseIndexOfString(const WCHAR *psz, size_t startLength) const +{ + if(psz == NULL || startLength == 0 || startLength > m_stringLen+1 || m_pBuffer == NULL) { + return DrStr_InvalidIndex; + } + size_t slen = wcslen(psz); + if (slen > m_stringLen) { + return DrStr_InvalidIndex; + } + size_t spos = startLength; + if (spos > m_stringLen - slen) { + spos = m_stringLen - slen; + } + WCHAR * p = m_pBuffer + spos; + while (p >= m_pBuffer) { + if (wcsncmp(p, psz, slen) == 0) { + return (size_t)(p - m_pBuffer); + } + p--; + } + return DrStr_InvalidIndex; +} + +// returns DrStr_InvalidIndex if there is no match or the starting index is out of range +// returns the string length if the null terminator is matched +// not multibyte aware. +size_t DrWStr::IndexOfStringNoCase(const WCHAR *psz, size_t startIndex) const +{ + if(psz == NULL || m_pBuffer == NULL || startIndex > m_stringLen) { + return DrStr_InvalidIndex; + } + size_t slen = wcslen(psz); + if (slen > m_stringLen - startIndex) { + return DrStr_InvalidIndex; + } + size_t epos = m_stringLen - slen; + for (size_t i = startIndex; i <= epos; i++) { + if (_wcsnicmp(m_pBuffer + i, psz, slen) == 0) { + return i; + } + } + return DrStr_InvalidIndex; +} + +// returns DrStr_InvalidIndex if there is no match or the startLength is out of range +// The startLength should be one greater than the first posible matching index (e.g., the length of the string to search) +size_t DrWStr::ReverseIndexOfStringNoCase(const WCHAR *psz, size_t startLength) const +{ + if(psz == NULL || startLength == 0 || startLength > m_stringLen+1 || m_pBuffer == NULL) { + return DrStr_InvalidIndex; + } + size_t slen = wcslen(psz); + if (slen > m_stringLen) { + return DrStr_InvalidIndex; + } + size_t spos = startLength; + if (spos > m_stringLen - slen) { + spos = m_stringLen - slen; + } + WCHAR * p = m_pBuffer + spos; + while (p >= m_pBuffer) { + if (_wcsnicmp(p, psz, slen) == 0) { + return (size_t)(p - m_pBuffer); + } + p--; + } + return DrStr_InvalidIndex; +} + +// Uses _strlwr (not multibyte aware) +DrWStr& DrWStr::ToLowerCase() +{ + if (m_pBuffer != NULL) { + LogAssert(m_pBuffer[m_stringLen] == L'\0'); + _wcslwr(m_pBuffer); + } + return *this; +} + +// Uses _strupr (not multibyte aware) +DrWStr& DrWStr::ToUpperCase() +{ + if (m_pBuffer != NULL) { + LogAssert(m_pBuffer[m_stringLen] == L'\0'); + _wcsupr(m_pBuffer); + } + return *this; +} + +DrWStr& DrWStr::DeleteRange(size_t startIndex, size_t numChars) +{ + size_t oldend = startIndex + numChars; + LogAssert(oldend >= startIndex); // overflow check + LogAssert(m_pBuffer != NULL && oldend <= m_stringLen); + if (numChars != 0) { + if (oldend < m_stringLen) { + memmove(m_pBuffer + startIndex, m_pBuffer + oldend, (m_stringLen - oldend) * sizeof(WCHAR)); + } + m_stringLen -= numChars; + m_pBuffer[m_stringLen] = L'\0'; + } + return *this; +} + +// Asserts that the startIndex is valid. If startIndex is 0 and the string is NULL, +// it is converted to an empty string before inserting. +DrWStr& DrWStr::Insert(size_t startIndex, const WCHAR *psz, size_t len) +{ + LogAssert(startIndex <= m_stringLen); + size_t newlen = m_stringLen + len; + LogAssert(newlen >= m_stringLen); // overflow check + GrowTo(newlen); + if (len != 0) { + size_t newstart = startIndex + len; + memmove(m_pBuffer + newstart, m_pBuffer + startIndex, (m_stringLen -startIndex) * sizeof(WCHAR)); + memcpy(m_pBuffer + startIndex, psz, len * sizeof(WCHAR)); + m_stringLen = newlen; + m_pBuffer[m_stringLen] = L'\0'; + } + return *this; +} + +// Asserts that the startIndex is valid. If startIndex is 0 and the string is NULL, +// it is converted to an empty string before inserting. +DrWStr& DrWStr::ReplaceRange(size_t startIndex, size_t oldLen, const WCHAR *psz, size_t newLen) +{ + LogAssert(startIndex <= m_stringLen); + size_t oldend = startIndex + oldLen; + LogAssert(oldend >= startIndex); // overflow check + LogAssert(oldend <= m_stringLen); + if (newLen > oldLen) { + size_t nInsert = newLen - oldLen; + size_t ipos = startIndex + oldLen; + GrowTo(m_stringLen + nInsert); + if (ipos < m_stringLen) { + memmove(m_pBuffer + ipos + nInsert, m_pBuffer + ipos, (m_stringLen -ipos) * sizeof(WCHAR)); + } + m_stringLen += nInsert; + m_pBuffer[m_stringLen] = L'\0'; + } else if (newLen < oldLen) { + EnsureNotNull(); + size_t nDelete = oldLen - newLen; + size_t ipos = startIndex + newLen; + size_t epos = ipos + nDelete; + if (epos < m_stringLen) { + memmove(m_pBuffer + ipos , m_pBuffer + epos, (m_stringLen -epos) * sizeof(WCHAR)); + } + m_stringLen -= nDelete; + m_pBuffer[m_stringLen] = L'\0'; + } + + if (newLen != 0) { + memcpy(m_pBuffer+startIndex, psz, newLen * sizeof(WCHAR)); + } + + return *this; +} + +__inline bool ISWSPACE(WCHAR c) +{ + return iswspace(c) != 0; +} + +// Removes whitespace from the start and end of the string +DrWStr& DrWStr::Trim() +{ + if (m_pBuffer != NULL) { + size_t newlen = m_stringLen; + while (newlen != 0 && ISWSPACE(m_pBuffer[newlen-1])) { + newlen--; + } + size_t nleading = 0; + while(nleading < newlen && ISWSPACE(m_pBuffer[nleading])) { + nleading++; + } + if (nleading != 0 && nleading != newlen) { + memmove(m_pBuffer, m_pBuffer + nleading, (newlen - nleading) * sizeof(WCHAR)); + } + m_stringLen = newlen - nleading; + m_pBuffer[m_stringLen] = L'\0'; + } + return *this; +} + +DrError DrStringToSignedOrUnsignedInt64(const WCHAR *psz, UInt64 *pResult, bool fSigned) +{ + UInt64 v = 0; + UInt64 vnew; + bool neg = false; + bool gotDig = false; + int base = 10; + + while (ISWSPACE(*psz)) { + psz++; + } + + if (*psz == L'+') { + psz++; + } else if (fSigned && *psz == L'-') { + neg = true; + psz++; + } + + if (*psz == L'0' && (*(psz+1) == L'x' ||*(psz+1) == L'X') ) { + psz += 2; + base = 16; + } + + if (base == 16 && neg == false) { + // we allow hex constants to set the sign bit. + fSigned = false; + } + + while (*psz != L'\0' && !ISWSPACE(*psz)) { + int dig = -1; + if (*psz >= L'0' && *psz <= L'9') { + dig = *psz - L'0'; + } else if (*psz >= L'a' && *psz <= L'f') { + dig = *psz - L'a' + 10; + } else if (*psz >= L'A' && *psz <= L'F') { + dig = *psz - L'A' + 10; + } + if (dig < 0 || dig >= base) { + return DrError_InvalidParameter; + } + vnew = v * base + dig; + if ((fSigned && (Int64)vnew < 0) || ((vnew - dig)/base != v)) { + // overflow + return DrError_InvalidParameter; + } + v = vnew; + gotDig = true; + psz++; + } + + if (!gotDig) { + return DrError_InvalidParameter; + } + + while (ISWSPACE(*psz)) { + psz++; + } + + if (*psz != L'\0') { + return DrError_InvalidParameter; + } + + if (neg) { + v= (UInt64)(-(Int64)v); + } + + *pResult = v; + return DrError_OK; +} + +DrError DrStringToFloat(const WCHAR *psz, float *pResult) +{ + double v; + DrError err = DrStringToDouble(psz, &v); + if (err == DrError_OK) { + *pResult = (float)v; + } + return err; +} + +DrError DrStringToUInt64(const WCHAR *psz, UInt64 *pResult) +{ + return DrStringToSignedOrUnsignedInt64(psz, pResult, false); +} + +DrError DrStringToInt64(const WCHAR *psz, Int64 *pResult) +{ + return DrStringToSignedOrUnsignedInt64(psz, (UInt64 *)(void *)pResult, true); +} + +DrError DrStringToUInt16(const WCHAR *psz, UInt16 *pResult) +{ + UInt64 v; + DrError err = DrStringToSignedOrUnsignedInt64(psz, &v, false); + if (err != DrError_OK) { + return err; + } + UInt16 v16 = (UInt16)v; + if (v != (UInt64)v16) { + return DrError_InvalidParameter; + } + *pResult = v16; + return DrError_OK; +} + +DrError DrStringToPortNumber(const char *psz, DrPortNumber *pResult) +{ + if (psz == NULL) { + return DrError_InvalidParameter; + } + + DrError err = DrError_OK; + + if (_wcsicmp(psz, L"any") == 0 || _wcsicmp(psz, L"*") == 0) { + *pResult = DrAnyPortNumber; + } else if (_wcsicmp(psz, L"invalid") == 0) { + *pResult = DrInvalidPortNumber; + } else { + err = DrStringToUInt16(psz, pResult); + } + return err; +} + +DrError DrStringToUInt32(const WCHAR *psz, UInt32 *pResult) +{ + UInt64 v; + DrError err = DrStringToSignedOrUnsignedInt64(psz, &v, false); + if (err != DrError_OK) { + return err; + } + UInt32 v32 = (UInt32)v; + if (v != (UInt64)v32) { + return DrError_InvalidParameter; + } + *pResult = v32; + return DrError_OK; +} + +DrError DrStringToInt32(const WCHAR *psz, Int32 *pResult) +{ + Int64 v; + DrError err = DrStringToSignedOrUnsignedInt64(psz, (UInt64 *)(void *)&v, true); + if (err != DrError_OK) { + return err; + } + Int32 v32 = (Int32)v; + if (v != (Int64)v32) { + return DrError_InvalidParameter; + } + *pResult = v32; + return DrError_OK; +} + +DrError DrStringToUInt(const WCHAR *psz, unsigned int *pResult) +{ + UInt64 v; + DrError err = DrStringToSignedOrUnsignedInt64(psz, &v, false); + if (err != DrError_OK) { + return err; + } + unsigned int v32 = (unsigned int)v; + if (v != (UInt64)v32) { + return DrError_InvalidParameter; + } + *pResult = v32; + return DrError_OK; +} + +DrError DrStringToInt(const WCHAR *psz, int *pResult) +{ + Int64 v; + DrError err = DrStringToSignedOrUnsignedInt64(psz, (UInt64 *)(void *)&v, true); + if (err != DrError_OK) { + return err; + } + int v32 = (int)v; + if (v != (Int64)v32) { + return DrError_InvalidParameter; + } + *pResult = v32; + return DrError_OK; +} + +DrError DrStringToDouble(const WCHAR *psz, double *pResult) +{ + double v = 0; + int exponent = 0; + bool neg = false; + bool gotDig = false; + bool gotPoint = false; + + while (ISWSPACE(*psz)) { + psz++; + } + + if (*psz == L'+') { + psz++; + } else if (*psz == L'-') { + neg = true; + psz++; + } + + while (*psz != L'\0' && !ISWSPACE(*psz) && *psz != L'e' && *psz != L'E') { + int dig = -1; + if (!gotPoint && *psz == L'.') { + gotPoint = true; + psz++; + continue; + } else if (*psz >= L'0' && *psz <= L'9') { + dig = *psz - L'0'; + } + if (dig < 0 || dig >= 10) { + return DrError_InvalidParameter; + } + v = v * 10.0 + dig; + if (gotPoint) { + exponent--; + } + gotDig = true; + psz++; + } + + if (!gotDig) { + return DrError_InvalidParameter; + } + + if (*psz == L'e' || *psz == L'E') { + psz++; + Int32 exp2; + DrError err = DrStringToInt32(psz, &exp2); + if (err != DrError_OK) { + return err; + } + exponent += exp2; + } else { + while (ISWSPACE(*psz)) { + psz++; + } + + if (*psz != L'\0') { + return DrError_InvalidParameter; + } + } + + if (exponent != 0) { + v = v * pow((double) 10, exponent); + } + + if (neg) { + v= -v; + } + + *pResult = v; + return DrError_OK; +} + +DrError DrStringToBool(const WCHAR *psz, bool *pResult) +{ + bool ret = false; + WCHAR tmp[16]; + size_t length = 0; + + while (ISWSPACE(*psz)) { + psz++; + } + + while (length < 15 && *psz != L'\0' && !ISWSPACE(*psz)) { + tmp[length++] = *(psz++); + } + + tmp[length] = L'\0'; + + while (ISWSPACE(*psz)) { + psz++; + } + + if (*psz != L'\0') { + return DrError_InvalidParameter; + } + + _wcslwr(tmp); + + length++; + if (wcsncmp(tmp, L"true", length) == 0 || + wcsncmp(tmp, L"yes", length) == 0 || + wcsncmp(tmp, L"on", length) == 0 || + wcsncmp(tmp, L"1", length) == 0) { + ret = true; + } else if ( + wcsncmp(tmp, L"false", length) == 0 || + wcsncmp(tmp, L"no", length) == 0 || + wcsncmp(tmp, L"off", length) == 0 || + wcsncmp(tmp, L"0", length) == 0) { + ret = false; + } else { + return DrError_InvalidParameter; + } + + *pResult = ret; + return DrError_OK; +} + + + + +void DrStrList::GrowTo(UInt32 nAlloced) +{ + if (nAlloced > m_numAllocedStrings) { + if (nAlloced < 32) { + nAlloced = 32; + } + if (nAlloced < 2 * m_numAllocedStrings) { + nAlloced = 2 * m_numAllocedStrings; + } + DrStr **ppNew = new DrStr*[nAlloced]; + LogAssert(ppNew != NULL); + if (m_numStrings > 0) { + memcpy(ppNew, m_prgpStrings, sizeof(DrStr *) * m_numStrings); + } + memset(ppNew + m_numStrings, 0, sizeof(DrStr *) * (nAlloced - m_numStrings)); + if (m_prgpStrings != NULL) { + delete[] m_prgpStrings; + } + m_prgpStrings = ppNew; + m_numAllocedStrings = nAlloced; + } +} +#endif // if 0 diff --git a/DryadVertex/VertexHost/system/classlib/src/DrThread.cpp b/DryadVertex/VertexHost/system/classlib/src/DrThread.cpp new file mode 100644 index 0000000..4b90c7c --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/DrThread.cpp @@ -0,0 +1,769 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrCommon.h" + +#pragma unmanaged + +volatile LONG DrJobHash::s_nextHash = 0L; + +DrJobHash DrNoJobHash; + +//Period in msec we wait for all worker threads to close +static const DWORD c_CloseThreadsTimeout=30000; + +DrThreadPool *g_pDrClientThreads = NULL; + +DrTlsPtr *t_ppThread = NULL; + + +//Slop period in msec we allow on timers so we don't wake the timer thread up to much +//e.g. If a timer expires at t=100 and we've another timer at t=120 we'll expire +//both of them rather than sleeping for the 20 msec +static const DWORD c_TimerExpirySlopPeriod=500; + +DrThread *DrGenerateCurrentThread() +{ + DrRef pThread; + pThread.Attach(new DrGenericThread()); + LogAssert(pThread != NULL); + pThread->AttachToCurrentThread(); + DrThread *pReturn = pThread; + // DecRef is OK here because current thread now holds a reference + return pReturn; +} + +DrThread::DrThread(const char *pszThreadClass, const char *pszShortClass, DrThreadPool *pPool, int iBucket) +{ + DrLogD( "DrThread contructed. %s (%s), pThread=%p, pool=%p, bucket=%d", pszThreadClass, pszShortClass, this, pPool, iBucket); + m_strClass = pszThreadClass; + m_strShortClass = pszShortClass; + m_pPool = pPool; + m_iBucket = iBucket; + m_pCurrentJob = NULL; + m_hThread = NULL; + m_dwThreadId = 0; + m_hIocp = NULL; + m_strTag = "INVL"; + m_strDescription = "Unattached Thread"; +} + +DrThread::~DrThread() +{ + DrLogD( "DrThread destructed. %s, pThread=%p", m_strClass.GetString(), this); + if (m_hThread != NULL) { + CloseHandle(m_hThread); + } +} + +void DrThread::AttachToCurrentThread() +{ + LogAssert(t_pThread.IsNull()); + t_pThread = this; + m_dwThreadId = GetCurrentThreadId(); + IncRef(); + UpdateTagAndDescription(); + DrLogD( "DrThread::AttachToCurrentThread. %s, pThread=%p", m_strClass.GetString(), this); +} + +void DrThread::DetachFromCurrentThread() +{ + LogAssert(t_pThread == this); + t_pThread = NULL; + DrLogD( "DrThread::AttachToCurrentThread. %s, pThread=%p", m_strClass.GetString(), this); + m_dwThreadId = 0; + LogAssert(m_hThread == NULL); + m_strTag = "INVL"; + m_strDescription = "Unattached Thread"; + DecRef(); + // "this" may no longer exist +} + +void DrThread::UpdateTagAndDescription() +{ + m_strDescription.SetF("%s %8u", m_strClass.GetString(), m_dwThreadId); + m_strTag.SetF("%s %8u", m_strShortClass.GetString(), m_dwThreadId); +} + +DrError DrThread::Start( + LPSECURITY_ATTRIBUTES lpThreadAttributes, + SIZE_T dwStackSize, + DWORD dwCreationFlags +) +{ + LogAssert(m_hThread == NULL); + DWORD dwThreadId; + + HANDLE h = CreateThread( + lpThreadAttributes, + dwStackSize, + ThreadEntryStatic, + this, + dwCreationFlags | CREATE_SUSPENDED, + &dwThreadId); + if (h == NULL) { + return DrGetLastError(); + } + m_hThread = h; + m_dwThreadId = dwThreadId; + + UpdateTagAndDescription(); + IncRef(); // The win32 thread owns one reference + if ((dwCreationFlags & CREATE_SUSPENDED) == 0) { + DWORD ret = ResumeThread(h); + LogAssert(ret != (DWORD)-1); + } + return DrError_OK; +} + +// param points to the DrThread, already IncRef'd +DWORD WINAPI DrThread::ThreadEntryStatic(void * param) +{ + DrThread *pThread = (DrThread *)param; + + // We save a pointer to this thread object in thread-local storage. That way, + // You can always find your current thread object with DrGetCurrentThread(). + // + LogAssert(t_pThread.IsNull()); + t_pThread = pThread; // save pointer to ourselves in thread local storage + + DrLogI( "Dryad thread starting. %s, pThread=%p", pThread->m_strDescription.GetString(), pThread); + + DWORD ret = pThread->ThreadEntry(); + + DrLogI( "Dryad thread exiting. %s, pThread=%p, exitcode=%08x", pThread->m_strDescription.GetString(), pThread, ret); + + LogAssert(t_pThread == pThread); + + // no longer under our control... + t_pThread = NULL; + + pThread->DecRef(); // The DrThread will be freed when noone is interested in it anymoe + + return ret; +} + +DrTimerThread::DrTimerThread(DrThreadPool *pPool) : DrThread("DrTimerThread", "TIMR", pPool) +{ + m_timerThreadWakesAt=0; + m_dwTimerThreadSleepPeriod=INFINITE; + m_timerEventHandle=CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(m_timerEventHandle != NULL); +} + +DrTimerThread::~DrTimerThread() +{ + if (m_timerEventHandle != NULL) { + CloseHandle(m_timerEventHandle); + } +} + + +DrThreadPool::DrThreadPool() +{ + memset(&m_ov, 0, sizeof(m_ov)); + m_completionPortHandle=NULL; + m_createdThreadCount=0; + m_rgWorkerThreads = NULL; + m_threadsShouldQuit=false; + m_hashedThreadCount = 0; + m_rgHashedThreads = NULL; + m_nextRandomHash = 0L; + m_numHashBuckets = 0; +} + + +DrThreadPool::~DrThreadPool() +{ + LogAssert(m_completionPortHandle==NULL); + LogAssert(m_createdThreadCount == 0); + LogAssert(m_hashedThreadCount == 0); + + if (m_rgWorkerThreads != NULL) + { + delete[] m_rgWorkerThreads; + } + + if (m_rgHashedThreads != NULL) + { + delete[] m_rgHashedThreads; + } +} + + +DrError DrThreadPool::Initialize(int initialThreadCount, int numHashedThreads) +{ + DrError err = DrError_OK; + + LogAssert(m_rgWorkerThreads==NULL); + LogAssert(m_rgHashedThreads==NULL); + LogAssert(m_createdThreadCount==0); + LogAssert(m_hashedThreadCount==0); + LogAssert(m_completionPortHandle==NULL); + LogAssert(m_threadsShouldQuit==false); + +/* JC LogAsserts below will catch + if (initialThreadCount < 0) { + initialThreadCount = (int)g_pDryadConfig->GetNumProcessors(); + } + if (numHashedThreads < 0) { + numHashedThreads= (int)g_pDryadConfig->GetNumProcessors(); + } +*/ + + m_numHashBuckets = numHashedThreads; + + LogAssert(initialThreadCount >=0 && initialThreadCount < 1000000); + LogAssert(numHashedThreads >=0 && numHashedThreads < 1000000); + LogAssert(initialThreadCount + numHashedThreads > 0); + + if (initialThreadCount != 0) { + m_completionPortHandle = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, NULL, 0); + if (m_completionPortHandle==NULL) + { + err = DrGetLastError(); + goto ExitError; + } + } + + if (numHashedThreads != 0) { + m_rgHashedThreads = new DrHashThread *[(Size_t)numHashedThreads]; + LogAssert(m_rgHashedThreads != NULL); + memset(m_rgHashedThreads, 0, sizeof(DrHashThread *) * numHashedThreads); + for (int i = 0; i < numHashedThreads; i++) { + m_rgHashedThreads[i] = new DrHashThread(this, i); + LogAssert(m_rgHashedThreads[i] != NULL); + m_rgHashedThreads[i] ->m_hIocp = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, NULL, 0); + if (m_rgHashedThreads[i]->m_hIocp == NULL) + { + err = DrGetLastError(); + goto ExitError; + } + } + } + + //Spin up the timer thread + m_pTimerThread.Attach(new DrTimerThread(this)); + LogAssert(m_pTimerThread != NULL); + + // TODO: We should be able to write the threadpool so this timer isn't needed. + //Clever version would have a single threadpool thread nominating itself as a timing thread + //if no other threads are and sleeping for the appropriate period before firing timers + //Simpler version would have a single threadpool thread permanently being the timer thread + //and waking up every second or so to fire off timers. + //Come back to this when we've got better usage patterns for timers + err = m_pTimerThread->Start(); + if (err != DrError_OK) + { + goto ExitError; + } + + //Spin up all the worker threads we were asked to + + if (initialThreadCount != 0) { + m_rgWorkerThreads=new DrPoolThread *[(size_t) initialThreadCount]; + LogAssert(m_rgWorkerThreads != NULL); + memset(m_rgWorkerThreads, 0, sizeof(DrPoolThread *) * initialThreadCount); + + for ( ; m_createdThreadCount < initialThreadCount; m_createdThreadCount++) + { + m_rgWorkerThreads[m_createdThreadCount] = new DrPoolThread(this); + LogAssert(m_rgWorkerThreads[m_createdThreadCount] != NULL); + m_rgWorkerThreads[m_createdThreadCount]->m_hIocp = m_completionPortHandle; + err = m_rgWorkerThreads[m_createdThreadCount]->Start(); + if (err != DrError_OK) { + break; + } + } + + //If we couldn't spin up all required threads then treat that as an error + //and tear everything down + if (err != DrError_OK) { + goto ExitError; + } + } + + + if (numHashedThreads!= 0) { + for ( ; m_hashedThreadCount < numHashedThreads; m_hashedThreadCount++) { + err = m_rgHashedThreads[m_hashedThreadCount]->Start(); + if (err != DrError_OK) { + break; + } + } + + //If we couldn't spin up all required threads then treat that as an error + //and tear everything down + if (err != DrError_OK) { + goto ExitError; + } + } + + //Looks like everything is good! + return DrError_OK; + +ExitError: + + //Flag all threads read to test for quit + m_threadsShouldQuit=true; + + //If we managed to spin up any worker threads then tell them to quit + CloseWorkerThreads(); + + if (m_pTimerThread != NULL) { + m_pTimerThread->Signal(); + m_pTimerThread->WaitForTermination(); + m_pTimerThread = NULL; + } + + //Tear down the completition port and any allocated memory + if (m_completionPortHandle != NULL) + { + CloseHandle(m_completionPortHandle); + m_completionPortHandle=NULL; + } + + m_threadsShouldQuit=false; + return err; +} + + +DrError DrThreadPool::Deinitialize() +{ + //Should only be called if initialize succeeded, so assert that + LogAssert(m_createdThreadCount+ m_hashedThreadCount != 0); + LogAssert(m_pTimerThread != NULL); + + m_threadsShouldQuit=true; + + m_pTimerThread->Signal(); + m_pTimerThread->WaitForTermination(); + m_pTimerThread = NULL; + + DrError closeResult=CloseWorkerThreads(); + + if (m_completionPortHandle != NULL) { + CloseHandle(m_completionPortHandle); + m_completionPortHandle=NULL; + } + + m_threadsShouldQuit=false; + + return closeResult; +} + + +DrError DrThreadPool::CloseWorkerThreads() +{ + LogAssert(m_threadsShouldQuit==true); + + DrError closeResult=DrError_OK; + + // To make threads quit we send them a completition with 'this' as the overlapped + // pointer. That causes them to inspect the thread pool state which will + // allow them to spot the fact we want them to quit + + if (m_createdThreadCount != 0) { + //Wake all started threads up once each + for (int i=0; iWaitForTermination(c_CloseThreadsTimeout)) { + closeResult=DrError_Fail; + } + m_rgWorkerThreads[i]->DecRef(); + m_rgWorkerThreads[i]=NULL; + } + + m_createdThreadCount = 0; + } + + if (m_rgWorkerThreads != NULL) { + delete[] m_rgWorkerThreads; + m_rgWorkerThreads = NULL; + } + + if (m_hashedThreadCount != 0) { + //Wake all started threads up once each + for (int i=0; iGetThreadIocp(), 0, 0,&m_ov); + } + + for (int i=0; iWaitForTermination(c_CloseThreadsTimeout)) { + closeResult=DrError_Fail; + } + } + + m_hashedThreadCount = 0; + } + + if (m_rgHashedThreads != NULL) { + delete[] m_rgHashedThreads; + m_rgHashedThreads = NULL; + } + + return closeResult; +} + + +bool DrThreadPool::EnqueueJobWithStatus(DrJob *job, DWORD numBytes, ULONG_PTR key, DrError err) +{ + + LogAssert(job); + LogAssert(m_completionPortHandle); + + job->m_postedError = err; + HANDLE h = GetCompletionHandleForBucket(GetBucketOfJob(job)); + if (h == NULL) { + SetLastError(ERROR_NOT_SUPPORTED); + return false; + } + job->SetDefaultThreadPool(this); + return (PostQueuedCompletionStatus( + h, + numBytes, + key, + job->GetOverlapped()) != 0); +} + + +bool DrThreadPool::AssociateHandleWithPool(HANDLE fileHandle, ULONG_PTR key) +{ + LogAssert(fileHandle!=NULL && fileHandle!=INVALID_HANDLE_VALUE); + HANDLE h = GetCompletionHandleForBucket(-1); + if (h == NULL) { + SetLastError(ERROR_NOT_SUPPORTED); + return false; + } + + return (CreateIoCompletionPort(fileHandle, h, key, 0) == h); +} + +bool DrThreadPool::AssociateHandleWithPoolAndHash(HANDLE fileHandle, ULONG_PTR key, const DrJobHash& jobHash) +{ + LogAssert(fileHandle!=NULL && fileHandle!=INVALID_HANDLE_VALUE); + HANDLE h = GetCompletionHandleForBucket(GetBucketOfHash(jobHash)); + if (h == NULL) { + SetLastError(ERROR_NOT_SUPPORTED); + return false; + } + + return (CreateIoCompletionPort(fileHandle, h, key, 0) == h); +} + +HANDLE DrThreadPool::GetCompletionHandleForBucket(int iBucket) +{ + HANDLE h = NULL; + if (iBucket < 0) { + if (m_completionPortHandle != NULL) { + h = m_completionPortHandle; + } else { + // This pool has no non-hashed threads. Pick a random hashed thread to run the job in. + DrJobHash jobHash; + jobHash.SetSequentialHash(); + iBucket = GetBucketOfHash(jobHash); + h = m_rgHashedThreads[iBucket]->GetThreadIocp(); + } + } else if (iBucket < m_numHashBuckets) { + h = m_rgHashedThreads[iBucket]->GetThreadIocp(); + } + return h; +} + + //Schedule a timer to run a specific number of msec from now +void DrThreadPool::ScheduleTimerMs(DrJob * timer, DWORD delay) + { + LogAssert(m_pTimerThread != NULL); + m_pTimerThread->ScheduleTimerMs(timer, delay); + } + + //Schedule a timer to run a specific number of msec from now +void DrTimerThread::ScheduleTimerMs(DrJob * timer, DWORD delay) +{ + bool wakeTimerThread=false; + DWORD currentTime=GetTickCount(); + + //Make sure we don't clash with the timer thread + Lock(); + + LogAssert(timer->m_isActiveTimer==false); + + //Set state of timer and insert it into heap + timer->m_isActiveTimer=true; + timer->m_expiryTime=currentTime+delay; + timer->SetDefaultThreadPool(m_pPool); + m_timerHeap.InsertHeapEntry(timer); + + //Now if we just inserted that at the root of the heap (i.e. New first timer) AND + //the timer thread is currently sleeping to infinity OR + //the time before it wakes up is too long then give it a kick + if ((DrJob * ) m_timerHeap.PeekHeapRoot()==timer && + (m_dwTimerThreadSleepPeriod==INFINITE || + ((int ) (m_timerThreadWakesAt-currentTime))>c_TimerExpirySlopPeriod)) + { + wakeTimerThread=true; + } + + Unlock(); + + if (wakeTimerThread) + { + SetEvent(m_timerEventHandle); + } +} + + +bool DrTimerThread::CancelTimer(DrJob * timer) +{ + bool cancelledOK=false; + + //Make sure we don't clash with the timer thread + Lock(); + + //If the timer isn't active then its already being processed + //Can't do much about that case + //Otherwise we should pull it from heap + + if (timer->m_isActiveTimer) + { + timer->m_isActiveTimer=false; + m_timerHeap.RemoveHeapEntry(timer->m_heapIndex); + cancelledOK=true; + } + + Unlock(); + + return cancelledOK; + +} + +bool DrThreadPool::CancelTimer(DrJob * timer) +{ + return m_pTimerThread->CancelTimer(timer); +} + + +DWORD DrPoolWorkerThread::PoolWorkerEntry() +{ + int iBucket = GetJobHashBucket(); + HANDLE iocp = GetThreadIocp(); + LogAssert(iocp != NULL); + + while (true) + { + m_pCurrentJob = NULL; + OVERLAPPED *overlapped = NULL; + ULONG_PTR completionKey = NULL; + DWORD bytesTransferred = 0; + + //TODO: We need to integrate timer functionality here. + //Threads should check if there are any timers scheduled and + //if no other thread is going to service them set their + //timeout so they can do it + + BOOL success = GetQueuedCompletionStatus( + iocp, + &bytesTransferred, + &completionKey, + &overlapped, + INFINITE + ); + + //If the overlapped structure is this object then thats our internal + //signal rather than an externally submitted job + if (overlapped == m_pPool->GetCommonOv()) + { + //We should check the thread pool state + if (m_pPool->ShouldQuit()) + { + return 0; + } + } + else + { + if (success) + { + //Looks like we got a valid job completed + LogAssert(overlapped != NULL); + + DrJob *pJob = DrJob::MapOverlappedToJob(overlapped); + int iJobBucket = m_pPool->GetBucketOfJob(pJob); + if (iJobBucket >= 0 && iJobBucket != iBucket) { + // This job completed in the wrong thread (probably because the job specified + // a different hashcode than the hashcode that the I/O handle was bound to). + // Simply resubmit the job to complete in the proper thread. + m_pPool->EnqueueJobWithStatus(pJob, bytesTransferred, completionKey, pJob->m_postedError); + } else { + m_pCurrentJob = pJob; + m_currentJobHash = pJob->GetJobHash(); + if (pJob->m_postedError == DrError_OK) { + pJob->JobReady(bytesTransferred, completionKey); + } else { + pJob->JobFailed(bytesTransferred, completionKey, pJob->m_postedError); + } + m_pCurrentJob = NULL; + } + } + else + { + //Couldn't get a good completition + DrError lastError=DrGetLastError(); + LogAssert(lastError != DrError_OK); + + //If we've got an associated job then tell it + if (overlapped != NULL) + { + DrJob *pJob = DrJob::MapOverlappedToJob(overlapped); + int iJobBucket = m_pPool->GetBucketOfJob(pJob); + if (iJobBucket >= 0 && iJobBucket != iBucket) { + // This job completed in the wrong thread (probably because the job specified + // a different hashcode than the hashcode that the I/O handle was bound to). + // Simply resubmit the job to complete in the proper thread. + m_pPool->EnqueueJobWithStatus(pJob, bytesTransferred, completionKey, lastError); + } else { + m_pCurrentJob = pJob; + m_currentJobHash = pJob->GetJobHash(); + pJob->JobFailed(bytesTransferred, completionKey, lastError); + m_pCurrentJob = NULL; + } + } + else + { + //TODO: What can we do sensibly in this scenario? + //Are there specific cases we should be able to handle? + LogAssert(overlapped != overlapped); + } + } + + } + } +} + +DWORD DrPoolThread::ThreadEntry() +{ + return PoolWorkerEntry(); +} + +void DrHashThread::UpdateTagAndDescription() +{ + m_strDescription.SetF("%s %8u(%d)", m_strClass.GetString(), m_dwThreadId, m_iBucket); + m_strTag.SetF("%s %8u(%d)", m_strShortClass.GetString(), m_dwThreadId, m_iBucket); +} + +DWORD DrHashThread::ThreadEntry() +{ + return PoolWorkerEntry(); +} + +DWORD DrTimerThread::ThreadEntry() +{ + DWORD currentTime; + int timeToExpire; + DrJob * timer; + + while (true) + { + //Analyse the current state of the timer heap + Lock(); + + currentTime=GetTickCount(); + while (true) + { + timer = (DrJob*) m_timerHeap.PeekHeapRoot(); + + //If there are no more timers on the heap then our timeout on our event + //is infinite and we're done checking the heap + if (timer==NULL) + { + m_dwTimerThreadSleepPeriod=INFINITE; + break; + } + + //Looks like we've got a timer. Work out when it'll expire relative to now + LogAssert(timer->m_isActiveTimer); + timeToExpire=(int ) (timer->m_expiryTime-currentTime); + + //If it hasn't expired yet then that should be our timeout and we're done + //checking the heap + if (timeToExpire>0) + { + m_dwTimerThreadSleepPeriod=(DWORD)timeToExpire; + break; + } + + //Looks like we've got an expired timer. Mark is as no longer active, pop it from heap + //and pass it to a worker thread for processing + timer->m_isActiveTimer=false; + m_timerHeap.DequeueHeapRoot(); + BOOL fSuccess = m_pPool->EnqueueJob(timer, currentTime, NULL); + LogAssert(fSuccess); + } + + //Little wrinkle here. We don't want to wake up too often and we don't care + //to be all that accurate for expiring timers. Therefore, if timeout is + //too small we'll push it out in the hope of expiring more timers at once + if (m_dwTimerThreadSleepPeriodShouldQuit()) + { + break; + } + + } + + //Treat any timers still scheduled as if they were cancelled + Lock(); + + while (true) + { + timer = (DrJob*) m_timerHeap.DequeueHeapRoot(); + if (timer==NULL) + { + break; + } + LogAssert(timer->m_isActiveTimer==true); + timer->m_isActiveTimer=false; + } + + Unlock(); + + return 0; +} diff --git a/DryadVertex/VertexHost/system/classlib/src/fingerprint.cpp b/DryadVertex/VertexHost/system/classlib/src/fingerprint.cpp new file mode 100644 index 0000000..0b56f61 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/fingerprint.cpp @@ -0,0 +1,82 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "fingerprint.h" + + +#pragma unmanaged + +UInt64 FingerPrint64::FingerPrint64Init::count; +FingerPrint64* FingerPrint64::instance; + +FingerPrint64::FingerPrint64Init::FingerPrint64Init() +{ + if (0 == count) + { + FingerPrint64::Init(); + } + ++count; +} + +FingerPrint64::FingerPrint64Init::~FingerPrint64Init() +{ +//JC assert(count > 0); + --count; + if (0 == count) + { + FingerPrint64::Dispose(); + } +} + +void FingerPrint64::Init() +{ + FingerPrint64::instance = new FingerPrint64(); +} + +void FingerPrint64::Dispose() +{ + delete FingerPrint64::instance; + FingerPrint64::instance = 0; +} + +FingerPrint64* FingerPrint64::GetInstance() +{ + return FingerPrint64::instance; +} + +UInt64 FingerPrint64::GetFingerPrint(const void *data, const size_t length) +{ + return ms_fprint_of(this->fp, (void*) data, (size_t) length); +} + +FingerPrint64::FingerPrint64(UInt64 poly) +{ + this->fp = ::ms_fprint_new(poly); +} + +FingerPrint64::FingerPrint64(void) +{ + this->fp = ::ms_fprint_new(); +} + +FingerPrint64::~FingerPrint64(void) +{ + ::ms_fprint_destroy(this->fp); +} diff --git a/DryadVertex/VertexHost/system/classlib/src/ms_fprint.cpp b/DryadVertex/VertexHost/system/classlib/src/ms_fprint.cpp new file mode 100644 index 0000000..6d013c9 --- /dev/null +++ b/DryadVertex/VertexHost/system/classlib/src/ms_fprint.cpp @@ -0,0 +1,162 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/* (c) Microsoft Corporation. All rights reserved. */ + +#include +#include "ms_fprint.h" + +#pragma unmanaged + +#if defined(pdp11) || defined(vax) || defined(__alpha) || defined(i386) || defined(__i386) || defined(__i386__) || defined(_M_IX86) || defined(MIPSEL) || defined(_MSC_VER) +#define MS_ENDIAN_LITTLE 1 +#endif +#if defined(__sparc__) || defined(MIPSEB) || defined(__ppc__) +#define MS_ENDIAN_LITTLE 0 +#endif + +#if !defined(MS_ENDIAN_LITTLE) +static short _endian_little = 1; +#define MS_ENDIAN_LITTLE (*(char *)&_endian_little) +#endif + +#define MS_ENDIAN_BIG (!MS_ENDIAN_LITTLE) + +#define BYTESWAP_FP(_x) \ + ( \ + ((_x) << 56) | \ + ((_x) >> 56) | \ + (((_x) & 0x0000ff00UL) << 40) | \ + (((_x) >> 40) & 0x0000ff00UL) | \ + (((_x) & 0x00ff0000UL) << 24) | \ + (((_x) >> 24) & 0x00ff0000UL) | \ + (((_x) & 0xff000000UL) << 8) | \ + (((_x) >> 8) & 0xff000000UL) \ + ) + +static const ms_fprint_t the_poly = (((ms_fprint_t)0xa795d0f2UL) << 32) + | (ms_fprint_t)0x9b4dcdf8UL; + +struct ms_fprint_data_s { + ms_fprint_t poly[2]; /* poly[0] = 0; poly[1] = polynomial */ + ms_fprint_t empty; /* fingerprint of the empty string */ + ms_fprint_t bybyte[8][256]; /* bybyte[b][i] is i*X^(64+8*b) mod poly[1] */ + ms_fprint_t bybyte_r[8][256]; /* bybyte[b][i] is i*X^(64+8*b) mod poly[1], byte-swapped */ +}; + +static void initbybyte (ms_fprint_data_t fp, + ms_fprint_t bybyte[][256], + ms_fprint_t f) { + int b; + for (b = 0; b != 8; b++) { + int i; + bybyte[b][0] = 0; + for (i = 0x80; i != 0; i >>= 1) { + bybyte[b][i] = f; + f = fp->poly[f & 1] ^ (f >> 1); + } + for (i = 1; i != 256; i <<= 1) { + ms_fprint_t xf = bybyte[b][i]; + int k; + for (k = 1; k != i; k++) { + bybyte[b][i+k] = xf ^ bybyte[b][k]; + } + } + } +} + +static void ms_fprint_init (ms_fprint_data_t fp, ms_fprint_t poly) { + int i, j; + fp->poly[0] = 0; + fp->poly[1] = poly; /*This must be initialized early on */ + fp->empty = poly; + initbybyte (fp, fp->bybyte, poly); + for (i = 0; i < 8; i++) + for (j = 0; j < 256; j++) + fp->bybyte_r[i][j] = BYTESWAP_FP(fp->bybyte[i][j]); +} + +ms_fprint_data_t ms_fprint_new (ms_fprint_t poly) { + ms_fprint_data_t fp = (ms_fprint_data_t) malloc (sizeof (*fp)); + ms_fprint_init(fp, poly); + return fp; +} + +ms_fprint_data_t ms_fprint_new () { + return ms_fprint_new(the_poly); +} + +ms_fprint_t +ms_fprint_of (ms_fprint_data_t fp, + void *data, + size_t len ) { + unsigned char *p = (unsigned char *)data; + unsigned char *e = p+len; + ms_fprint_t init = fp->empty; + while (p != e && (((ptrdiff_t) p) & 7L) != 0) { + init = (init >> 8) ^ fp->bybyte[0][(init & 0xff) ^ *p++]; + } + if (MS_ENDIAN_LITTLE) { + while (p+8 <= e) { + init ^= *(ms_fprint_t *)p; + init = fp->bybyte[7][init & 0xff] ^ + fp->bybyte[6][(init >> 8) & 0xff] ^ + fp->bybyte[5][(init >> 16) & 0xff] ^ + fp->bybyte[4][(init >> 24) & 0xff] ^ + fp->bybyte[3][(init >> 32) & 0xff] ^ + fp->bybyte[2][(init >> 40) & 0xff] ^ + fp->bybyte[1][(init >> 48) & 0xff] ^ + fp->bybyte[0][init >> 56]; + p += 8; + } + } else if (p+8 <= e) { + init = BYTESWAP_FP (init); + while (p+16 <= e) { + init ^= *(ms_fprint_t *)p; + init = fp->bybyte_r[0][init & 0xff] ^ + fp->bybyte_r[1][(init >> 8) & 0xff] ^ + fp->bybyte_r[2][(init >> 16) & 0xff] ^ + fp->bybyte_r[3][(init >> 24) & 0xff] ^ + fp->bybyte_r[4][(init >> 32) & 0xff] ^ + fp->bybyte_r[5][(init >> 40) & 0xff] ^ + fp->bybyte_r[6][(init >> 48) & 0xff] ^ + fp->bybyte_r[7][init >> 56]; + p += 8; + } + init ^= *(ms_fprint_t *)p; + init = fp->bybyte[0][init & 0xff] ^ + fp->bybyte[1][(init >> 8) & 0xff] ^ + fp->bybyte[2][(init >> 16) & 0xff] ^ + fp->bybyte[3][(init >> 24) & 0xff] ^ + fp->bybyte[4][(init >> 32) & 0xff] ^ + fp->bybyte[5][(init >> 40) & 0xff] ^ + fp->bybyte[6][(init >> 48) & 0xff] ^ + fp->bybyte[7][init >> 56]; + p += 8; + } + + while (p != e) { + init = (init >> 8) ^ fp->bybyte[0][(init & 0xff) ^ *p++]; + } + return (init); +} +void ms_fprint_destroy (ms_fprint_data_t fp) { + free (fp); +} diff --git a/DryadVertex/VertexHost/system/common/common.vcxproj b/DryadVertex/VertexHost/system/common/common.vcxproj new file mode 100644 index 0000000..12a222d --- /dev/null +++ b/DryadVertex/VertexHost/system/common/common.vcxproj @@ -0,0 +1,186 @@ + + + + + Debug + Win32 + + + Debug + x64 + + + Release + Win32 + + + Release + x64 + + + + {57663B94-E11B-431E-BE4B-E2C61112DEC5} + common + Win32Proj + + + + StaticLibrary + + + StaticLibrary + + + StaticLibrary + + + StaticLibrary + Unicode + true + + + + + + + + + + + + + + + + + + + <_ProjectFileVersion>10.0.40219.1 + Debug\ + Debug\ + $(Platform)\$(Configuration)\ + $(Platform)\$(Configuration)\ + Release\ + Release\ + $(Platform)\$(Configuration)\ + $(Platform)\$(Configuration)\ + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + + + + Disabled + WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions) + true + EnableFastChecks + MultiThreadedDebugDLL + + + Level3 + EditAndContinue + + + + + X64 + + + Disabled + include;..\common\include;..\classlib\include;%(AdditionalIncludeDirectories) + WIN32;_DEBUG;_LIB;WIN32_LEAN_AND_MEAN;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + false + Default + MultiThreadedDebugDLL + + + Level3 + ProgramDatabase + + + + + WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + + + + + X64 + + + WIN32;NDEBUG;_LIB;WIN32_LEAN_AND_MEAN;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + include;..\common\include;..\classlib\include;%(AdditionalIncludeDirectories) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/DryadVertex/VertexHost/system/common/include/CsEnhancedTimer.h b/DryadVertex/VertexHost/system/common/include/CsEnhancedTimer.h new file mode 100644 index 0000000..a3dfcc0 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/CsEnhancedTimer.h @@ -0,0 +1,241 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#ifndef __DRYADENHANCEDTIMER_H__ +#define __DRYADENHANCEDTIMER_H__ + +/* + * DrEnhancedTimer + * + * The enhanced timer class improves the standard timer class by guaranteeing + * that it can always be cancelled and that it can always be rescheduled whilst + * its still running. + * + * To use this you must supply an owner object type that supports the methods: + * void IncRef() + * void DecRef() both for reference counting + * void EnhancedTimerFired(DrEnhancedTimer * ,DWORD ) for receiving timer fired events + * + * Locking conventions + * The owner of an enhanced timer must supply a critical section lock that guards + * its timing state. + * This lock must be held by the owner when it schedules/cancels the timer + * This lock will be taken by the timer prior to calling back into the owner + * + * Reference counting + * The timer will take and hold references to the owner while it has outstanding + * threadpool operations in progress. + */ + +//Disable warning about non standard extentions +//We're using nameless structs/unions and compiler gets whiny about them +#pragma warning (push) +#pragma warning (disable:4201) + + +template +class DrEnhancedTimer : public DrTimer +{ +public: + + //Standard c'tor. Must call Initialize prior to using timer if you use this + inline DrEnhancedTimer(); + + //Construct and initialize timer + inline DrEnhancedTimer(Owner * pOwner, DrThreadPool * pThreadPool, DrCriticalSection * pLock); + + //Initialize timer state if standard c'tor was used + inline void Initialize(Owner * pOwner, DrThreadPool * pThreadPool, DrCriticalSection * pLock); + + //Schedule the timer to fire in dwPeriod msec from now. + //This will cancel and reschedule the timer if its already running + //Lock supplied to c'tor/Initialize must be held on this call + inline void Schedule(DWORD dwPeriod); + + //Cancel the timer + //Lock supplied to c'tor/Initialize must be held on this call + inline void Cancel(); + + +private: + + inline void TimerFired(DWORD dwFiredTime); + + + //Time that Owner wants the timer to fire at + DWORD m_dwExpiryTime; + + //Thread pool timer is currently scheduled against + DrThreadPool * m_pThreadPool; + + //Object using this timer + Owner * m_pOwner; + + //Lock created by owner to guard its timing state + DrCriticalSection * m_pLock; + + union + { + DWORD m_dwFlags; + struct + { + //Set to TRUE if the timer is currently running + DWORD m_fTimerRunning : 1; + //Set to TRUE if Owner really wants the timer callback + DWORD m_fTimerRequired : 1; + }; + }; + +}; + + +/* + * Inline methods from DrEnhancedTimer + */ + +template +DrEnhancedTimer::DrEnhancedTimer() +{ + m_pThreadPool=NULL; + m_pOwner=NULL; + m_pLock=NULL; + m_dwFlags=0; +} + + +template +DrEnhancedTimer::DrEnhancedTimer(Owner * pOwner, DrThreadPool * pThreadPool, DrCriticalSection * pLock) +{ + m_pThreadPool=pThreadPool; + m_pOwner=pOwner; + m_pLock=pLock; + m_dwFlags=0; +} + +template +void DrEnhancedTimer::Initialize(Owner * pOwner, DrThreadPool * pThreadPool, DrCriticalSection * pLock) +{ + LogAssert(m_pOwner==NULL); + LogAssert(m_pThreadPool==NULL); + LogAssert(m_dwFlags==0); + + m_pThreadPool=pThreadPool; + m_pOwner=pOwner; + m_pLock=pLock; +} + + +template +void DrEnhancedTimer::Schedule(DWORD dwPeriod) +{ + DebugLogAssert( m_pLock->Aquired() ); + + //Whatever happens we want a timer to fire + m_fTimerRequired=TRUE; + + //Compute when we want the timer to expire + m_dwExpiryTime=GetTickCount()+dwPeriod; + + //If the timer is already scheduled then try and cancel it + if (m_fTimerRunning) + { + if (m_pThreadPool->CancelTimer(this)==FALSE) + { + //Let it fire and detect the bogus expiry time at that point + return; + } + } + else + { + //Timer wasn't previously running so store fact it now will be + //and that owner needs to stick around + m_fTimerRunning=TRUE; + m_pOwner->IncRef(); + } + + m_pThreadPool->ScheduleTimerMs(this, dwPeriod); +} + + +template +void DrEnhancedTimer::Cancel() +{ + DebugLogAssert( m_pLock->Aquired() ); + + //Whatever happens owner doesn't want a timer running + m_fTimerRequired=FALSE; + + //If its not scheduled then we're done, OR + //If we fail to cancel it then we'll just have to let it fire and + //spot that its no longer requested then + if (m_fTimerRunning==FALSE || m_pThreadPool->CancelTimer(this)==FALSE) + { + return; + } + + //Looks like we cancelled it OK + m_fTimerRunning=FALSE; + m_pOwner->DecRef(); +} + + +template +void DrEnhancedTimer::TimerFired(DWORD dwFiredTime) +{ + m_pLock->Enter(); + + LogAssert(m_fTimerRunning); + + m_fTimerRunning=FALSE; + + //If owner didn't currently want a timer then we're done + if (m_fTimerRequired==FALSE) + { + m_pLock->Leave(); + m_pOwner->DecRef(); + return; + } + + //If this isn't the right time for the timer to fire then reschedule it + //for the right time + int iTimeRemaining=(int ) (m_dwExpiryTime-dwFiredTime); + if (iTimeRemaining>0) + { + m_pThreadPool->ScheduleTimerMs(this, iTimeRemaining); + m_fTimerRunning=TRUE; + m_pLock->Leave(); + return; + } + + //Looks like we've got a good timer that we want to process + m_fTimerRequired=FALSE; + + //Tell the owner a timer has expired + m_pOwner->EnhancedTimerFired(this, dwFiredTime); + + m_pLock->Leave(); + m_pOwner->DecRef(); + +} + +#pragma warning (pop) + +#endif //end if not defined __DRYADENHANCEDTIMER_H__ + diff --git a/DryadVertex/VertexHost/system/common/include/DObjPool.h b/DryadVertex/VertexHost/system/common/include/DObjPool.h new file mode 100644 index 0000000..e0977ee --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/DObjPool.h @@ -0,0 +1,260 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +class DObjPoolBase; + +class DObjPoolCache; + +/* the DObjPoolThreadPrivateBlock is only ever read or modified by the + thread which owns it, so needs no synchronization on any of its + members */ +class DObjPoolThreadPrivateBlock +{ +public: + DObjPoolThreadPrivateBlock(); + ~DObjPoolThreadPrivateBlock(); + + DObjPoolCache* LookUpPoolCache(DObjPoolBase* pool, LONGLONG poolKey); + void AddPoolCache(DObjPoolBase* pool, LONGLONG poolKey, + DObjPoolCache* cache); + void GarbageCollectCaches(); + +private: + class Entry + { + public: + Entry(); + ~Entry(); + void Initialize(DObjPoolBase* pool, LONGLONG poolKey, + DObjPoolCache* cache); + + DObjPoolBase* m_pool; + LONGLONG m_key; + DObjPoolCache* m_cache; + }; + + UInt32 m_entryArraySize; + UInt32 m_numberOfEntries; + Entry* m_entry; +}; + +/* the DObjPoolCache is a thread-local pool of cached objects from the + pool. It is generally only accessed by the local thread, so most + methods need no synchronization. During pool cleanup it can be + accessed by another thread which is holding the pool's lock. + Correctly written code will ensure that pool cleanup will never + occur unless all outstanding pool objects have been handed back, + and this handback must be communicated to the cleanup thread using + correctly synchronized operations. This ensures that no object can + be added to or removed from the cache once cleanup begins. After + cleanup, Abandoned() will return true, and the thread-local code + may subsequently garbage-collect the cache. +*/ +class DObjPoolCache +{ +public: + DObjPoolCache(UInt32 maxEntries, UInt32 keepEntryCount, + DObjPoolBase* pool); + ~DObjPoolCache(); + + bool Abandoned(); + + void InsertObject(void* o); + void* RemoveObject(); + +private: + /* this is the only method which can be called on a thread other + than the local owner thread, and this only happens during + cleanup */ + void ReturnToPool(bool finalCleanup); + + bool m_abandoned; + DObjPoolBase* m_pool; + UInt32 m_maxEntries; + UInt32 m_keepEntryCount; + UInt32 m_numberOfEntries; + void** m_array; + + friend class DObjPoolBase; +}; + +class DObjFactoryBase : public IDrRefCounter +{ +public: + virtual void* AllocateObjectUntyped() = 0; + virtual void FreeObjectUntyped(void* object) = 0; +}; + +typedef DrRef DObjFactoryRef; + +class DObjPoolBase : public DObjFactoryBase +{ +public: + DObjPoolBase(DObjFactoryBase* factory, + UInt32 maxCentralEntries, + UInt32 maxLocalEntries, UInt32 localKeepEntryCount); + ~DObjPoolBase(); + + void RemoveObjects(void** dst, UInt32 countNeeded); + void AcceptObjects(void** src, UInt32 count); + + void* AllocateObjectUntyped(); + void FreeObjectUntyped(void* object); + +protected: + DObjFactoryBase* DetachFactory(); + +private: + DObjPoolCache* MakeCache(); + DObjPoolCache* FetchPrivateCache(); + + DrRef m_factory; + UInt32 m_maxCentralEntries; + UInt32 m_maxLocalEntries; + UInt32 m_localKeepEntryCount; + + UInt32 m_numberOfCentralEntries; + void** m_array; + + UInt32 m_cacheArraySize; + UInt32 m_numberOfCaches; + DObjPoolCache** m_cache; + + UInt64 m_totalGivenOut; + UInt64 m_totalReturned; + UInt64 m_totalAllocated; + UInt64 m_totalFreed; + + LONGLONG m_key; + CRITSEC m_atomic; +}; + +template< class T_ > class DrRefFactory : public DObjFactoryBase +{ +public: + virtual ~DrRefFactory() {} + virtual void AllocateObject(DrRef* pObject) = 0; + virtual void FreeObject(DrRef& object) + { + object = NULL; + } + +private: + void* AllocateObjectUntyped() + { + DrRef typedObject; + AllocateObject(&typedObject); + return typedObject.Detach(); + } + + void FreeObjectUntyped(void* object) + { + DrRef typedObject; + typedObject.Attach((T_ *) object); + FreeObject(typedObject); + } +}; + +template< class T_ > class StdRefPoolFactory : public DrRefFactory +{ +public: + StdRefPoolFactory() {} + StdRefPoolFactory(DObjPoolBase* pool) + { + InitializePoolFactory(pool); + } + + void InitializePoolFactory(DObjPoolBase* pool) + { + m_pool = pool; + } + + void AllocateObject(DrRef* pObject) + { + pObject->Attach(new T_(m_pool)); + } + + void FreeObject(DrRef& object) + { + object.Detach()->PoolFreeMemory(); + } + +private: + DrRef m_pool; + + DRREFCOUNTIMPL +}; + +template< class T_ > class DrRefPool : public DObjPoolBase +{ +public: + DrRefPool(DrRefFactory* factory, + UInt32 maxCentralEntries, + UInt32 maxLocalEntries, UInt32 localKeepEntryCount) : + DObjPoolBase(factory, maxCentralEntries, + maxLocalEntries, localKeepEntryCount) + { + } + + void InsertObject(DrRef& object) + { + FreeObjectUntyped(object.Detach()); + } + + void RemoveObject(DrRef* pObject) + { + pObject->Attach((T_ *) AllocateObjectUntyped()); + } + + DRREFCOUNTIMPL +}; + +template< class T_ > class StdRefPool : public DObjPoolBase +{ +public: + StdRefPool(UInt32 maxCentralEntries, + UInt32 maxLocalEntries, UInt32 localKeepEntryCount) : + DObjPoolBase(new StdRefPoolFactory(), maxCentralEntries, + maxLocalEntries, localKeepEntryCount) + { + } + + ~StdRefPool() + { + DObjFactoryBase* factory = DetachFactory(); + delete factory; + } + + void InsertObject(DrRef& object) + { + FreeObjectUntyped(object.Detach()); + } + + void RemoveObject(DrRef* pObject) + { + pObject->Attach((T_ *) AllocateObjectUntyped()); + } + + DRREFCOUNTIMPL +}; diff --git a/DryadVertex/VertexHost/system/common/include/cosmospropertyblock.h b/DryadVertex/VertexHost/system/common/include/cosmospropertyblock.h new file mode 100644 index 0000000..cfa6ef9 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/cosmospropertyblock.h @@ -0,0 +1,59 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +class DrPnSetProcessPropertyRequest; +class DrPnGetProcessPropertyResponse; +class DrProcessPropertyInfo; + +class DryadDryadPnProcessPropertyRequest : + public DryadPnProcessPropertyRequest +{ +public: + DryadDryadPnProcessPropertyRequest(DrPnSetProcessPropertyRequest* + request); + + void SetPropertyLabel(const char* label, const char* controlLabel); + void SetPropertyString(const char* string); + DrMemoryBuffer* GetPropertyBlock(); + + DrPnSetProcessPropertyRequest* GetMessage(); + +private: + DrRef m_message; +}; + +class DryadDryadPnProcessPropertyResponse : + public DryadPnProcessPropertyResponse +{ +public: + DryadDryadPnProcessPropertyResponse(DrPnGetProcessPropertyResponse* + response); + + void RetrievePropertyLabel(const char* label); + DrMemoryBuffer* GetPropertyBlock(); + +private: + DrPnGetProcessPropertyResponse* m_message; + DrProcessPropertyInfo* m_info; +}; diff --git a/DryadVertex/VertexHost/system/common/include/cosmosstreampropertyupdater.h b/DryadVertex/VertexHost/system/common/include/cosmosstreampropertyupdater.h new file mode 100644 index 0000000..d0dc476 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/cosmosstreampropertyupdater.h @@ -0,0 +1,30 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +class DryadStreamPropertyUpdater : public StreamPropertyUpdater +{ +private: + void SendUpdate(StreamHolder* holder); + void SendWdUpdate(ProcessHolder* holder); +}; diff --git a/DryadVertex/VertexHost/system/common/include/dryadcosmosresources.h b/DryadVertex/VertexHost/system/common/include/dryadcosmosresources.h new file mode 100644 index 0000000..77c94cb --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryadcosmosresources.h @@ -0,0 +1,53 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +class DryadDryadProcessIdentifier : public DryadProcessIdentifier +{ +public: + DryadDryadProcessIdentifier(DrGuid* guid); + + DrGuid* GetGuid(); + const char* GetGuidString(); + + void SetWorkingDirectory(const char* workingDirectory); + void MakeURIForRelativeFile(DrStr* dst, + const char* baseDirectory, + const char* relativeFileName); + +private: + DrGuid m_guid; + DrStr32 m_guidString; + DrStr128 m_workingDirectory; +}; + +class DryadDryadMachineIdentifier : public DryadMachineIdentifier +{ +public: + DryadDryadMachineIdentifier(DrServiceDescriptor* desc); + + DrServiceDescriptor* GetServiceDescriptor(); + +private: + DrServiceDescriptor m_serviceDescriptor; +}; diff --git a/DryadVertex/VertexHost/system/common/include/dryaderror.h b/DryadVertex/VertexHost/system/common/include/dryaderror.h new file mode 100644 index 0000000..3b0717b --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryaderror.h @@ -0,0 +1,59 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + + +// This file must contain *only* DEFINE_DRYAD_ERROR directives! +// +// +// It is included multiple times with different macro definitions. + +// Dummy error +DEFINE_DRYAD_ERROR (DryadError_BadMetaData, DRYAD_ERROR (0x0001), "Bad MetaData XML") +DEFINE_DRYAD_ERROR (DryadError_InvalidCommand, DRYAD_ERROR (0x0002), "Invalid Command") +DEFINE_DRYAD_ERROR (DryadError_VertexReceivedTermination, DRYAD_ERROR (0x0003), "Vertex Received Termination") +DEFINE_DRYAD_ERROR (DryadError_InvalidChannelURI, DRYAD_ERROR (0x0004), "Invalid Channel URI syntax") +DEFINE_DRYAD_ERROR (DryadError_ChannelOpenError, DRYAD_ERROR (0x0005), "Channel Open Error") +DEFINE_DRYAD_ERROR (DryadError_ChannelRestartError, DRYAD_ERROR (0x0006), "Channel Restart Error") +DEFINE_DRYAD_ERROR (DryadError_ChannelWriteError, DRYAD_ERROR (0x0007), "Channel Write Error") +DEFINE_DRYAD_ERROR (DryadError_ChannelReadError, DRYAD_ERROR (0x0008), "Channel Read Error") +DEFINE_DRYAD_ERROR (DryadError_ItemParseError, DRYAD_ERROR (0x0009), "Item Parse Error") +DEFINE_DRYAD_ERROR (DryadError_ItemMarshalError, DRYAD_ERROR (0x0010), "Item Marshal Error") +DEFINE_DRYAD_ERROR (DryadError_BufferHole, DRYAD_ERROR (0x0011), "Buffer Hole") +DEFINE_DRYAD_ERROR (DryadError_ItemHole, DRYAD_ERROR (0x0012), "Item Hole") +DEFINE_DRYAD_ERROR (DryadError_ChannelRestart, DRYAD_ERROR (0x0013), "Channel Sent Restart") +DEFINE_DRYAD_ERROR (DryadError_ChannelAbort, DRYAD_ERROR (0x0014), "Channel Sent Abort") +DEFINE_DRYAD_ERROR (DryadError_VertexRunning, DRYAD_ERROR (0x0015), "Vertex Is Running") +DEFINE_DRYAD_ERROR (DryadError_VertexCompleted, DRYAD_ERROR (0x0016), "Vertex Has Completed") +DEFINE_DRYAD_ERROR (DryadError_VertexError, DRYAD_ERROR (0x0017), "Vertex Had Errors") +DEFINE_DRYAD_ERROR (DryadError_ProcessingError, DRYAD_ERROR (0x0018), "Error While Processing") +DEFINE_DRYAD_ERROR (DryadError_VertexInitialization, DRYAD_ERROR (0x0019), "Vertex Could Not Initialize") +DEFINE_DRYAD_ERROR (DryadError_ProcessingInterrupted, DRYAD_ERROR (0x001a), "Processing was interrupted before completion") +DEFINE_DRYAD_ERROR (DryadError_VertexChannelClose, DRYAD_ERROR (0x001b), "Errors during channel close") +DEFINE_DRYAD_ERROR (DryadError_AssertFailure, DRYAD_ERROR (0x001c), "Assertion Failure") +DEFINE_DRYAD_ERROR (DryadError_ExternalChannel, DRYAD_ERROR (0x001d), "External Channel") +DEFINE_DRYAD_ERROR (DryadError_AlreadyInitialized, DRYAD_ERROR (0x001e), "Dryad Already Initialized") +DEFINE_DRYAD_ERROR (DryadError_DuplicateVertices, DRYAD_ERROR (0x001f), "Duplicate Vertices") +DEFINE_DRYAD_ERROR (DryadError_ComposeRHSNeedsInput, DRYAD_ERROR (0x0020), "RHS of composition must have at least one input") +DEFINE_DRYAD_ERROR (DryadError_ComposeLHSNeedsOutput, DRYAD_ERROR (0x0021), "LHS of composition must have at least one output") +DEFINE_DRYAD_ERROR (DryadError_ComposeStagesMustBeDifferent, DRYAD_ERROR (0x0022), "Stages for composition must be different") +DEFINE_DRYAD_ERROR (DryadError_ComposeStageEmpty, DRYAD_ERROR (0x0023), "Stage for composition is empty") +DEFINE_DRYAD_ERROR (DryadError_VertexNotInGraph, DRYAD_ERROR (0x0024), "Vertex not in graph") +DEFINE_DRYAD_ERROR (DryadError_HardConstraintCannotBeMet, DRYAD_ERROR (0x0025), "Hard constraint cannot be met") +DEFINE_DRYAD_ERROR (DryadError_MustRequeue, DRYAD_ERROR (0x0026), "Must requeue process") diff --git a/DryadVertex/VertexHost/system/common/include/dryaderrordef.h b/DryadVertex/VertexHost/system/common/include/dryaderrordef.h new file mode 100644 index 0000000..ff52e10 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryaderrordef.h @@ -0,0 +1,35 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +#define FACILITY_DRYAD 778 +#define DRYAD_ERROR(n) ((HRESULT)(0x80000000 + (FACILITY_DRYAD << 16) + n)) + +#ifdef DEFINE_DRYAD_ERROR +#undef DEFINE_DRYAD_ERROR +#endif + +#define DEFINE_DRYAD_ERROR(name, number, description) static const DrError name = number; +#include "DryadError.h" + +#undef DEFINE_DRYAD_ERROR diff --git a/DryadVertex/VertexHost/system/common/include/dryadeventcache.h b/DryadVertex/VertexHost/system/common/include/dryadeventcache.h new file mode 100644 index 0000000..2a8211c --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryadeventcache.h @@ -0,0 +1,50 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +class DryadHandleListEntry +{ +public: + DryadHandleListEntry(HANDLE handle); + HANDLE GetHandle(); + +private: + HANDLE m_handle; + DrBListEntry m_listPtr; + friend class DryadBList; +}; + +class DryadEventCache +{ +public: + DryadEventCache(); + ~DryadEventCache(); + + DryadHandleListEntry* GetEvent(bool reset); + void ReturnEvent(DryadHandleListEntry* event); + +private: + typedef DryadBList HandleList; + + HandleList m_eventCache; +}; diff --git a/DryadVertex/VertexHost/system/common/include/dryadlisthelper.h b/DryadVertex/VertexHost/system/common/include/dryadlisthelper.h new file mode 100644 index 0000000..6837a4f --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryadlisthelper.h @@ -0,0 +1,72 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include + +template< class _T > class DryadBList : public DrBList +{ +public: + typedef _T EntryType; + + static EntryType* CastOut(DrBListEntry* item) + { + return (item == NULL) ? NULL : + (DR_GET_CONTAINER(EntryType, item, m_listPtr)); + } + + static DrBListEntry* CastIn(EntryType* item) + { + return &(item->m_listPtr); + } + + EntryType* GetNextTyped(EntryType* item) + { + return CastOut(GetNext(CastIn(item))); + } +}; + +template< class _T, class _B > class DryadBListDerived : public DrBList +{ +public: + typedef _T EntryType; + typedef _B BaseType; + typedef DryadBList< BaseType > BaseListType; + + static EntryType* CastOut(DrBListEntry* item) + { + return (EntryType *) BaseListType::CastOut(item); + } + + static DrBListEntry* CastIn(EntryType* item) + { + return BaseListType::CastIn(item); + } + + EntryType* GetNextTyped(EntryType* item) + { + return CastOut(GetNext(CastIn(item))); + } +}; + + + diff --git a/DryadVertex/VertexHost/system/common/include/dryadmetadata.h b/DryadVertex/VertexHost/system/common/include/dryadmetadata.h new file mode 100644 index 0000000..9b73cbc --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryadmetadata.h @@ -0,0 +1,288 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once +#pragma warning(disable:4512) +#pragma warning(disable:4511) +#pragma warning(disable:4995) + +#include +#include "dryadmetadatatag.h" +#include "dryadmetadatatagtypes.h" +#include +#include + +class DryadMetaDataConst; + +class DryadMetaData; +typedef DrRef DryadMetaDataRef; + +class DryadMetaData : public DrRefCounter +{ +public: + typedef std::list TagList; + typedef TagList::iterator TagListIter; + typedef std::multimap< UInt16,DryadMTag * > TagMap; + typedef TagMap::iterator TagMapIter; + + /* the following method places a reference to a new, empty, + DryadMetaData object in the caller's dstMetaData object. */ + static void Create(DryadMetaDataRef* dstMetaData); + + /* this appends tag to the end of self's tag list and transfers + the caller's reference to tag. If allowDuplicateTags is false + and there is a tag of the same name in self, tag is not + appended to self and its reference count is + decremented. Returns true if and only if tag was appended to + self. */ + bool Append(DryadMTag* tag, bool allowDuplicateTags); + + /* this appends all the tags in metaData to the end of self's tag + list without modifying metaData (incrementing each tag's + refcount). If allowDuplicateTags is false then any tags in + metaData which already exist in self are not appended. */ + void AppendMetaDataTags(DryadMetaData* metaData, bool allowDuplicateTags); + + /* this replaces oldTag with newTag in self's tag list. It is an + error to call Replace with oldTag and newTag which do not have + the same tag value. If oldTag is not in self's tag list this + returns false and does not alter any reference counts, + otherwise it returns true, transfers the caller's reference to + newTag into self and discards self's reference to oldTag. + */ + bool Replace(DryadMTag* newTag, DryadMTag* oldTag); + + /* this removes tag from self's tag list. If tag is not in self's + tag list this returns false and does not alter any reference + counts, otherwise it returns true and discards self's reference + to tag. + */ + bool Remove(DryadMTag* tag); + + /* this returns the tag with ID enumID if and only if there is + a unique tag with that name in self's list, otherwise it + returns NULL. This call does not modify the tag's reference + count. */ + DryadMTag* LookUpTag(UInt16 enumID); + + /* the following type-safe lookup methods return non-NULL if and + only if there is a unique tag in self with ID enumID and the + correct type. These calls do not modify the tag's reference + count. */ + DryadMTagUnknown* LookUpUnknownTag(UInt16 enumID); + DryadMTagVoid* LookUpVoidTag(UInt16 enumID); + DryadMTagBoolean* LookUpBooleanTag(UInt16 enumID); + DryadMTagInt16* LookUpInt16Tag(UInt16 enumID); + DryadMTagUInt16* LookUpUInt16Tag(UInt16 enumID); + DryadMTagInt32* LookUpInt32Tag(UInt16 enumID); + DryadMTagUInt32* LookUpUInt32Tag(UInt16 enumID); + DryadMTagInt64* LookUpInt64Tag(UInt16 enumID); + DryadMTagUInt64* LookUpUInt64Tag(UInt16 enumID); + DryadMTagString* LookUpStringTag(UInt16 enumID); + DryadMTagGuid* LookUpGuidTag(UInt16 enumID); + DryadMTagTimeStamp* LookUpTimeStampTag(UInt16 enumID); + DryadMTagTimeInterval* LookUpTimeIntervalTag(UInt16 enumID); + DryadMTagDrError* LookUpDrErrorTag(UInt16 enumID); + DryadMTagMetaData* LookUpMetaDataTag(UInt16 enumID); + DryadMTagVertexCommand* LookUpVertexCommandTag(UInt16 enumID); + DryadMTagInputChannelDescription* + LookUpInputChannelDescriptionTag(UInt16 enumID); + DryadMTagOutputChannelDescription* + LookUpOutputChannelDescriptionTag(UInt16 enumID); + DryadMTagVertexProcessStatus* LookUpVertexProcessStatusTag(UInt16 enumID); + DryadMTagVertexStatus* LookUpVertexStatusTag(UInt16 enumID); + DryadMTagVertexCommandBlock* LookUpVertexCommandBlockTag(UInt16 enumID); + + /* the following type-safe lookup methods return DrError_OK if and + only if there is a unique tag in self with ID enumID and the + correct type. */ + DrError LookUpVoid(UInt16 enumID); + DrError LookUpBoolean(UInt16 enumID, bool* pVal /* out */); + DrError LookUpInt16(UInt16 enumID, Int16* pVal /* out */); + DrError LookUpUInt16(UInt16 enumID, UInt16* pVal /* out */); + DrError LookUpInt32(UInt16 enumID, Int32* pVal /* out */); + DrError LookUpUInt32(UInt16 enumID, UInt32* pVal /* out */); + DrError LookUpInt64(UInt16 enumID, Int64* pVal /* out */); + DrError LookUpUInt64(UInt16 enumID, UInt64* pVal /* out */); + DrError LookUpString(UInt16 enumID, const char** pVal /* out */); + DrError LookUpGuid(UInt16 enumID, const DrGuid** pVal /* out */); + DrError LookUpTimeStamp(UInt16 enumID, DrTimeStamp* pVal /* out */); + DrError LookUpTimeInterval(UInt16 enumID, DrTimeInterval* pVal /* out */); + DrError LookUpDrError(UInt16 enumID, DrError* pVal /* out */); + DrError LookUpMetaData(UInt16 enumID, DryadMetaDataRef* pVal /* out */); + DrError LookUpVertexCommand(UInt16 enumID, DVertexCommand* pVal /* out */); + DrError LookUpInputChannelDescription(UInt16 enumID, + DryadInputChannelDescription** + pVal /* out */); + DrError LookUpOutputChannelDescription(UInt16 enumID, + DryadOutputChannelDescription** + pVal /* out */); + DrError LookUpVertexProcessStatus(UInt16 enumID, + DVertexProcessStatus** pVal /* out */); + DrError LookUpVertexStatus(UInt16 enumID, DVertexStatus** pVal /* out */); + DrError LookUpVertexCommandBlock(UInt16 enumID, + DVertexCommandBlock** pVal /* out */); + + /* the following type-safe append methods return true if and only + if the tag was appended. */ + bool AppendVoid(UInt16 enumID, + bool allowDuplicateTags); + bool AppendBoolean(UInt16 enumID, bool value, + bool allowDuplicateTags); + bool AppendInt16(UInt16 enumID, Int16 value, + bool allowDuplicateTags); + bool AppendUInt16(UInt16 enumID, UInt16 value, + bool allowDuplicateTags); + bool AppendInt32(UInt16 enumID, Int32 value, + bool allowDuplicateTags); + bool AppendUInt32(UInt16 enumID, UInt32 value, + bool allowDuplicateTags); + bool AppendInt64(UInt16 enumID, Int64 value, + bool allowDuplicateTags); + bool AppendUInt64(UInt16 enumID, UInt64 value, + bool allowDuplicateTags); + bool AppendString(UInt16 enumID, const char* value, + bool allowDuplicateTags); + bool AppendGuid(UInt16 enumID, const DrGuid* value, + bool allowDuplicateTags); + bool AppendTimeStamp(UInt16 enumID, DrTimeStamp value, + bool allowDuplicateTags); + bool AppendTimeInterval(UInt16 enumID, DrTimeInterval value, + bool allowDuplicateTags); + bool AppendDrError(UInt16 enumID, DrError value, + bool allowDuplicateTags); + bool AppendMetaData(UInt16 enumID, DryadMetaData* value, + bool marshalAsAggregate, + bool allowDuplicateTags); + bool AppendVertexCommand(UInt16 enumID, DVertexCommand value, + bool allowDuplicateTags); + bool AppendInputChannelDescription(UInt16 enumID, + DryadInputChannelDescription* value, + bool allowDuplicateTags); + bool AppendOutputChannelDescription(UInt16 enumID, + DryadOutputChannelDescription* value, + bool allowDuplicateTags); + bool AppendVertexProcessStatus(UInt16 enumID, DVertexProcessStatus* value, + bool allowDuplicateTags); + bool AppendVertexStatus(UInt16 enumID, DVertexStatus* value, + bool allowDuplicateTags); + bool AppendVertexCommandBlock(UInt16 enumID, DVertexCommandBlock* value, + bool allowDuplicateTags); + + /* This returns an iterator which can be used to access all the + tags with ID enumID (if any) within self. *pEndIter is + filled in with an iterator beyond the last tag with ID + enumID. No reference counts are modified by this call. The + order of tags returned by this iterator is undefined. */ + DryadMetaData::TagMapIter + LookUpMulti(UInt16 enumID, + DryadMetaData::TagMapIter* pEndIter); + + /* This returns an iterator which can be used to access all the + tags within self in sequence starting from startTag. *pEndIter + is filled in with an iterator beyond the last tag in self. If + startTag is NULL the iterator begins at the first tag (if any) + in self. No reference counts are modified by this call. */ + DryadMetaData::TagListIter + LookUpInSequence(DryadMTag* startTag, + DryadMetaData::TagListIter* pEndIter); + + /* this returns a recursive clone of self. Leaf tags acquire a new + reference within the cloned object, rather than being copied. The + returned metadata has a single reference owned by the + caller. */ + void Clone(DryadMetaDataRef* dstMetaData); + + void Serialize(DrMemoryWriter* writer); + void CacheSerialization(); + DrMemoryBuffer* SerializeToBuffer(); + + /* call 'delete []' on the buffer returned by GetText() */ + char* GetText(); + DrError WriteAsProperty(DrMemoryWriter* writer, + UInt16 propertyTag, + bool writeIfEmpty); + DrError WriteAsAggregate(DrMemoryWriter* writer, + UInt16 propertyTag, + bool writeIfEmpty); + + /* these are convenience methods to add error codes. They each add + a DrError tag with ID Prop_Dryad_ErrorCode, value errorCode if + it doesn't already exist. The second also adds a + Prop_Dryad_ErrorString containing errorDescription. */ + void AddError(DrError errorCode); + void AddErrorWithDescription(DrError errorCode, + const char* errorDescription); + + /* this is a convenience method. If the metadata contains a + DrError tag with ID Prop_Dryad_ErrorCode then the call returns + true and the error in that property is filled in to *pError, + otherwise it returns false and *pError is not modified. */ + bool GetErrorCode(DrError* pError /* out */); + /* returns NULL if no error string exists. The returned value is + valid only as long as this instance is not changed */ + const char* GetErrorString(); + +private: + DryadMetaData(); + ~DryadMetaData(); + + TagList m_elementList; + TagMap m_elementMap; + DrRef m_cachedSerialization; + + CRITSEC m_baseDR; +}; + +class DryadMetaDataParser : public DrPropertyParser +{ +public: + typedef DrError (TagFactory)(DrMemoryReader* reader, + UInt16 enumId, UInt32 dataLen, + DryadMetaDataParser* parent); + typedef DrError (AggregateFactory)(DrMemoryReader* reader, + DryadMTagRef* pTag); + + DryadMetaDataParser(); + ~DryadMetaDataParser(); + + DryadMetaData* GetMetaData(); + void AddTag(DryadMTag* tag); + + DrError ParseBuffer(const void* data, UInt32 dataLength); + + DrError DryadMetaDataParser::OnParseProperty(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + void* cookie); + +private: + DryadMetaDataRef m_set; +}; + +extern void DryadInitMetaDataTable(); +extern DrError + DryadAddFactoryToTypeTable(UInt16 typeCode, + DryadMetaDataParser::TagFactory* factory); +extern DrError + DryadAddFactoryToAggregateTable(UInt16 typeCode, + DryadMetaDataParser::AggregateFactory* + factory); +extern DrError DryadAddPropertyToMetaData(UInt16 prop, UInt16 typeCode); diff --git a/DryadVertex/VertexHost/system/common/include/dryadmetadatatag.h b/DryadVertex/VertexHost/system/common/include/dryadmetadatatag.h new file mode 100644 index 0000000..c4f26a1 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryadmetadatatag.h @@ -0,0 +1,84 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +class DryadMetaData; +class DryadMetaDataConst; +class DrMemoryWriter; + +class DryadMTag; +typedef DrRef DryadMTagRef; + +class DryadMTag : public DrRefCounter +{ +public: + UInt16 GetTagValue(); + UInt16 GetType(); + + virtual DrError Serialize(DrMemoryWriter* writer) = 0; + + /* the default implementation is for immutable tags and simply + increments the reference count and returns self */ + virtual void Clone(DryadMTagRef* dstTag); + +protected: + DryadMTag(UInt16 tagValue, UInt16 type); + virtual ~DryadMTag(); + +private: + UInt16 m_tag; + UInt16 m_type; +}; + +enum DrPropertyTagEnum { + +#ifdef DECLARE_DRPROPERTYTYPE +#undef DECLARE_DRPROPERTYTYPE +#endif + +#define DECLARE_DRPROPERTYTYPE(type) DrPropertyTagType_##type, + +#include "DrPropertyType.h" + +#undef DECLARE_DRPROPERTYTYPE +}; + +enum DryadPropertyTagEnum { + + DryadPropertyTagType_MetaData = 0x1000, + DryadPropertyTagType_InputChannelDescription, + DryadPropertyTagType_OutputChannelDescription, + DryadPropertyTagType_VertexProcessStatus, + DryadPropertyTagType_VertexStatus, + DryadPropertyTagType_VertexCommandBlock, + +#ifdef DECLARE_DRYADPROPERTYTYPE +#undef DECLARE_DRYADPROPERTYTYPE +#endif + +#define DECLARE_DRYADPROPERTYTYPE(type) DryadPropertyTagType_##type, + +#include "DryadPropertyType.h" + +#undef DECLARE_DRYADPROPERTYTYPE +}; diff --git a/DryadVertex/VertexHost/system/common/include/dryadmetadatatagtypes.h b/DryadVertex/VertexHost/system/common/include/dryadmetadatatagtypes.h new file mode 100644 index 0000000..f5374d7 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryadmetadatatagtypes.h @@ -0,0 +1,559 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "dryadmetadatatag.h" +#include "dryadpropertiesdef.h" + +class DryadInputChannelDescription; +class DryadOutputChannelDescription; +class DVertexProcessStatus; +class DVertexStatus; +class DVertexCommandBlock; + +class DryadMTagUnknown; +typedef DrRef DryadMTagUnknownRef; + +class DryadMTagUnknown : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagUnknown* Create(UInt16 tag, UInt32 dataLen, void* data, + UInt16 originalType); + static DrError ReadFromStream(DrMemoryReader* reader, + UInt16 tag, UInt32 dataLen, + DryadMTagUnknownRef* pTag); + static DrError ReadFromStreamWithType(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + UInt16 originalType, + DryadMTagUnknownRef* outTag); + DrError Serialize(DrMemoryWriter* writer); + + UInt32 GetDataLength(); + void* GetData(); + UInt16 GetOriginalType(); + +private: + DryadMTagUnknown(UInt16 tag, UInt32 dataLen, UInt16 originalType); + ~DryadMTagUnknown(); + + UInt32 m_dataLength; + void* m_data; + UInt16 m_originalType; +}; + +class DryadMTagVoid; +typedef DrRef DryadMTagVoidRef; + +class DryadMTagVoid : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagVoid* Create(UInt16 tag); + static DrError ReadFromStream(DrMemoryReader* reader, + UInt16 tag, UInt32 dataLen, + DryadMTagVoidRef* pTag); + DrError Serialize(DrMemoryWriter* writer); + +private: + DryadMTagVoid(UInt16 tag); + ~DryadMTagVoid(); +}; + +class DryadMTagBoolean; +typedef DrRef DryadMTagBooleanRef; + +class DryadMTagBoolean : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagBoolean* Create(UInt16 tag, bool val); + static DrError ReadFromStream(DrMemoryReader* reader, + UInt16 tag, UInt32 dataLen, + DryadMTagBooleanRef* pTag); + DrError Serialize(DrMemoryWriter* writer); + + bool GetBoolean(); + +private: + DryadMTagBoolean(UInt16 tag, bool val); + ~DryadMTagBoolean(); + + bool m_val; +}; + +class DryadMTagInt16; +typedef DrRef DryadMTagInt16Ref; + +class DryadMTagInt16 : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagInt16* Create(UInt16 tag, Int16 val); + static DrError ReadFromStream(DrMemoryReader* reader, + UInt16 tag, UInt32 dataLen, + DryadMTagInt16Ref* pTag); + DrError Serialize(DrMemoryWriter* writer); + + Int16 GetInt16(); + +private: + DryadMTagInt16(UInt16 tag, Int16 val); + ~DryadMTagInt16(); + + Int16 m_i16Val; +}; + +class DryadMTagUInt16; +typedef DrRef DryadMTagUInt16Ref; + +class DryadMTagUInt16 : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagUInt16* Create(UInt16 tag, UInt16 val); + static DrError ReadFromStream(DrMemoryReader* reader, + UInt16 tag, UInt32 dataLen, + DryadMTagUInt16Ref* pTag); + DrError Serialize(DrMemoryWriter* writer); + + UInt16 GetUInt16(); + +private: + DryadMTagUInt16(UInt16 tag, UInt16 val); + ~DryadMTagUInt16(); + + UInt16 m_uI16Val; +}; + +class DryadMTagInt32; +typedef DrRef DryadMTagInt32Ref; + +class DryadMTagInt32 : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagInt32* Create(UInt16 tag, Int32 val); + static DrError ReadFromStream(DrMemoryReader* reader, + UInt16 tag, UInt32 dataLen, + DryadMTagInt32Ref* pTag); + DrError Serialize(DrMemoryWriter* writer); + + Int32 GetInt32(); + +private: + DryadMTagInt32(UInt16 tag, Int32 val); + ~DryadMTagInt32(); + + Int32 m_i32Val; +}; + +class DryadMTagUInt32; +typedef DrRef DryadMTagUInt32Ref; + +class DryadMTagUInt32 : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagUInt32* Create(UInt16 tag, UInt32 val); + static DrError ReadFromStream(DrMemoryReader* reader, + UInt16 tag, UInt32 dataLen, + DryadMTagUInt32Ref* pTag); + DrError Serialize(DrMemoryWriter* writer); + + UInt32 GetUInt32(); + +private: + DryadMTagUInt32(UInt16 tag, UInt32 val); + ~DryadMTagUInt32(); + + UInt32 m_uI32Val; +}; + +class DryadMTagInt64; +typedef DrRef DryadMTagInt64Ref; + +class DryadMTagInt64 : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagInt64* Create(UInt16 tag, Int64 val); + static DrError ReadFromStream(DrMemoryReader* reader, + UInt16 tag, UInt32 dataLen, + DryadMTagInt64Ref* pTag); + DrError Serialize(DrMemoryWriter* writer); + + Int64 GetInt64(); + +private: + DryadMTagInt64(UInt16 tag, Int64 val); + ~DryadMTagInt64(); + + Int64 m_i64Val; +}; + +class DryadMTagUInt64; +typedef DrRef DryadMTagUInt64Ref; + +class DryadMTagUInt64 : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagUInt64* Create(UInt16 tag, UInt64 val); + static DrError ReadFromStream(DrMemoryReader* reader, + UInt16 tag, UInt32 dataLen, + DryadMTagUInt64Ref* pTag); + DrError Serialize(DrMemoryWriter* writer); + + UInt64 GetUInt64(); + +private: + DryadMTagUInt64(UInt16 tag, UInt64 val); + ~DryadMTagUInt64(); + + UInt64 m_uI64Val; +}; + +class DryadMTagDouble; +typedef DrRef DryadMTagDoubleRef; + +class DryadMTagDouble : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagDouble* Create(UInt16 tag, double val); + static DrError ReadFromStream(DrMemoryReader* reader, + UInt16 tag, UInt32 dataLen, + DryadMTagDoubleRef* pTag); + DrError Serialize(DrMemoryWriter* writer); + + double GetDouble(); + +private: + DryadMTagDouble(UInt16 tag, double val); + ~DryadMTagDouble(); + + double m_doubleVal; +}; + +class DryadMTagString; +typedef DrRef DryadMTagStringRef; + +class DryadMTagString : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagString* Create(UInt16 tag, const char* val); + static DrError ReadFromStream(DrMemoryReader* reader, + UInt16 tag, UInt32 dataLen, + DryadMTagStringRef* pTag); + DrError Serialize(DrMemoryWriter* writer); + + const char* GetString(); + +private: + DryadMTagString(UInt16 tag, const char* val); + ~DryadMTagString(); + char* GetWritableString(size_t dataLen); + + DrStr64 m_string; +}; + +class DryadMTagGuid; +typedef DrRef DryadMTagGuidRef; + +class DryadMTagGuid : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagGuid* Create(UInt16 tag, const DrGuid* val); + static DrError ReadFromStream(DrMemoryReader* reader, + UInt16 tag, UInt32 dataLen, + DryadMTagGuidRef* pTag); + DrError Serialize(DrMemoryWriter* writer); + + const DrGuid* GetGuid(); + +private: + DryadMTagGuid(UInt16 tag, const DrGuid* val); + ~DryadMTagGuid(); + + DrInitializedGuid m_guid; +}; + +class DryadMTagTimeStamp; +typedef DrRef DryadMTagTimeStampRef; + +class DryadMTagTimeStamp : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagTimeStamp* Create(UInt16 tag, DrTimeStamp val); + static DrError ReadFromStream(DrMemoryReader* reader, + UInt16 tag, UInt32 dataLen, + DryadMTagTimeStampRef* pTag); + DrError Serialize(DrMemoryWriter* writer); + + DrTimeStamp GetTimeStamp(); + +private: + DryadMTagTimeStamp(UInt16 tag, DrTimeStamp val); + ~DryadMTagTimeStamp(); + + DrTimeStamp m_val; +}; + +class DryadMTagTimeInterval; +typedef DrRef DryadMTagTimeIntervalRef; + +class DryadMTagTimeInterval : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagTimeInterval* Create(UInt16 tag, DrTimeInterval val); + static DrError ReadFromStream(DrMemoryReader* reader, + UInt16 tag, UInt32 dataLen, + DryadMTagTimeIntervalRef* pTag); + DrError Serialize(DrMemoryWriter* writer); + + DrTimeInterval GetTimeInterval(); + +private: + DryadMTagTimeInterval(UInt16 tag, DrTimeInterval val); + ~DryadMTagTimeInterval(); + + DrTimeInterval m_val; +}; + +class DryadMTagDrError; +typedef DrRef DryadMTagDrErrorRef; + +class DryadMTagDrError : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagDrError* Create(UInt16 tag, DrError val); + static DrError ReadFromStream(DrMemoryReader* reader, + UInt16 tag, UInt32 dataLen, + DryadMTagDrErrorRef* pTag); + DrError Serialize(DrMemoryWriter* writer); + + DrError GetDrError(); + +private: + DryadMTagDrError(UInt16 tag, DrError val); + ~DryadMTagDrError(); + + DrError m_val; +}; + +class DryadMTagMetaData; +typedef DrRef DryadMTagMetaDataRef; + +class DryadMTagMetaData : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. The caller's reference to val is + transferred to the new tag. If marshalAsAggregate is true, this + is serialized as an aggregate with tagValue=tag, otherwise it + is serialized as a single property with enumID=tag. */ + static DryadMTagMetaData* Create(UInt16 tag, DryadMetaData* val, + bool marshalAsAggregate); + static DrError ReadFromStreamInAggregate(DrMemoryReader* reader, + DryadMTagMetaDataRef* pTag); + static DrError ReadFromArray(UInt16 tag, + const void* data, UInt32 dataLen, + DryadMTagMetaDataRef* pTag); + DrError Serialize(DrMemoryWriter* writer); + + /* do a deep copy instead of just increasing the refcount */ + DryadMTag* Clone(); + + /* this call does not modify the returned metadata's reference + count */ + DryadMetaData* GetMetaData(); + +private: + DryadMTagMetaData(UInt16 tag, DryadMetaData* val, + bool marshalAsAggregate); + ~DryadMTagMetaData(); + + DrRef m_val; + bool m_marshalAsAggregate; +}; + +class DryadMTagVertexCommand; +typedef DrRef DryadMTagVertexCommandRef; + +class DryadMTagVertexCommand : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagVertexCommand* Create(UInt16 tag, DVertexCommand val); + static DrError ReadFromStream(DrMemoryReader* reader, + UInt16 tag, UInt32 dataLen, + DryadMTagVertexCommandRef* pTag); + DrError Serialize(DrMemoryWriter* writer); + + DVertexCommand GetVertexCommand(); + +private: + DryadMTagVertexCommand(UInt16 tag, DVertexCommand val); + ~DryadMTagVertexCommand(); + + DVertexCommand m_val; +}; + +class DryadMTagInputChannelDescription; +typedef DrRef + DryadMTagInputChannelDescriptionRef; + +class DryadMTagInputChannelDescription : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagInputChannelDescription* + Create(UInt16 tag, DryadInputChannelDescription* val); + static DrError ReadFromStream(DrMemoryReader* reader, + DryadMTagInputChannelDescriptionRef* pTag); + DrError Serialize(DrMemoryWriter* writer); + + DryadInputChannelDescription* GetInputChannelDescription(); + +private: + DryadMTagInputChannelDescription(UInt16 tag, + DryadInputChannelDescription* val); + ~DryadMTagInputChannelDescription(); + + DryadInputChannelDescription* m_val; +}; + +class DryadMTagOutputChannelDescription; +typedef DrRef + DryadMTagOutputChannelDescriptionRef; + +class DryadMTagOutputChannelDescription : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagOutputChannelDescription* + Create(UInt16 tag, DryadOutputChannelDescription* val); + static DrError ReadFromStream(DrMemoryReader* reader, + DryadMTagOutputChannelDescriptionRef* pTag); + DrError Serialize(DrMemoryWriter* writer); + + DryadOutputChannelDescription* GetOutputChannelDescription(); + +private: + DryadMTagOutputChannelDescription(UInt16 tag, + DryadOutputChannelDescription* val); + ~DryadMTagOutputChannelDescription(); + + DryadOutputChannelDescription* m_val; +}; + +class DryadMTagVertexProcessStatus; +typedef DrRef DryadMTagVertexProcessStatusRef; + +class DryadMTagVertexProcessStatus : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagVertexProcessStatus* Create(UInt16 tag, + DVertexProcessStatus* val); + static DrError ReadFromStream(DrMemoryReader* reader, + DryadMTagVertexProcessStatusRef* pTag); + DrError Serialize(DrMemoryWriter* writer); + + DVertexProcessStatus* GetVertexProcessStatus(); + +private: + DryadMTagVertexProcessStatus(UInt16 tag, DVertexProcessStatus* val); + ~DryadMTagVertexProcessStatus(); + + DrRef m_val; +}; + +class DryadMTagVertexStatus; +typedef DrRef DryadMTagVertexStatusRef; + +class DryadMTagVertexStatus : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagVertexStatus* Create(UInt16 tag, + DVertexStatus* val); + static DrError ReadFromStream(DrMemoryReader* reader, + DryadMTagVertexStatusRef* pTag); + DrError Serialize(DrMemoryWriter* writer); + + DVertexStatus* GetVertexStatus(); + +private: + DryadMTagVertexStatus(UInt16 tag, DVertexStatus* val); + ~DryadMTagVertexStatus(); + + DrRef m_val; +}; + +class DryadMTagVertexCommandBlock; +typedef DrRef DryadMTagVertexCommandBlockRef; + +class DryadMTagVertexCommandBlock : public DryadMTag +{ +public: + /* the following call creates a new tag with a single reference + owned by the caller. */ + static DryadMTagVertexCommandBlock* Create(UInt16 tag, + DVertexCommandBlock* val); + static DrError ReadFromStream(DrMemoryReader* reader, + DryadMTagVertexCommandBlockRef* pTag); + DrError Serialize(DrMemoryWriter* writer); + + DVertexCommandBlock* GetVertexCommandBlock(); + +private: + DryadMTagVertexCommandBlock(UInt16 tag, DVertexCommandBlock* val); + ~DryadMTagVertexCommandBlock(); + + DrRef m_val; +}; diff --git a/DryadVertex/VertexHost/system/common/include/dryadnativeport.h b/DryadVertex/VertexHost/system/common/include/dryadnativeport.h new file mode 100644 index 0000000..3b85421 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryadnativeport.h @@ -0,0 +1,153 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include +#include +#include + +class DryadNativePort +{ +public: + class HandlerBase : protected OVERLAPPED + { + public: + virtual ~HandlerBase(); + + OVERLAPPED* GetOverlapped(); + + private: + virtual void ProcessIO(DrError retval, UInt32 numBytes) = 0; + + friend class DryadNativePort; + }; + + class Handler : public HandlerBase + { + public: + Handler(); + virtual ~Handler(); + + PSIZE_T GetNumberOfBytesToTransferPtr(); +//JC DR_STREAM_POSITION *GetDryadPositionPtr(); + virtual void* GetData() = 0; + + protected: + void InitializeInternal(UInt32 bytesToTransfer, + UInt64 requestOffset); + + private: + SIZE_T m_bytesToTransfer; +//JC DR_STREAM_POSITION m_cosmosPosition; + }; + + DryadNativePort(DWORD numWorkerThreads, + DWORD numConcurrentThreads); + ~DryadNativePort(); + + void Start(); + + void AssociateHandle(HANDLE fileHandle); + + void QueueNativeRead(HANDLE fileHandle, Handler* request); + void QueueNativeXComputeRead(XCPROCESSFILEHANDLE fileHandle, + Handler* request, + UInt64* streamOffset, + DrError* pendingStatePtr); +/*JC void QueueDryadRead(DRHANDLE streamHandle, + DrError* pendingStatePtr, + UInt64 streamOffset, + Handler* request);*/ + + void QueueNativeWrite(HANDLE fileHandle, Handler* request); +/*JC void QueueDryadWrite(DRHANDLE streamHandle, + DrError* pendingStatePtr, + UInt64 streamOffset, + Handler* request);*/ + +/*JC void QueueDryadSetStreamProperties(const char* uri, + DrError* pendingStatePtr, + Handler* request);*/ + + void IncrementOutstandingRequests(); + void DecrementOutstandingRequests(); + + // + // Return the number of outstanding requests + // + UInt32 GetOutstandingRequests() + { + return m_outstandingRequests; + } + + // + // Return a the completition port handle + // + HANDLE GetCompletionPort(); + + void Stop(); + +private: + enum BufferPortState { + BPS_Stopped, + BPS_Running, + BPS_Stopping + }; + + static unsigned __stdcall ThreadFunc(void* arg); + + static unsigned __stdcall WriteFileThreadBase(void* arg); + void WriteFileThread(); + + BufferPortState m_state; + + DWORD m_numWorkerThreads; + DWORD m_numConcurrentThreads; + HANDLE m_completionPort; + HANDLE* m_threadHandle; + UInt32 m_outstandingRequests; + + class WriteFileRequest + { + public: + WriteFileRequest(HANDLE h, Handler* hh) + { + m_fileHandle = h; + m_request = hh; + } + + HANDLE m_fileHandle; + Handler* m_request; + DrBListEntry m_listPtr; + }; + typedef DryadBList DryadWriteFileList; + + HANDLE m_writeFileEvent; + HANDLE m_writeFileHandle; + DryadWriteFileList m_writeFileList; + bool m_writeFileFinished; + CRITSEC m_writeFileCS; + + CRITSEC m_baseDR; +}; + + diff --git a/DryadVertex/VertexHost/system/common/include/dryadopaqueresources.h b/DryadVertex/VertexHost/system/common/include/dryadopaqueresources.h new file mode 100644 index 0000000..30d2c88 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryadopaqueresources.h @@ -0,0 +1,46 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +/* this is an opaque class that identifies a process independent of + cluster type */ +class DryadProcessIdentifier : public DrRefCounter +{ +public: + virtual ~DryadProcessIdentifier(); + + virtual DrGuid* GetGuid() = 0; + virtual const char* GetGuidString() = 0; + + virtual void MakeURIForRelativeFile(DrStr* dst, + const char* baseDirectory, + const char* relativeFileName) = 0; +}; + +/* this is an opaque class that identifies a machine independent of + cluster type */ +class DryadMachineIdentifier : public DrRefCounter +{ +public: + virtual ~DryadMachineIdentifier(); +}; diff --git a/DryadVertex/VertexHost/system/common/include/dryadproperties.h b/DryadVertex/VertexHost/system/common/include/dryadproperties.h new file mode 100644 index 0000000..ea00486 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryadproperties.h @@ -0,0 +1,90 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// There must *only* be DEFINE_DRPROPERTY or DEFINE_DRYADPROPERTY +// directives in this file + +DEFINE_DRPROPERTY(Prop_Dryad_ChannelState, PROP_SHORTATOM(0x4000), DrError, "ChannelState") +DEFINE_DRPROPERTY(Prop_Dryad_ChannelURI, PROP_LONGATOM(0x4003), String, "ChannelURI") +DEFINE_DRPROPERTY(Prop_Dryad_ChannelBufferOffset, PROP_SHORTATOM(0x4004), UInt64, "ChannelBufferOffset") +DEFINE_DRPROPERTY(Prop_Dryad_ChannelTotalLength, PROP_SHORTATOM(0x4005), UInt64, "ChannelTotalLength") +DEFINE_DRPROPERTY(Prop_Dryad_ChannelProcessedLength, PROP_SHORTATOM(0x4006), UInt64, "ChannelProcessedLength") +DEFINE_DRPROPERTY(Prop_Dryad_StreamExpireTimeWhileOpen, PROP_SHORTATOM(0x4007), TimeInterval, "StreamExpireTimeWhileOpen") +DEFINE_DRPROPERTY(Prop_Dryad_StreamExpireTimeWhileClosed, PROP_SHORTATOM(0x4008), TimeInterval, "StreamExpireTimeWhileClosed") + +DEFINE_DRPROPERTY(Prop_Dryad_VertexState, PROP_SHORTATOM(0x4010), DrError, "VertexState") +DEFINE_DRPROPERTY(Prop_Dryad_VertexErrorCode, PROP_SHORTATOM(0x4011), DrError, "VertexErrorCode") +DEFINE_DRPROPERTY(Prop_Dryad_VertexId, PROP_SHORTATOM(0x4012), UInt32, "VertexId") +DEFINE_DRPROPERTY(Prop_Dryad_VertexVersion, PROP_SHORTATOM(0x4013), UInt32, "VertexVersion") +DEFINE_DRPROPERTY(Prop_Dryad_VertexInputChannelCount, PROP_SHORTATOM(0x4015), UInt32, "VertexInputChannelCount") +DEFINE_DRPROPERTY(Prop_Dryad_VertexOutputChannelCount, PROP_SHORTATOM(0x4016), UInt32, "VertexOutputChannelCount") +DEFINE_DRYADPROPERTY(Prop_Dryad_VertexCommand, PROP_SHORTATOM(0x4017), VertexCommand, "VertexCommand") +DEFINE_DRPROPERTY(Prop_Dryad_VertexArgumentCount, PROP_SHORTATOM(0x4018), UInt32, "VertexArgumentCount") +DEFINE_DRPROPERTY(Prop_Dryad_VertexArgument, PROP_LONGATOM(0x4019), String, "VertexArgument") +DEFINE_DRPROPERTY(Prop_Dryad_VertexSerializedBlock, PROP_LONGATOM(0x401a), Blob, "VertexSerializedBlock") +DEFINE_DRPROPERTY(Prop_Dryad_DebugBreak, PROP_SHORTATOM(0x401b), Boolean, "DebugBreak") +DEFINE_DRPROPERTY(Prop_Dryad_AssertFailure, PROP_LONGATOM(0x401c), String, "AssertFailure") +DEFINE_DRPROPERTY(Prop_Dryad_CanShareWorkQueue, PROP_SHORTATOM(0x401d), Boolean, "CanShareWorkQueue") +DEFINE_DRPROPERTY(Prop_Dryad_VertexMaxOpenInputChannelCount, PROP_SHORTATOM(0x401e), UInt32, "VertexMaxOpenInputChannelCount") +DEFINE_DRPROPERTY(Prop_Dryad_VertexMaxOpenOutputChannelCount, PROP_SHORTATOM(0x401f), UInt32, "VertexMaxOpenOutputChannelCount") + +DEFINE_DRPROPERTY(Prop_Dryad_ErrorCode, PROP_SHORTATOM(0x4040), DrError, "ErrorCode") +DEFINE_DRPROPERTY(Prop_Dryad_ErrorString, PROP_LONGATOM(0x4041), String, "ErrorString") +DEFINE_DRPROPERTY(Prop_Dryad_ItemBufferStartOffset, PROP_SHORTATOM(0x4042), UInt64, "ItemBufferStartOffset") +DEFINE_DRPROPERTY(Prop_Dryad_ItemBufferEndOffset, PROP_SHORTATOM(0x4043), UInt64, "ItemBufferEndOffset") +DEFINE_DRPROPERTY(Prop_Dryad_BufferLength, PROP_SHORTATOM(0x4044), UInt64, "BufferLength") +DEFINE_DRPROPERTY(Prop_Dryad_ItemStreamStartOffset, PROP_SHORTATOM(0x4045), UInt64, "ItemStreamStartOffset") +DEFINE_DRPROPERTY(Prop_Dryad_ItemStreamEndOffset, PROP_SHORTATOM(0x4046), UInt64, "ItemStreamEndOffset") +DEFINE_DRPROPERTY(Prop_Dryad_ItemDataSequenceNumber, PROP_SHORTATOM(0x4047), UInt64, "ItemDataSequenceNumber") +DEFINE_DRPROPERTY(Prop_Dryad_ItemDeliverySequenceNumber, PROP_SHORTATOM(0x4048), UInt64, "ItemDeliverySequenceNumber") + +DEFINE_DRPROPERTY(Prop_Dryad_InputPortCount, PROP_SHORTATOM(0x4060), UInt32, "InputPortCount") +DEFINE_DRPROPERTY(Prop_Dryad_OutputPortCount, PROP_SHORTATOM(0x4061), UInt32, "OutputPortCount") +DEFINE_DRPROPERTY(Prop_Dryad_NumberOfVertices, PROP_SHORTATOM(0x4062), UInt32, "NumberOfVertices") +DEFINE_DRPROPERTY(Prop_Dryad_SourceVertex, PROP_SHORTATOM(0x4063), UInt32, "SourceVertex") +DEFINE_DRPROPERTY(Prop_Dryad_SourcePort, PROP_SHORTATOM(0x4064), UInt32, "SourcePort") +DEFINE_DRPROPERTY(Prop_Dryad_DestinationVertex, PROP_SHORTATOM(0x4065), UInt32, "DestinationVertex") +DEFINE_DRPROPERTY(Prop_Dryad_DestinationPort, PROP_SHORTATOM(0x4066), UInt32, "DestinationPort") +DEFINE_DRPROPERTY(Prop_Dryad_NumberOfEdges, PROP_SHORTATOM(0x4067), UInt32, "NumberOfEdges") +DEFINE_DRYADPROPERTY(Prop_Dryad_TryToCreateChannelPath, PROP_SHORTATOM(0x4068), Void, "TryToCreateChannelPath") +DEFINE_DRPROPERTY(Prop_Dryad_InitialChannelWriteSize, PROP_SHORTATOM(0x4069), UInt64, "InitialChannelWriteSize") +// BUGBUG: this property used to be called Prop_Dryad_MachineName, but it was mistakenly declared with a SHORTATOM. THat property +// BUGBUG: is now deprecated, and this one should be used instead. +DEFINE_DRPROPERTY(Prop_Dryad_LongMachineName, PROP_LONGATOM(0x406a), String, "LongMachineName" ) + +DEFINE_DRPROPERTY(Prop_Dryad_RSRootProcessIdentifier, PROP_LONGATOM(0x4070), String, "RSRootProcessIdentifier") +DEFINE_DRPROPERTY(Prop_Dryad_RSMachineName, PROP_LONGATOM(0x4071), String, "RSMachineName") +DEFINE_DRPROPERTY(Prop_Dryad_RSCPUAllowance, PROP_SHORTATOM(0x4072), UInt32, "RSCPUAllowance") +DEFINE_DRPROPERTY(Prop_Dryad_RSDiskAllowance, PROP_SHORTATOM(0x4073), UInt32, "RSDiskAllowance") +DEFINE_DRPROPERTY(Prop_Dryad_RSMemoryAllowance, PROP_SHORTATOM(0x4074), UInt64, "RSMemoryAllowance") +DEFINE_DRPROPERTY(Prop_Dryad_RSProcessGuid, PROP_SHORTATOM(0x4075), Guid, "RSProcessGuid") +DEFINE_DRPROPERTY(Prop_Dryad_RSPodName, PROP_LONGATOM(0x4076), String, "RSPodName") +DEFINE_DRPROPERTY(Prop_Dryad_RSAffinity, PROP_SHORTATOM(0x4077), UInt32, "RSAffinity") +DEFINE_DRPROPERTY(Prop_Dryad_RSFailedMachine, PROP_LONGATOM(0x4078), String, "RSFailedMachine") +DEFINE_DRPROPERTY(Prop_Dryad_RSDiscardedProcess, PROP_SHORTATOM(0x4079), Guid, "RSDiscardedProcess") +DEFINE_DRPROPERTY(Prop_Dryad_RSReturnedProcess, PROP_SHORTATOM(0x407a), Guid, "RSReturnedProcess") +DEFINE_DRPROPERTY(Prop_Dryad_RSReplacementGuid, PROP_SHORTATOM(0x407b), Guid, "RSReplacementGuid") +DEFINE_DRPROPERTY(Prop_Dryad_RSClientStarting, PROP_SHORTATOM(0x407c), Boolean, "RSClientStarting") +DEFINE_DRPROPERTY(Prop_Dryad_RSMachineDataSize, PROP_SHORTATOM(0x407d), UInt64, "RSMachineDataSize") +DEFINE_DRPROPERTY(Prop_Dryad_RSPodDataSize, PROP_SHORTATOM(0x407e), UInt64, "RSPodDataSize") +DEFINE_DRPROPERTY(Prop_Dryad_RSRootProcessName, PROP_LONGATOM(0x407f), String, "RSRootProcessName") +DEFINE_DRPROPERTY(Prop_Dryad_RSRootProcessMachine, PROP_LONGATOM(0x4080), String, "RSRootProcessMachine") +DEFINE_DRPROPERTY(Prop_Dryad_RSSendTime, PROP_SHORTATOM(0x4081), TimeStamp, "RSSendTime") +DEFINE_DRPROPERTY(Prop_Dryad_RSProcessingTime, PROP_SHORTATOM(0x4082), TimeInterval, "RSProcessingTime") diff --git a/DryadVertex/VertexHost/system/common/include/dryadpropertiesdef.h b/DryadVertex/VertexHost/system/common/include/dryadpropertiesdef.h new file mode 100644 index 0000000..865dfe1 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryadpropertiesdef.h @@ -0,0 +1,64 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +class DrPropertyDumper; + +#ifdef DECLARE_DRYADPROPERTYTYPE +#undef DECLARE_DRYADPROPERTYTYPE +#endif + +#define DECLARE_DRYADPROPERTYTYPE(type) \ + extern DrError DryadPropertyToText_##type(DrPropertyDumper *pDumper, UInt16 enumId, const char *propertyName); + +#include "dryadpropertytype.h" + +#undef DECLARE_DRYADPROPERTYTYPE + +#ifdef DEFINE_DRPROPERTY +#undef DEFINE_DRPROPERTY +#endif + +#ifdef DEFINE_DRYADPROPERTY +#undef DEFINE_DRYADPROPERTY +#endif + +#define DEFINE_DRPROPERTY(var, value, type, propertyName) \ + static const UInt16 var = value; + +#define DEFINE_DRYADPROPERTY(var, value, type, propertyName) \ + static const UInt16 var = value; + +#include "dryadproperties.h" + +#undef DEFINE_DRPROPERTY +#undef DEFINE_DRYADPROPERTY + + + +// Options for VertexCommand in a DVertexCommand message +enum DVertexCommand { + DVertexCommand_Start = 0, + DVertexCommand_ReOpenChannels, + DVertexCommand_Terminate, + DVertexCommand_Max +}; +extern const char* g_dVertexCommandText[DVertexCommand_Max]; diff --git a/DryadVertex/VertexHost/system/common/include/dryadpropertydumper.h b/DryadVertex/VertexHost/system/common/include/dryadpropertydumper.h new file mode 100644 index 0000000..ec69e5f --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryadpropertydumper.h @@ -0,0 +1,26 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +extern void DryadInitPropertyTable(); +extern void DryadInitTagTable(); +extern void DryadInitErrorTable(); + diff --git a/DryadVertex/VertexHost/system/common/include/dryadpropertytype.h b/DryadVertex/VertexHost/system/common/include/dryadpropertytype.h new file mode 100644 index 0000000..896b450 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryadpropertytype.h @@ -0,0 +1,25 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// This file must consist only of DECLARE_DRYADPROPERTYTYPE statements + +DECLARE_DRYADPROPERTYTYPE(Void) + +DECLARE_DRYADPROPERTYTYPE(VertexCommand) diff --git a/DryadVertex/VertexHost/system/common/include/dryadstandaloneini.h b/DryadVertex/VertexHost/system/common/include/dryadstandaloneini.h new file mode 100644 index 0000000..cc18cb7 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryadstandaloneini.h @@ -0,0 +1,40 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +//JC #include +//JC #include +//JC #include +#include + +DrError DryadInitializeXCompute(const char* netLibName, const char* iniFileName, + int argc, char* argv[], int* pNOpts); +DrError DryadShutdownXCompute(); + +XDRESSIONHANDLE GetSessionHandle(); +XCPROCESSHANDLE GetProcessHandle(); + +void DryadInitialize(); + +class DryadNativePort; + +extern DryadNativePort* g_dryadNativePort; diff --git a/DryadVertex/VertexHost/system/common/include/dryadtags.h b/DryadVertex/VertexHost/system/common/include/dryadtags.h new file mode 100644 index 0000000..fc341de --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryadtags.h @@ -0,0 +1,52 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// Tags used for BeginTag/EndTag + +// There must *only* be DEFINE_DRYADTAG directives in this file + +DEFINE_DRYADTAG(DryadTag_InputChannelDescription, 10000, "InputChannelDescription", InputChannelDescription) +DEFINE_DRYADTAG(DryadTag_OutputChannelDescription, 10001, "OutputChannelDescription", OutputChannelDescription) +DEFINE_DRYADTAG(DryadTag_VertexProcessStatus, 10002, "VertexProcessStatusBlock", VertexProcessStatus) +DEFINE_DRYADTAG(DryadTag_VertexStatus, 10003, "VertexStatusBlock", VertexStatus) +DEFINE_DRYADTAG(DryadTag_VertexCommand, 10004, "VertexCommandBlock", VertexCommandBlock) +DEFINE_DRYADTAG(DryadTag_ItemStart, 10005, "ItemStartBlock", MetaData) +DEFINE_DRYADTAG(DryadTag_ItemEnd, 10006, "ItemEndBlock", MetaData) +DEFINE_DRYADTAG(DryadTag_ChannelMetaData, 10007, "ChannelMetaData", MetaData) +DEFINE_DRYADTAG(DryadTag_VertexMetaData, 10008, "VertexMetaData", MetaData) +DEFINE_DRYADTAG(DryadTag_ArgumentArray, 10009, "ArgumentArray", MetaData) +DEFINE_DRYADTAG(DryadTag_VertexArray, 10010, "VertexArray", MetaData) +DEFINE_DRYADTAG(DryadTag_VertexInfo, 10011, "VertexInfo", MetaData) +DEFINE_DRYADTAG(DryadTag_EdgeArray, 10012, "EdgeArray", MetaData) +DEFINE_DRYADTAG(DryadTag_EdgeInfo, 10013, "EdgeInfo", MetaData) +DEFINE_DRYADTAG(DryadTag_GraphDescription, 10014, "GraphDescription", MetaData) +DEFINE_DRYADTAG(DryadTag_RSCAReturnMachine, 10015, "ReturnMachine", MetaData) +DEFINE_DRYADTAG(DryadTag_RSCAEnqueueProcess, 10016, "EnqueueProcess", MetaData) +DEFINE_DRYADTAG(DryadTag_RSCAReportFailedMachine, 10017, "ReportFailedMachine", MetaData) +DEFINE_DRYADTAG(DryadTag_RSCADiscardProcess, 10018, "DiscardProcess", MetaData) +DEFINE_DRYADTAG(DryadTag_RSClientResponse, 10019, "RSClientResponse", MetaData) +DEFINE_DRYADTAG(DryadTag_RSClientRootProcessRequest, 10020, "RSClientRootProcessRequest", MetaData) +DEFINE_DRYADTAG(DryadTag_RSClientRootProcessResponse, 10021, "RSClientRootProcessResponse", MetaData) +DEFINE_DRYADTAG(DryadTag_RSClientInitializeRequest, 10022, "RSClientInitializeRequest", MetaData) +DEFINE_DRYADTAG(DryadTag_RSClientActionRequest, 10023, "RSClientActionRequest", MetaData) +DEFINE_DRYADTAG(DryadTag_RSClientStatusRequest, 10024, "RSClientStatusRequest", MetaData) +DEFINE_DRYADTAG(DryadTag_RSClientMatch, 10025, "RSClientMatch", MetaData) +DEFINE_DRYADTAG(DryadTag_RSRevocation, 10026, "RSRevocation", MetaData) +DEFINE_DRYADTAG(DryadTag_RSClientCommand, 10027, "RSClientCommand", MetaData) diff --git a/DryadVertex/VertexHost/system/common/include/dryadtagsdef.h b/DryadVertex/VertexHost/system/common/include/dryadtagsdef.h new file mode 100644 index 0000000..4bfa65d --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryadtagsdef.h @@ -0,0 +1,32 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#ifdef DEFINE_DRYADTAG +#undef DEFINE_DRYADTAG +#endif + +#define DEFINE_DRYADTAG(var, value, tagName, tagType) \ + static const UInt16 var = value; + +#include "dryadtags.h" + +#undef DEFINE_DRYADTAG diff --git a/DryadVertex/VertexHost/system/common/include/dryadxcomputeresources.h b/DryadVertex/VertexHost/system/common/include/dryadxcomputeresources.h new file mode 100644 index 0000000..7fc569c --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dryadxcomputeresources.h @@ -0,0 +1,55 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include + +class DryadXComputeProcessIdentifier : public DryadProcessIdentifier +{ +public: + DryadXComputeProcessIdentifier(XCPROCESSHANDLE handle); + + DrGuid* GetGuid(); + const char* GetGuidString(); + void MakeURIForRelativeFile(DrStr* dst, + const char* baseDirectory, + const char* relativeFileName); + + XCPROCESSHANDLE GetHandle(); + void ReCacheGUID(); + +private: + XCPROCESSHANDLE m_handle; + DrInitializedGuid m_guid; + DrStr32 m_guidString; +}; + +class DryadXComputeMachineIdentifier : public DryadMachineIdentifier +{ +public: + DryadXComputeMachineIdentifier(XCPROCESSNODEID node); + + XCPROCESSNODEID GetNodeID(); + +private: + XCPROCESSNODEID m_node; +}; diff --git a/DryadVertex/VertexHost/system/common/include/dvertexcommand.h b/DryadVertex/VertexHost/system/common/include/dvertexcommand.h new file mode 100644 index 0000000..512f62b --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/dvertexcommand.h @@ -0,0 +1,227 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "DrCommon.h" +#include + +/* these are wrapper classes to access properties that abstracts the + different mechanisms for getting and setting properties in xcompute + and potentially other clusters */ +class DryadPnProcessPropertyRequest : public DrRefCounter +{ +public: + virtual ~DryadPnProcessPropertyRequest(); + + virtual void SetPropertyLabel(const char* label, + const char* controlLabel) = 0; + virtual void SetPropertyString(const char* string) = 0; + virtual DrMemoryBuffer* GetPropertyBlock() = 0; +}; + +class DryadPnProcessPropertyResponse : public DrRefCounter +{ +public: + virtual ~DryadPnProcessPropertyResponse(); + + virtual void RetrievePropertyLabel(const char* label) = 0; + virtual DrMemoryBuffer* GetPropertyBlock() = 0; +}; + +class DryadChannelDescription : public DrPropertyParser +{ +public: + DryadChannelDescription(bool isInputChannel); + virtual ~DryadChannelDescription(); + + DrError GetChannelState() const; + void SetChannelState(DrError state); + + const char* GetChannelURI() const; + void SetChannelURI(const char* channelURI); + + DryadMetaData* GetChannelMetaData() const; + void SetChannelMetaData(DryadMetaData* metaData); + + UInt64 GetChannelTotalLength() const; + void SetChannelTotalLength(UInt64 totalLength); + + UInt64 GetChannelProcessedLength() const; + void SetChannelProcessedLength(UInt64 processedLength); + + DrError Serialize(DrMemoryWriter* writer); + DrError OnParseProperty(DrMemoryReader *reader, UInt16 enumID, + UInt32 dataLen, void *cookie); + + void CopyFrom(DryadChannelDescription* src, bool includeLengths); + +private: + DrError m_state; + DrStr64 m_URI; + DryadMetaDataRef m_metaData; + UInt64 m_totalLength; + UInt64 m_processedLength; + bool m_isInputChannel; +}; + +class DryadInputChannelDescription : public DryadChannelDescription +{ +public: + DryadInputChannelDescription(); +}; + +class DryadOutputChannelDescription : public DryadChannelDescription +{ +public: + DryadOutputChannelDescription(); +}; + +class DVertexProcessStatus : public DrPropertyParser, public DrRefCounter +{ +public: + DVertexProcessStatus(); + ~DVertexProcessStatus(); + + UInt32 GetVertexId(); + void SetVertexId(UInt32 vertexId); + + UInt32 GetVertexInstanceVersion(); + void SetVertexInstanceVersion(UInt32 instanceVersion); + + DryadMetaData* GetVertexMetaData(); + void SetVertexMetaData(DryadMetaData* metaData); + + UInt32 GetInputChannelCount(); + void SetInputChannelCount(UInt32 channelCount); + + UInt32 GetMaxOpenInputChannelCount(); + void SetMaxOpenInputChannelCount(UInt32 channelCount); + + DryadInputChannelDescription* GetInputChannels(); + + UInt32 GetOutputChannelCount(); + void SetOutputChannelCount(UInt32 channelCount); + + UInt32 GetMaxOpenOutputChannelCount(); + void SetMaxOpenOutputChannelCount(UInt32 channelCount); + + DryadOutputChannelDescription* GetOutputChannels(); + + bool GetCanShareWorkQueue(); + void SetCanShareWorkQueue(bool canShareWorkQueue); + + DrError Serialize(DrMemoryWriter* writer); + DrError OnParseProperty(DrMemoryReader *reader, UInt16 enumID, + UInt32 dataLen, void *cookie); + + void CopyFrom(DVertexProcessStatus* src, bool includeLengths); + +private: + UInt32 m_id; + UInt32 m_version; + DryadMetaDataRef m_metaData; + UInt32 m_nInputChannels; + UInt32 m_maxInputChannels; + DryadInputChannelDescription* m_inputChannel; + UInt32 m_nOutputChannels; + UInt32 m_maxOutputChannels; + DryadOutputChannelDescription* m_outputChannel; + bool m_canShareWorkQueue; + + UInt32 m_nextInputChannelToRead; + UInt32 m_nextOutputChannelToRead; +}; + + +class DVertexStatus : public DrPropertyParser, public DrRefCounter +{ +public: + DVertexStatus(); + + DrError GetVertexState(); + void SetVertexState(DrError state); + + DVertexProcessStatus* GetProcessStatus(); + void SetProcessStatus(DVertexProcessStatus* status); + + DrError Serialize(DrMemoryWriter* writer); + DrError OnParseProperty(DrMemoryReader *reader, UInt16 enumID, + UInt32 dataLen, void *cookie); + + void StoreInRequestMessage(DryadPnProcessPropertyRequest* request); + DrError ReadFromResponseMessage(DryadPnProcessPropertyResponse* response, + UInt32 vertexId, UInt32 vertexVersion); + + static void GetPnPropertyLabel(DrStr* pDstString, + UInt32 vertexId, UInt32 vertexVersion, + bool notifyWaiters); + +private: + DrError m_state; + DrRef m_processStatus; +}; + + +class DVertexCommandBlock : public DrPropertyParser, public DrRefCounter +{ +public: + DVertexCommandBlock(); + ~DVertexCommandBlock(); + + DVertexCommand GetVertexCommand(); + void SetVertexCommand(DVertexCommand command); + + DVertexProcessStatus* GetProcessStatus(); + void SetProcessStatus(DVertexProcessStatus* status); + + UInt32 GetArgumentCount(); + void SetArgumentCount(UInt32 nArguments); + DrStr64* GetArgumentVector(); + void SetArgument(UInt32 argumentIndex, const char* argument); + + void* GetRawSerializedBlock(); + UInt32 GetRawSerializedBlockLength(); + void SetRawSerializedBlock(UInt32 length, const void* data); + + void SetDebugBreak(bool setBreakpointOnCommandArrival); + bool GetDebugBreak(); + + DrError Serialize(DrMemoryWriter* writer); + DrError OnParseProperty(DrMemoryReader *reader, UInt16 enumID, + UInt32 dataLen, void *cookie); + + void StoreInRequestMessage(DryadPnProcessPropertyRequest* request); + DrError ReadFromResponseMessage(DryadPnProcessPropertyResponse* response, + UInt32 vertexId, UInt32 vertexVersion); + + static void GetPnPropertyLabel(DrStr* pDstString, + UInt32 vertexId, UInt32 vertexVersion); + +private: + DVertexCommand m_command; + DrRef m_processStatus; + UInt32 m_nArguments; + DrStr64* m_argument; + UInt32 m_serializedBlockLength; + char* m_serializedBlock; + bool m_setBreakpointOnCommandArrival; + UInt32 m_nextArgumentToRead; +}; diff --git a/DryadVertex/VertexHost/system/common/include/errorreporter.h b/DryadVertex/VertexHost/system/common/include/errorreporter.h new file mode 100644 index 0000000..5c609c7 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/errorreporter.h @@ -0,0 +1,49 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +class DVErrorReporter +{ +public: + DVErrorReporter(); + + bool NoError(); + DrError GetErrorCode(); + DryadMetaData* GetErrorMetaData(); + + void ReportError(DrError errorStatus); + void ReportError(const char* errorFormat, ...); + void ReportError(DrError errorStatus, + const char* errorFormat, ...); + void ReportError(DrError errorStatus, DryadMetaData* metaData); + + void InterruptProcessing(); + +private: + void ReportFormattedErrorInternal(DrError errorStatus, + const char *formatString, + va_list args); + + DrError m_errorCode; + DryadMetaDataRef m_metaData; +}; diff --git a/DryadVertex/VertexHost/system/common/include/orderedsendlatch.h b/DryadVertex/VertexHost/system/common/include/orderedsendlatch.h new file mode 100644 index 0000000..bfe72d1 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/orderedsendlatch.h @@ -0,0 +1,150 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once +#pragma warning(disable:4512) // KLUDGE -- build for now, fix later. + +#include + +template< class _T > class DryadOrderedSendLatch +{ +public: + typedef _T ListType; + + DryadOrderedSendLatch() + { + m_sendState = SS_Empty; + m_event = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(m_event != NULL); + } + + ~DryadOrderedSendLatch() + { + LogAssert(m_sendState == SS_Empty); + LogAssert(m_pendingList.IsEmpty()); + BOOL bRet = ::CloseHandle(m_event); + LogAssert(bRet != 0); + } + + + void Start() + { + LogAssert(m_sendState == SS_Empty); + LogAssert(m_pendingList.IsEmpty()); + } + + void Stop() + { + LogAssert(m_sendState == SS_Empty); + LogAssert(m_pendingList.IsEmpty()); + } + + // + // If the send latch is sending or blocked, add list to pending + // + void AcceptList(ListType* src) + { + if (src->IsEmpty() == false) + { + if (m_sendState == SS_Empty) + { + // + // If send state says latch is empty, verify nothing in the pending list + // and set state to sending + // todo: shouldn't the m_pendingList append the src parameter so that there's something to send? + // + LogAssert(m_pendingList.IsEmpty()); + m_sendState = SS_Sending; + } + else + { + // + // If sending or blocking, add list to pending + // + m_pendingList.TransitionToTail(src); + } + } + } + + // + // If there is currently a pending list, put the supplied list on the end + // otherwise, stop blocking + // + void TransferList(ListType* dst) + { + if (m_pendingList.IsEmpty() == false) + { + LogAssert(m_sendState != SS_Empty); + dst->TransitionToTail(&m_pendingList); + } + else + { + // todo: why isn't dst used in this case? + if (m_sendState == SS_Blocking) + { + BOOL bRet = ::SetEvent(m_event); + LogAssert(bRet != 0); + } + m_sendState = SS_Empty; + } + } + + // + // Block and further sends. Return true if need to wait for blocking to occur, false otherwise. + // + bool Interrupt() + { + bool mustWait = false; + + // + // If currently sending, set to blocking and return true + // + if (m_sendState == SS_Sending) + { + BOOL bRet = ::ResetEvent(m_event); + LogAssert(bRet != 0); + m_sendState = SS_Blocking; + mustWait = true; + } + + return mustWait; + } + + // + // Blocking wait for reset event + // + void Wait() + { + DWORD dRet = ::WaitForSingleObject(m_event, INFINITE); + LogAssert(dRet == WAIT_OBJECT_0); + } + +private: + enum SendState { + SS_Empty, + SS_Sending, + SS_Blocking + }; + + SendState m_sendState; + ListType m_pendingList; + HANDLE m_event; +}; + diff --git a/DryadVertex/VertexHost/system/common/include/portmemorybuffers.h b/DryadVertex/VertexHost/system/common/include/portmemorybuffers.h new file mode 100644 index 0000000..7cd92dd --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/portmemorybuffers.h @@ -0,0 +1,97 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +class DryadFixedMemoryBuffer : public DrFixedMemoryBuffer +{ +public: + DryadFixedMemoryBuffer(); + DryadFixedMemoryBuffer(BYTE *pData, + Size_t allocatedSize, Size_t availableSize = 0); + virtual ~DryadFixedMemoryBuffer(); + +private: + DrBListEntry m_listPtr; + friend class DryadBList; +}; + +typedef DryadBList DryadFixedBufferList; + +class DryadFixedMemoryBufferCopy : public DryadFixedMemoryBuffer +{ +public: + DryadFixedMemoryBufferCopy(DryadFixedMemoryBuffer* src); + ~DryadFixedMemoryBufferCopy(); + +private: + BYTE* m_dataCopy; +}; + +class DryadLockedMemoryBuffer : public DrFixedMemoryBuffer +{ +public: + DryadLockedMemoryBuffer(); + DryadLockedMemoryBuffer(BYTE *pData, Size_t allocatedSize); + virtual ~DryadLockedMemoryBuffer(); + + void Init(BYTE *pData, Size_t allocatedSize); + + void SetAvailableSize(Size_t uSize); + +private: + DrBListEntry m_listPtr; + friend class DryadBList; +}; + +typedef DryadBList DryadLockedBufferList; + +class DryadAlignedReadBlock : public DryadLockedMemoryBuffer +{ +public: + DryadAlignedReadBlock(size_t size, size_t alignment); + ~DryadAlignedReadBlock(); + + void Trim(Size_t numBytes); + + void* GetData(); + +private: + void* m_data; + void* m_alignedData; + size_t m_alignment; +}; + +class DryadAlignedWriteBlock : public DryadFixedMemoryBuffer +{ +public: + DryadAlignedWriteBlock(size_t size, size_t alignment); + ~DryadAlignedWriteBlock(); + + void* GetData(); + +private: + void* m_data; + void* m_alignedData; + size_t m_alignment; +}; + diff --git a/DryadVertex/VertexHost/system/common/include/workqueue.h b/DryadVertex/VertexHost/system/common/include/workqueue.h new file mode 100644 index 0000000..d1b8ddc --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/workqueue.h @@ -0,0 +1,73 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +class WorkQueue; + +class WorkRequest +{ +public: + virtual ~WorkRequest(); + virtual void Process() = 0; + virtual bool ShouldAbort() = 0; + +private: + DrBListEntry m_listPtr; + friend class DryadBList; +}; + +typedef DryadBList WorkRequestList; + +class WorkQueue { + public: + WorkQueue(DWORD numWorkerThreads, + DWORD numConcurrentThreads); + ~WorkQueue(); + + void Start(); + bool EnQueue(WorkRequest* request); + + void Clean(); + void Stop(); + +private: + enum WorkQueueState { + WQS_Stopped, + WQS_Running, + WQS_Stopping + }; + + static unsigned __stdcall ThreadFunc(void* arg); + + WorkRequestList m_list; /* list of WorkRequest items */ + + WorkQueueState m_state; + + DWORD m_numWorkerThreads; + DWORD m_numConcurrentThreads; + HANDLE m_completionPort; + HANDLE* m_threadHandle; + DWORD m_numQueuedWakeUps; + + CRITSEC m_baseDR; +}; diff --git a/DryadVertex/VertexHost/system/common/include/xcomputepropertyblock.h b/DryadVertex/VertexHost/system/common/include/xcomputepropertyblock.h new file mode 100644 index 0000000..289be11 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/xcomputepropertyblock.h @@ -0,0 +1,57 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include + +class DryadXComputePnProcessPropertyRequest : + public DryadPnProcessPropertyRequest +{ +public: + DryadXComputePnProcessPropertyRequest(); + + void SetPropertyLabel(const char* label, const char* controlLabel); + void SetPropertyString(const char* string); + DrMemoryBuffer* GetPropertyBlock(); + +protected: + DrStr64 m_label; + DrStr64 m_controlLabel; + DrStr64 m_string; + DrRef m_block; +}; + +class DryadXComputePnProcessPropertyResponse : + public DryadPnProcessPropertyResponse +{ +public: + DryadXComputePnProcessPropertyResponse(PXC_PROCESS_INFO response); + + void RetrievePropertyLabel(const char* label); + DrMemoryBuffer* GetPropertyBlock(); + +private: + PXC_PROCESS_INFO m_processInfo; + PXC_PROCESSPROPERTY_INFO m_propertyInfo; + DrRef m_block; +}; + diff --git a/DryadVertex/VertexHost/system/common/include/yarnpropertyblock.h b/DryadVertex/VertexHost/system/common/include/yarnpropertyblock.h new file mode 100644 index 0000000..8f69ecd --- /dev/null +++ b/DryadVertex/VertexHost/system/common/include/yarnpropertyblock.h @@ -0,0 +1,53 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +class DryadYarnPnProcessPropertyRequest : public DryadPnProcessPropertyRequest +{ +public:; + DryadYarnPnProcessPropertyRequest(); + + void SetPropertyLabel(const char* label, const char* controlLabel); + void SetPropertyString(const char* string); + DrMemoryBuffer* GetPropertyBlock(); + +protected: + DrStr64 m_label; + DrStr64 m_controlLabel; + DrStr64 m_string; + DrRef m_block; +}; + +class DryadYarnPnProcessPropertyResponse : public DryadPnProcessPropertyResponse +{ +public: + DryadYarnPnProcessPropertyResponse(); + + void RetrievePropertyLabel(const char* label); + DrMemoryBuffer* GetPropertyBlock(); + +private: + + DrRef m_block; +}; + diff --git a/DryadVertex/VertexHost/system/common/src/DObjPool.cpp b/DryadVertex/VertexHost/system/common/src/DObjPool.cpp new file mode 100644 index 0000000..5a721ea --- /dev/null +++ b/DryadVertex/VertexHost/system/common/src/DObjPool.cpp @@ -0,0 +1,444 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DObjPool.h" + +DObjPoolThreadPrivateBlock::Entry::Entry() +{ + m_pool = NULL; + m_key = 0; + m_cache = NULL; +} + +DObjPoolThreadPrivateBlock::Entry::~Entry() +{ + if (m_cache != NULL) + { + LogAssert(m_cache->Abandoned()); + delete m_cache; + } +} + +void DObjPoolThreadPrivateBlock::Entry::Initialize(DObjPoolBase* pool, + LONGLONG poolKey, + DObjPoolCache* cache) +{ + LogAssert(pool != NULL); + LogAssert(cache != NULL); + + m_pool = pool; + m_key = poolKey; + m_cache = cache; +} + +DObjPoolThreadPrivateBlock::DObjPoolThreadPrivateBlock() +{ + m_entryArraySize = 16; + m_entry = new Entry[m_entryArraySize]; + m_numberOfEntries = 0; +} + +DObjPoolThreadPrivateBlock::~DObjPoolThreadPrivateBlock() +{ + delete [] m_entry; +} + +DObjPoolCache* DObjPoolThreadPrivateBlock::LookUpPoolCache(DObjPoolBase* pool, + LONGLONG key) +{ + UInt32 i; + for (i=0; im_pool == pool && e->m_key == key) + { + return e->m_cache; + } + } + + return NULL; +} + +void DObjPoolThreadPrivateBlock::AddPoolCache(DObjPoolBase* pool, + LONGLONG poolKey, + DObjPoolCache* cache) +{ + LogAssert(LookUpPoolCache(pool, poolKey) == NULL); + + if (m_numberOfEntries == m_entryArraySize) + { + m_entryArraySize *= 2; + DrLogI( "Growing pool cache array. Was %u entries, now %u", m_numberOfEntries, m_entryArraySize); + Entry* newArray = new Entry[m_entryArraySize]; + LogAssert(newArray != NULL); + ::memcpy(newArray, m_entry, m_numberOfEntries*sizeof(m_entry[0])); + delete [] m_entry; + m_entry = newArray; + } + + LogAssert(m_numberOfEntries < m_entryArraySize); + + DrLogI( "Inserting pool cache entry. Entry %u pool %p key %I64d", m_numberOfEntries, pool, poolKey); + + m_entry[m_numberOfEntries].Initialize(pool, poolKey, cache); + ++m_numberOfEntries; +} + +void DObjPoolThreadPrivateBlock::GarbageCollectCaches() +{ + DrLogI( "Inspecting pool cache for stale entries"); + + UInt32 newTotal = 0; + UInt32 i; + for (i=0; iAbandoned()) + { + DrLogI( "Discarding pool cache entry. Entry %u pool %p key %I64d", + i, m_entry[i].m_pool, m_entry[i].m_key); + delete m_entry[i].m_cache; + } + else + { + m_entry[newTotal].Initialize(m_entry[i].m_pool, + m_entry[i].m_key, + m_entry[i].m_cache); + ++newTotal; + } + } + + LogAssert(newTotal <= m_numberOfEntries); + + DrLogI( "After pool cache collection. Old size %u new size %u", m_numberOfEntries, newTotal); + + m_numberOfEntries = newTotal; +} + + +DObjPoolCache::DObjPoolCache(UInt32 maxEntries, + UInt32 keepEntryCount, + DObjPoolBase* pool) +{ + LogAssert(keepEntryCount < maxEntries); + m_abandoned = false; + m_pool = pool; + m_maxEntries = maxEntries; + m_keepEntryCount = keepEntryCount; + m_array = new void* [m_maxEntries]; + m_numberOfEntries = 0; +} + +DObjPoolCache::~DObjPoolCache() +{ + LogAssert(m_abandoned == true); + delete [] m_array; +} + +bool DObjPoolCache::Abandoned() +{ + return m_abandoned; +} + +void DObjPoolCache::InsertObject(void* e) +{ + LogAssert(m_abandoned == false); + + if (m_numberOfEntries == m_maxEntries) + { + ReturnToPool(false); + } + + LogAssert(m_numberOfEntries < m_maxEntries); + m_array[m_numberOfEntries] = e; + ++m_numberOfEntries; +} + +void* DObjPoolCache::RemoveObject() +{ + LogAssert(m_abandoned == false); + + if (m_numberOfEntries == 0) + { + m_pool->RemoveObjects(m_array, m_keepEntryCount); + m_numberOfEntries = m_keepEntryCount; + } + + LogAssert(m_numberOfEntries > 0); + --m_numberOfEntries; + return m_array[m_numberOfEntries]; +} + +void DObjPoolCache::ReturnToPool(bool finalCleanup) +{ + UInt32 keepCount; + if (m_keepEntryCount == 0 || finalCleanup) + { + keepCount = 0; + } + else + { + /* we're about to insert something, so get rid of one + extra */ + keepCount = m_keepEntryCount - 1; + } + + m_pool->AcceptObjects(m_array + keepCount, + m_numberOfEntries - keepCount); + m_numberOfEntries = keepCount; + + if (finalCleanup) + { + /* make sure we don't reference any member variables after + setting m_abandoned to true, since we may be spontaneously + deleted by another thread at any time after this action */ + m_abandoned = true; + } +} + + +static CRITSEC* s_refPoolGlobalCritSec; +static DrTlsPtr* t_privateCacheBlock; + +class DObjPoolCritSecInitializer +{ +public: + DObjPoolCritSecInitializer(CRITSEC** critSec); +}; + +DObjPoolCritSecInitializer::DObjPoolCritSecInitializer(CRITSEC** critSec) +{ + *critSec = new CRITSEC(); +} + +/* make sure s_refPoolGlobalCritSec is initialized by the time all + static constructors have run */ +static DObjPoolCritSecInitializer s_initCritSec(&s_refPoolGlobalCritSec); + +DObjPoolBase::DObjPoolBase(DObjFactoryBase* factory, + UInt32 maxCentralEntries, + UInt32 maxLocalEntries, + UInt32 localKeepEntryCount) +{ + /* make a UID to distinguish us from any other pool with the same + heap address (e.g. if memory gets recycled). We don't have to + worry about a race here, since another pool being created at + exactly the same time on another processor must have a + different address. */ + LARGE_INTEGER keyLI; + ::QueryPerformanceCounter(&keyLI); + m_key = keyLI.QuadPart; + + LogAssert(factory != NULL); + m_factory = factory; + m_maxCentralEntries = maxCentralEntries; + m_maxLocalEntries = maxLocalEntries; + m_localKeepEntryCount = localKeepEntryCount; + + m_array = new void* [m_maxCentralEntries]; + m_numberOfCentralEntries = 0; + + m_cacheArraySize = 32; + m_cache = new DObjPoolCache* [m_cacheArraySize]; + m_numberOfCaches = 0; + + m_totalGivenOut = 0; + m_totalReturned = 0; + m_totalAllocated = 0; + m_totalFreed = 0; + + /* make sure we have exactly one TLS entry for all pools */ + if (s_refPoolGlobalCritSec == NULL) + { + /* a static constructor is calling, so we don't need to worry + about thread safety */ + if (t_privateCacheBlock == NULL) + { + t_privateCacheBlock = new DrTlsPtr; + } + } + else + { + AutoCriticalSection acs(s_refPoolGlobalCritSec); + + if (t_privateCacheBlock == NULL) + { + t_privateCacheBlock = new DrTlsPtr; + } + } +} + +DObjPoolBase::~DObjPoolBase() +{ + UInt32 i; + for (i=0; iReturnToPool(true); + /* we must not delete m_cache[i] here since it is still + referenced in the private block of its thread. That thread + will delete the cache later during a garbage collection */ + } + delete [] m_cache; + + LogAssert(m_totalGivenOut == m_totalReturned); + LogAssert(m_numberOfCentralEntries + m_totalFreed == m_totalAllocated); + + for (i=0; iFreeObjectUntyped(m_array[i]); + } + delete [] m_array; +} + +DObjFactoryBase* DObjPoolBase::DetachFactory() +{ + LogAssert(m_factory != NULL); + return m_factory.Detach(); +} + +void DObjPoolBase::AcceptObjects(void** src, UInt32 count) +{ + UInt32 i; + { + AutoCriticalSection acs(&m_atomic); + + UInt32 freeSpace = m_maxCentralEntries - m_numberOfCentralEntries; + if (freeSpace > count) + { + freeSpace = count; + } + + for (i=0; iFreeObjectUntyped(src[i]); + } +} + +void DObjPoolBase::RemoveObjects(void** dst, UInt32 count) +{ + UInt32 i; + { + AutoCriticalSection acs(&m_atomic); + + UInt32 existing = m_numberOfCentralEntries; + if (existing > count) + { + existing = count; + } + m_numberOfCentralEntries -= existing; + + for (i=0; iAllocateObjectUntyped(); + } +} + +DObjPoolCache* DObjPoolBase::MakeCache() +{ + AutoCriticalSection acs(&m_atomic); + + if (m_numberOfCaches == m_cacheArraySize) + { + m_cacheArraySize *= 2; + DObjPoolCache** newArray = new DObjPoolCache* [m_cacheArraySize]; + LogAssert(newArray != NULL); + ::memcpy(newArray, m_cache, m_numberOfCaches * sizeof(m_cache[0])); + delete m_cache; + m_cache = newArray; + } + + LogAssert(m_numberOfCaches < m_cacheArraySize); + + DObjPoolCache* cache = new DObjPoolCache(m_maxLocalEntries, + m_localKeepEntryCount, + this); + m_cache[m_numberOfCaches] = cache; + ++m_numberOfCaches; + + return cache; +} + +DObjPoolCache* DObjPoolBase::FetchPrivateCache() +{ + DObjPoolThreadPrivateBlock* privateCacheBlock = *t_privateCacheBlock; + + if (privateCacheBlock == NULL) + { + /* this is the first time any pool has been referenced on this + thread, so make a new empty private cache block. This will + never be garbage-collected. */ + privateCacheBlock = new DObjPoolThreadPrivateBlock(); + LogAssert(privateCacheBlock != NULL); + *t_privateCacheBlock = privateCacheBlock; + } + + DObjPoolCache* cache = privateCacheBlock->LookUpPoolCache(this, m_key); + if (cache == NULL) + { + /* this is the first time this particular pool has been + referenced on this thread, so make a new empty entry + cache. This will be garbage-collected by the + privateCacheBlock some time after the pool is freed. */ + cache = MakeCache(); + privateCacheBlock->AddPoolCache(this, m_key, cache); + } + + return cache; +} + +void* DObjPoolBase::AllocateObjectUntyped() +{ + return FetchPrivateCache()->RemoveObject(); +} + +void DObjPoolBase::FreeObjectUntyped(void* item) +{ + FetchPrivateCache()->InsertObject(item); +} diff --git a/DryadVertex/VertexHost/system/common/src/dryadeventcache.cpp b/DryadVertex/VertexHost/system/common/src/dryadeventcache.cpp new file mode 100644 index 0000000..758a080 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/src/dryadeventcache.cpp @@ -0,0 +1,81 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "dryadeventcache.h" + + +#pragma unmanaged + +DryadHandleListEntry::DryadHandleListEntry(HANDLE handle) +{ + m_handle = handle; +} + +HANDLE DryadHandleListEntry::GetHandle() +{ + return m_handle; +} + +DryadEventCache::DryadEventCache() +{ +} + +DryadEventCache::~DryadEventCache() +{ + BOOL bRet; + DrBListEntry* listEntry = m_eventCache.GetHead(); + while (listEntry != NULL) + { + DryadHandleListEntry* h = m_eventCache.CastOut(listEntry); + listEntry = m_eventCache.GetNext(listEntry); + m_eventCache.Remove(m_eventCache.CastIn(h)); + bRet = ::CloseHandle(h->GetHandle()); + LogAssert(bRet != 0); + delete h; + } +} + +DryadHandleListEntry* DryadEventCache::GetEvent(bool reset) +{ + DryadHandleListEntry* event; + + if (m_eventCache.IsEmpty()) + { + HANDLE h = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(h != NULL); + event = new DryadHandleListEntry(h); + } + else + { + event = m_eventCache.CastOut(m_eventCache.RemoveHead()); + if (reset) + { + BOOL bRet = ::ResetEvent(event->GetHandle()); + LogAssert(bRet != 0); + } + } + + return event; +} + +void DryadEventCache::ReturnEvent(DryadHandleListEntry* event) +{ + m_eventCache.InsertAsHead(m_eventCache.CastIn(event)); +} diff --git a/DryadVertex/VertexHost/system/common/src/dryadmetadata.cpp b/DryadVertex/VertexHost/system/common/src/dryadmetadata.cpp new file mode 100644 index 0000000..fc5ad40 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/src/dryadmetadata.cpp @@ -0,0 +1,2456 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include +#include +#include + +#ifdef DECLARE_DRPROPERTYTYPE +#undef DECLARE_DRPROPERTYTYPE +#endif + +#define DECLARE_DRPROPERTYTYPE(type) DrError DrMetaDataTagFactory_##type(DrMemoryReader*,UInt16,UInt32,DryadMetaDataParser*); + +#include "DrPropertyType.h" + +#undef DECLARE_DRPROPERTYTYPE + +#ifdef DECLARE_DRYADPROPERTYTYPE +#undef DECLARE_DRYADPROPERTYTYPE +#endif + +#define DECLARE_DRYADPROPERTYTYPE(type) DrError DryadMetaDataTagFactory_##type(DrMemoryReader*,UInt16,UInt32,DryadMetaDataParser*); + +#include "DryadPropertyType.h" + +#undef DECLARE_DRYADPROPERTYTYPE + +#ifdef DEFINE_DRYADTAG +#undef DEFINE_DRYADTAG +#endif + +#define DEFINE_DRYADTAG(name,id,desc,type) \ + DrError DryadMetaDataAggregateFactory_##type(DrMemoryReader*,DryadMTagRef*); + +#include "dryadtags.h" + +#undef DEFINE_DRYADTAG + + +#pragma unmanaged + +struct DryadTypeFactoryRecord +{ + UInt16 m_typeCode; + DryadMetaDataParser::TagFactory* m_factory; +}; + +// +// todo: Understand the preprocessor magic herein contained +// todo: if cosmos specific, remove +// +static DryadTypeFactoryRecord s_cosmosTypeFactory[] = { + +#ifdef DECLARE_DRPROPERTYTYPE +#undef DECLARE_DRPROPERTYTYPE +#endif + +#define DECLARE_DRPROPERTYTYPE(type) \ + {DrPropertyTagType_##type,DrMetaDataTagFactory_##type}, + +#include "DrPropertyType.h" + +#undef DECLARE_DRPROPERTYTYPE +}; + +// +// todo: Understand the preprocessor magic herein contained +// +static DryadTypeFactoryRecord s_dryadTypeFactory[] = { + +#ifdef DECLARE_DRYADPROPERTYTYPE +#undef DECLARE_DRYADPROPERTYTYPE +#endif + +#define DECLARE_DRYADPROPERTYTYPE(type) \ + {DryadPropertyTagType_##type,DryadMetaDataTagFactory_##type}, + +#include "DryadPropertyType.h" + +#undef DECLARE_DRYADPROPERTYTYPE +}; + +struct DryadMetaDataTypeRecord +{ + UInt16 m_enumId; + UInt16 m_typeCode; + DryadMetaDataParser::TagFactory* m_factory; +}; + +static DryadMetaDataTypeRecord s_cosmosTagTypes[] = { + +#ifdef DEFINE_DRPROPERTY +#undef DEFINE_DRPROPERTY +#endif + +#define DEFINE_DRPROPERTY(name,id,type,desc) \ + {name,DrPropertyTagType_##type,DrMetaDataTagFactory_##type}, + +#include "DrProperties.h" + +}; + +static DryadMetaDataTypeRecord s_dryadTagTypes[] = { + +#ifdef DEFINE_DRYADPROPERTY +#undef DEFINE_DRYADPROPERTY +#endif + +#define DEFINE_DRYADPROPERTY(name,id,type,desc) \ + {name,DryadPropertyTagType_##type,DryadMetaDataTagFactory_##type}, + +#include "DryadProperties.h" + +#undef DEFINE_DRYADPROPERTY +#undef DEFINE_DRPROPERTY +}; + +struct DryadMetaDataAggregateFactory +{ + UInt16 m_aggregateId; + DryadMetaDataParser::AggregateFactory* m_factory; +}; + +static DryadMetaDataAggregateFactory s_dryadAggregateFactory[] = { + +#ifdef DEFINE_DRYADTAG +#undef DEFINE_DRYADTAG +#endif + +#define DEFINE_DRYADTAG(name,id,desc,type) \ + {name,DryadMetaDataAggregateFactory_##type}, + +#include "dryadtags.h" + +#undef DEFINE_DRYADTAG +}; + +// +// Define maps between integer ids and various value types +// +typedef std::map TagFactoryMap; +typedef std::map + AggregateFactoryMap; +typedef std::map TagTypeRecordMap; + +// +// allocate these on the heap during the DryadInitMetaDataTable +// routine for purposes of sanitation: we don't really want to have +// stl constructors/destructors being called statically if we don't +// have to +static TagFactoryMap* s_factoryTable; +static AggregateFactoryMap* s_aggregateFactoryTable; +static TagTypeRecordMap* s_propertyTypeTable; + +DryadMetaData::DryadMetaData() +{ +} + +DryadMetaData::~DryadMetaData() +{ + TagListIter iter; + for (iter = m_elementList.begin(); iter != m_elementList.end(); ++iter) + { + (*iter)->DecRef(); + } +} + +void DryadMetaData::Create(DrRef * dstMetaData) +{ + dstMetaData->Attach(new DryadMetaData()); +} + +bool DryadMetaData::Append(DryadMTag* tag, bool allowDuplicateNames) +{ + bool appended = true; + + { + AutoCriticalSection acs(&m_baseDR); + + if (!allowDuplicateNames) + { + TagMapIter iter = m_elementMap.find(tag->GetTagValue()); + if (iter != m_elementMap.end()) + { + appended = false; + } + } + + if (appended) + { + tag->IncRef(); + m_elementMap.insert(std::make_pair(tag->GetTagValue(), tag)); + m_elementList.push_back(tag); + } + } + + return appended; +} + +void DryadMetaData::AppendMetaDataTags(DryadMetaData* metaData, + bool allowDuplicateNames) +{ + TagListIter endIter; + + { + AutoCriticalSection outerAcs(&(metaData->m_baseDR)); + + TagListIter iter = metaData->LookUpInSequence(NULL, &endIter); + + { + AutoCriticalSection acs(&m_baseDR); + + while (iter != endIter) + { + DryadMTagRef tag; + (*iter)->Clone(&tag); + Append(tag, allowDuplicateNames); + ++iter; + } + } + } +} + +bool DryadMetaData::Replace(DryadMTag* newTag, DryadMTag* oldTag) +{ + bool replaced = false; + + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(oldTag->GetTagValue() == newTag->GetTagValue()); + + TagMapIter iter = m_elementMap.find(oldTag->GetTagValue()); + while (iter != m_elementMap.end() && + iter->second != oldTag) + { + ++iter; + } + + if (iter != m_elementMap.end()) + { + iter->second = newTag; + oldTag->DecRef(); + replaced = true; + } + } + + return replaced; +} + +bool DryadMetaData::Remove(DryadMTag* tag) +{ + bool removed = false; + + { + AutoCriticalSection acs(&m_baseDR); + + TagMapIter mIter = m_elementMap.find(tag->GetTagValue()); + while (mIter != m_elementMap.end() && + mIter->second != tag) + { + ++mIter; + } + + if (mIter != m_elementMap.end()) + { + TagListIter lIter = m_elementList.begin(); + while (lIter != m_elementList.end() && + (*lIter) != tag) + { + ++lIter; + } + LogAssert(lIter != m_elementList.end()); + m_elementList.erase(lIter); + m_elementMap.erase(mIter); + tag->DecRef(); + removed = true; + } + } + + return removed; +} + +DryadMTag* DryadMetaData::LookUpTag(UInt16 enumId) +{ + DryadMTag* tag = NULL; + + { + AutoCriticalSection acs(&m_baseDR); + + TagMapIter iter = m_elementMap.find(enumId); + if (iter != m_elementMap.end()) + { + ++iter; + if (iter == m_elementMap.end() || iter->first != enumId) + { + --iter; + tag = iter->second; + } + } + } + + return tag; +} + +DryadMetaData::TagMapIter + DryadMetaData::LookUpMulti(UInt16 enumId, + DryadMetaData::TagMapIter* pEndIter) +{ + DryadMetaData::TagMapIter startIter = m_elementMap.find(enumId); + DryadMetaData::TagMapIter endIter = startIter; + while (endIter != m_elementMap.end() && enumId == endIter->first) + { + ++endIter; + } + + *pEndIter = endIter; + return startIter; +} + +DryadMetaData::TagListIter + DryadMetaData::LookUpInSequence(DryadMTag* tag, + DryadMetaData::TagListIter* pEndIter) +{ + *pEndIter = m_elementList.end(); + + if (tag == NULL) + { + return m_elementList.begin(); + } + + DryadMetaData::TagListIter iter = m_elementList.begin(); + while (iter != m_elementList.end() && (*iter) != tag) + { + ++iter; + } + + return iter; +} + +void DryadMetaData::Clone(DryadMetaDataRef* dstMetaData) +{ + DryadMetaData* clone = new DryadMetaData(); + + TagListIter lIter; + + { + AutoCriticalSection acs(&m_baseDR); + + for (lIter = m_elementList.begin(); + lIter != m_elementList.end(); + ++lIter) + { + DryadMTagRef tag; + (*lIter)->Clone(&tag); + DryadMTag* freeTag = tag.Detach(); + clone->m_elementList.push_back(freeTag); + clone->m_elementMap.insert(std::make_pair(freeTag->GetTagValue(), + freeTag)); + } + } + + dstMetaData->Attach(clone); +} + +void DryadMetaData::AddError(DrError errorCode) +{ + DryadMTagRef tag; + tag.Attach(DryadMTagDrError::Create(Prop_Dryad_ErrorCode, errorCode)); + Append(tag, false); +} + +void DryadMetaData::AddErrorWithDescription(DrError errorCode, + const char* errorDescription) +{ + DryadMTagRef tag; + tag.Attach(DryadMTagDrError::Create(Prop_Dryad_ErrorCode, errorCode)); + Append(tag, false); + tag.Attach(DryadMTagString::Create(Prop_Dryad_ErrorString, + errorDescription)); + Append(tag, false); +} + +bool DryadMetaData::GetErrorCode(DrError* pError) +{ + DryadMTagDrError* tag = LookUpDrErrorTag(Prop_Dryad_ErrorCode); + if (tag != NULL) + { + *pError = tag->GetDrError(); + return true; + } + else + { + return false; + } +} + +const char* DryadMetaData::GetErrorString() +{ + DryadMTagString* tag = LookUpStringTag(Prop_Dryad_ErrorString); + if (tag == NULL) + { + return NULL; + } + + return tag->GetString(); +} + +void DryadMetaData::Serialize(DrMemoryWriter* writer) +{ + { + AutoCriticalSection acs(&m_baseDR); + + if (m_cachedSerialization.Ptr() != NULL) + { + Size_t available; + void* data = + m_cachedSerialization->GetDataAddress(0, &available, NULL); + LogAssert(available < 0x100000000); + LogAssert(available >= m_cachedSerialization->GetAvailableSize()); + + writer->WriteBytes(data, + m_cachedSerialization->GetAvailableSize()); + } + else + { + DryadMetaData::TagListIter iter; + for (iter = m_elementList.begin(); iter != m_elementList.end(); + ++iter) + { + (*iter)->Serialize(writer); + } + } + } +} + +void DryadMetaData::CacheSerialization() +{ + m_cachedSerialization = NULL; + + DrRef buffer; + buffer.Attach(new DrSimpleHeapBuffer()); + { + DrMemoryBufferWriter writer(buffer); + Serialize(&writer); + } + /* don't assign buffer to m_cachedSerialization until after the + call to Serialize to make sure something actually happens + there */ + m_cachedSerialization = buffer; +} + +DrMemoryBuffer* DryadMetaData::SerializeToBuffer() +{ + DrMemoryBuffer* buffer = new DrSimpleHeapBuffer(); + + { + DrMemoryBufferWriter writer(buffer); + Serialize(&writer); + DrError errTmp = writer.CloseMemoryWriter(); + LogAssert(errTmp == DrError_OK); + } + + return buffer; +} + +char* DryadMetaData::GetText() +{ + DrMemoryBuffer* buffer = SerializeToBuffer(); + DrMemoryBufferReader reader(buffer); + buffer->DecRef(); + + buffer = new DrSimpleHeapBuffer(); + DrMemoryBufferWriter writer(buffer); + +/* JC + DrPropertyDumper dumper; + + dumper.SetWriter(&writer); + + dumper.PutNestedPropertyList(&reader); +*/ + DrError errTmp = writer.CloseMemoryWriter(); + LogAssert(errTmp == DrError_OK); + + size_t textLength = buffer->GetAvailableSize(); + char* textBuffer = new char[textLength+1]; + LogAssert(textBuffer != NULL); + buffer->Read(0, textBuffer, textLength); + buffer->DecRef(); + textBuffer[textLength] = '\0'; + + return textBuffer; +} + +DrError DryadMetaData::WriteAsProperty(DrMemoryWriter* writer, + UInt16 propertyTag, + bool writeIfEmpty) +{ + DrMemoryBuffer* buffer = SerializeToBuffer(); + + const BYTE* data = NULL; + Size_t length = buffer->GetAvailableSize(); + + DrError err = DrError_OK; + + if (length > 0) + { + LogAssert(length < 0x100000000UL); + DrMemoryBufferReader reader(buffer); + buffer->DecRef(); + + reader.ReadBytes(length, &data); + LogAssert(data != NULL); + + err = writer->WriteAnySizeBlobProperty(propertyTag, length, data); + } + else if (writeIfEmpty) + { + err = writer->WriteEmptyProperty(propertyTag); + } + + return err; +} + +DrError DryadMetaData::WriteAsAggregate(DrMemoryWriter* writer, + UInt16 propertyTag, + bool writeIfEmpty) +{ + writer->WriteUInt16Property(Prop_Dryad_BeginTag, propertyTag); + Serialize(writer); + return writer->WriteUInt16Property(Prop_Dryad_EndTag, propertyTag); +} + +DryadMTagUnknown* DryadMetaData::LookUpUnknownTag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DrPropertyTagType_Unknown) + { + tag = NULL; + } + } + return (DryadMTagUnknown *) tag; +} + +DryadMTagVoid* DryadMetaData::LookUpVoidTag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DryadPropertyTagType_Void) + { + tag = NULL; + } + } + return (DryadMTagVoid *) tag; +} + +DryadMTagBoolean* DryadMetaData::LookUpBooleanTag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DrPropertyTagType_Boolean) + { + tag = NULL; + } + } + return (DryadMTagBoolean *) tag; +} + +DryadMTagInt16* DryadMetaData::LookUpInt16Tag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DrPropertyTagType_Int16) + { + tag = NULL; + } + } + return (DryadMTagInt16 *) tag; +} + +DryadMTagUInt16* DryadMetaData::LookUpUInt16Tag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DrPropertyTagType_UInt16) + { + tag = NULL; + } + } + return (DryadMTagUInt16 *) tag; +} + +DryadMTagInt32* DryadMetaData::LookUpInt32Tag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DrPropertyTagType_Int32) + { + tag = NULL; + } + } + return (DryadMTagInt32 *) tag; +} + +DryadMTagUInt32* DryadMetaData::LookUpUInt32Tag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DrPropertyTagType_UInt32) + { + tag = NULL; + } + } + return (DryadMTagUInt32 *) tag; +} + +DryadMTagInt64* DryadMetaData::LookUpInt64Tag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DrPropertyTagType_Int64) + { + tag = NULL; + } + } + return (DryadMTagInt64 *) tag; +} + +DryadMTagUInt64* DryadMetaData::LookUpUInt64Tag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DrPropertyTagType_UInt64) + { + tag = NULL; + } + } + return (DryadMTagUInt64 *) tag; +} + +DryadMTagString* DryadMetaData::LookUpStringTag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DrPropertyTagType_String) + { + tag = NULL; + } + } + return (DryadMTagString *) tag; +} + +DryadMTagGuid* DryadMetaData::LookUpGuidTag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DrPropertyTagType_Guid) + { + tag = NULL; + } + } + return (DryadMTagGuid *) tag; +} + +DryadMTagTimeStamp* DryadMetaData::LookUpTimeStampTag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DrPropertyTagType_TimeStamp) + { + tag = NULL; + } + } + return (DryadMTagTimeStamp *) tag; +} + +DryadMTagTimeInterval* DryadMetaData::LookUpTimeIntervalTag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DrPropertyTagType_TimeInterval) + { + tag = NULL; + } + } + return (DryadMTagTimeInterval *) tag; +} + +DryadMTagDrError* DryadMetaData::LookUpDrErrorTag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DrPropertyTagType_DrError) + { + tag = NULL; + } + } + return (DryadMTagDrError *) tag; +} + +DryadMTagMetaData* DryadMetaData::LookUpMetaDataTag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DryadPropertyTagType_MetaData) + { + tag = NULL; + } + } + return (DryadMTagMetaData *) tag; +} + +DryadMTagVertexCommand* DryadMetaData::LookUpVertexCommandTag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DryadPropertyTagType_VertexCommand) + { + tag = NULL; + } + } + return (DryadMTagVertexCommand *) tag; +} + +DryadMTagInputChannelDescription* + DryadMetaData::LookUpInputChannelDescriptionTag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DryadPropertyTagType_InputChannelDescription) + { + tag = NULL; + } + } + return (DryadMTagInputChannelDescription *) tag; +} + +DryadMTagOutputChannelDescription* + DryadMetaData::LookUpOutputChannelDescriptionTag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DryadPropertyTagType_OutputChannelDescription) + { + tag = NULL; + } + } + return (DryadMTagOutputChannelDescription *) tag; +} + +DryadMTagVertexProcessStatus* + DryadMetaData::LookUpVertexProcessStatusTag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DryadPropertyTagType_VertexProcessStatus) + { + tag = NULL; + } + } + return (DryadMTagVertexProcessStatus *) tag; +} + +DryadMTagVertexStatus* DryadMetaData::LookUpVertexStatusTag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DryadPropertyTagType_VertexStatus) + { + tag = NULL; + } + } + return (DryadMTagVertexStatus *) tag; +} + +DryadMTagVertexCommandBlock* + DryadMetaData::LookUpVertexCommandBlockTag(UInt16 enumID) +{ + DryadMTag* tag = LookUpTag(enumID); + if (tag != NULL) + { + if (tag->GetType() != DryadPropertyTagType_VertexCommandBlock) + { + tag = NULL; + } + } + return (DryadMTagVertexCommandBlock *) tag; +} + +DrError DryadMetaData::LookUpVoid(UInt16 enumID) +{ + DryadMTagVoid* tag = LookUpVoidTag(enumID); + return (tag == NULL) ? DrError_InvalidProperty : DrError_OK; +} + +DrError DryadMetaData::LookUpBoolean(UInt16 enumID, bool* pVal /* out */) +{ + DryadMTagBoolean* tag = LookUpBooleanTag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetBoolean(); + return DrError_OK; + } +} + +DrError DryadMetaData::LookUpUInt16(UInt16 enumID, UInt16* pVal /* out */) +{ + DryadMTagUInt16* tag = LookUpUInt16Tag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetUInt16(); + return DrError_OK; + } +} + +DrError DryadMetaData::LookUpInt32(UInt16 enumID, Int32* pVal /* out */) +{ + DryadMTagInt32* tag = LookUpInt32Tag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetInt32(); + return DrError_OK; + } +} + +DrError DryadMetaData::LookUpUInt32(UInt16 enumID, UInt32* pVal /* out */) +{ + DryadMTagUInt32* tag = LookUpUInt32Tag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetUInt32(); + return DrError_OK; + } +} + +DrError DryadMetaData::LookUpInt64(UInt16 enumID, Int64* pVal /* out */) +{ + DryadMTagInt64* tag = LookUpInt64Tag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetInt64(); + return DrError_OK; + } +} + +DrError DryadMetaData::LookUpUInt64(UInt16 enumID, UInt64* pVal /* out */) +{ + DryadMTagUInt64* tag = LookUpUInt64Tag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetUInt64(); + return DrError_OK; + } +} + +DrError DryadMetaData::LookUpString(UInt16 enumID, const char** pVal /* out */) +{ + DryadMTagString* tag = LookUpStringTag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetString(); + return DrError_OK; + } +} + +DrError DryadMetaData::LookUpGuid(UInt16 enumID, const DrGuid** pVal /* out */) +{ + DryadMTagGuid* tag = LookUpGuidTag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetGuid(); + return DrError_OK; + } +} + +DrError DryadMetaData::LookUpTimeStamp(UInt16 enumID, + DrTimeStamp* pVal /* out */) +{ + DryadMTagTimeStamp* tag = LookUpTimeStampTag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetTimeStamp(); + return DrError_OK; + } +} + +DrError DryadMetaData::LookUpTimeInterval(UInt16 enumID, + DrTimeInterval* pVal /* out */) +{ + DryadMTagTimeInterval* tag = LookUpTimeIntervalTag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetTimeInterval(); + return DrError_OK; + } +} + +DrError DryadMetaData::LookUpDrError(UInt16 enumID, DrError* pVal /* out */) +{ + DryadMTagDrError* tag = LookUpDrErrorTag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetDrError(); + return DrError_OK; + } +} + +DrError DryadMetaData::LookUpMetaData(UInt16 enumID, + DryadMetaDataRef* pVal /* out */) +{ + DryadMTagMetaData* tag = LookUpMetaDataTag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetMetaData(); + return DrError_OK; + } +} + +DrError DryadMetaData::LookUpVertexCommand(UInt16 enumID, + DVertexCommand* pVal /* out */) +{ + DryadMTagVertexCommand* tag = LookUpVertexCommandTag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetVertexCommand(); + return DrError_OK; + } +} + +DrError DryadMetaData:: + LookUpInputChannelDescription(UInt16 enumID, + DryadInputChannelDescription** pVal) +{ + DryadMTagInputChannelDescription* tag = + LookUpInputChannelDescriptionTag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetInputChannelDescription(); + return DrError_OK; + } +} + +DrError DryadMetaData:: + LookUpOutputChannelDescription(UInt16 enumID, + DryadOutputChannelDescription** pVal) +{ + DryadMTagOutputChannelDescription* tag = + LookUpOutputChannelDescriptionTag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetOutputChannelDescription(); + return DrError_OK; + } +} + +DrError DryadMetaData::LookUpVertexProcessStatus(UInt16 enumID, + DVertexProcessStatus** pVal) +{ + DryadMTagVertexProcessStatus* tag = LookUpVertexProcessStatusTag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetVertexProcessStatus(); + return DrError_OK; + } +} + +DrError DryadMetaData::LookUpVertexStatus(UInt16 enumID, DVertexStatus** pVal) +{ + DryadMTagVertexStatus* tag = LookUpVertexStatusTag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetVertexStatus(); + return DrError_OK; + } +} + +DrError DryadMetaData::LookUpVertexCommandBlock(UInt16 enumID, + DVertexCommandBlock** pVal) +{ + DryadMTagVertexCommandBlock* tag = LookUpVertexCommandBlockTag(enumID); + if (tag == NULL) + { + return DrError_InvalidProperty; + } + else + { + *pVal = tag->GetVertexCommandBlock(); + return DrError_OK; + } +} + +bool DryadMetaData::AppendVoid(UInt16 enumID, + bool allowDuplicateTags) +{ + DryadMTagVoidRef tag; + tag.Attach(DryadMTagVoid::Create(enumID)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendBoolean(UInt16 enumID, bool value, + bool allowDuplicateTags) +{ + DryadMTagBooleanRef tag; + tag.Attach(DryadMTagBoolean::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendInt16(UInt16 enumID, Int16 value, + bool allowDuplicateTags) +{ + DryadMTagInt16Ref tag; + tag.Attach(DryadMTagInt16::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendUInt16(UInt16 enumID, UInt16 value, + bool allowDuplicateTags) +{ + DryadMTagUInt16Ref tag; + tag.Attach(DryadMTagUInt16::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendInt32(UInt16 enumID, Int32 value, + bool allowDuplicateTags) +{ + DryadMTagInt32Ref tag; + tag.Attach(DryadMTagInt32::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendUInt32(UInt16 enumID, UInt32 value, + bool allowDuplicateTags) +{ + DryadMTagUInt32Ref tag; + tag.Attach(DryadMTagUInt32::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendInt64(UInt16 enumID, Int64 value, + bool allowDuplicateTags) +{ + DryadMTagInt64Ref tag; + tag.Attach(DryadMTagInt64::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendUInt64(UInt16 enumID, UInt64 value, + bool allowDuplicateTags) +{ + DryadMTagUInt64Ref tag; + tag.Attach(DryadMTagUInt64::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendString(UInt16 enumID, const char* value, + bool allowDuplicateTags) +{ + DryadMTagStringRef tag; + tag.Attach(DryadMTagString::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendGuid(UInt16 enumID, const DrGuid* value, + bool allowDuplicateTags) +{ + DryadMTagGuidRef tag; + tag.Attach(DryadMTagGuid::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendTimeStamp(UInt16 enumID, DrTimeStamp value, + bool allowDuplicateTags) +{ + DryadMTagTimeStampRef tag; + tag.Attach(DryadMTagTimeStamp::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendTimeInterval(UInt16 enumID, DrTimeInterval value, + bool allowDuplicateTags) +{ + DryadMTagTimeIntervalRef tag; + tag.Attach(DryadMTagTimeInterval::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendDrError(UInt16 enumID, DrError value, + bool allowDuplicateTags) +{ + DryadMTagDrErrorRef tag; + tag.Attach(DryadMTagDrError::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendMetaData(UInt16 enumID, DryadMetaData* value, + bool marshalAsAggregate, + bool allowDuplicateTags) +{ + DryadMTagMetaDataRef tag; + tag.Attach(DryadMTagMetaData::Create(enumID, value, marshalAsAggregate)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendVertexCommand(UInt16 enumID, DVertexCommand value, + bool allowDuplicateTags) +{ + DryadMTagVertexCommandRef tag; + tag.Attach(DryadMTagVertexCommand::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData:: + AppendInputChannelDescription(UInt16 enumID, + DryadInputChannelDescription* value, + bool allowDuplicateTags) +{ + DryadMTagInputChannelDescriptionRef tag; + tag.Attach(DryadMTagInputChannelDescription::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData:: + AppendOutputChannelDescription(UInt16 enumID, + DryadOutputChannelDescription* value, + bool allowDuplicateTags) +{ + DryadMTagOutputChannelDescriptionRef tag; + tag.Attach(DryadMTagOutputChannelDescription::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendVertexProcessStatus(UInt16 enumID, + DVertexProcessStatus* value, + bool allowDuplicateTags) +{ + DryadMTagVertexProcessStatusRef tag; + tag.Attach(DryadMTagVertexProcessStatus::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendVertexStatus(UInt16 enumID, + DVertexStatus* value, + bool allowDuplicateTags) +{ + DryadMTagVertexStatusRef tag; + tag.Attach(DryadMTagVertexStatus::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + +bool DryadMetaData::AppendVertexCommandBlock(UInt16 enumID, + DVertexCommandBlock* value, + bool allowDuplicateTags) +{ + DryadMTagVertexCommandBlockRef tag; + tag.Attach(DryadMTagVertexCommandBlock::Create(enumID, value)); + return Append(tag, allowDuplicateTags); +} + + +DryadMetaDataParser::DryadMetaDataParser() +{ + DryadMetaData::Create(&m_set); +} + +DryadMetaDataParser::~DryadMetaDataParser() +{ +} + +DryadMetaData* DryadMetaDataParser::GetMetaData() +{ + return m_set.Ptr(); +} + +void DryadMetaDataParser::AddTag(DryadMTag* tag) +{ + m_set.Ptr()->Append(tag, true); +} + +DrError DryadMetaDataParser::OnParseProperty(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + void* cookie) +{ + TagTypeRecordMap::iterator iter = s_propertyTypeTable->find(enumId); + TagFactory* factory; + if (iter == s_propertyTypeTable->end()) + { + factory = DrMetaDataTagFactory_Unknown; + } + else + { + factory = iter->second->m_factory; + } + return (*factory)(reader, enumId, dataLen, this); +} + +DrError DryadMetaDataParser::ParseBuffer(const void* data, UInt32 dataLength) +{ + DrSingleBlockReader r((void *) data, dataLength); + + DrError err = DrError_OK; + bool finished = false; + + while (!finished) { + UInt16 enumId; + UInt32 length; + + err = r.PeekNextPropertyTag(&enumId, &length); + if (err == DrError_EndOfStream) { + err = DrError_OK; + finished = true; + } else { + err = OnParseProperty(&r, enumId, length, NULL); + if (err != DrError_OK) + { + finished = true; + } + } + } + + return err; +} + + +// +// Attempt to insert tagfactory into tagfactory map +// If unable to insert, fail because of duplicate key +// +DrError DryadAddFactoryToTypeTable(UInt16 typeCode, + DryadMetaDataParser::TagFactory* factory) +{ + std::pair retval = + s_factoryTable->insert(std::make_pair(typeCode, factory)); + if (retval.second == false) + { + return HRESULT_FROM_WIN32( ERROR_ALREADY_ASSIGNED ); + } + else + { + return DrError_OK; + } +} + +// +// Attempt to add factory to aggregate factory table +// If unable to insert, fail because of duplicate key +// +DrError DryadAddFactoryToAggregateTable(UInt16 typeCode, + DryadMetaDataParser:: + AggregateFactory* factory) +{ + std::pair retval = + s_aggregateFactoryTable->insert(std::make_pair(typeCode, factory)); + if (retval.second == false) + { + return HRESULT_FROM_WIN32( ERROR_ALREADY_ASSIGNED ); + } + else + { + return DrError_OK; + } +} + +// +// Get factory with specified typecode, create a new metadata record +// with reference to that factory and provided id, and then insert +// the metadata record into the record map +// +DrError DryadAddPropertyToMetaData(UInt16 enumId, UInt16 typeCode) +{ + // + // Get factory associated with typeCode + // + TagFactoryMap::iterator iter = s_factoryTable->find(typeCode); + if (iter == s_factoryTable->end()) + { + return DrError_InvalidProperty; + } + + // + // Build Metadata record + // + DryadMetaDataTypeRecord* r = new DryadMetaDataTypeRecord; + r->m_enumId = enumId; + r->m_typeCode = typeCode; + r->m_factory = iter->second; + + // + // Attempt to insert record into property type table + // If unable to add, report duplicate + // + std::pair retval = + s_propertyTypeTable->insert(std::make_pair(enumId, r)); + if (retval.second == false) + { + return HRESULT_FROM_WIN32( ERROR_ALREADY_ASSIGNED ); + } + else + { + return DrError_OK; + } +} + +// +// Build maps with all factories, aggregate factories, and property types +// +void DryadInitMetaDataTable() +{ + // + // Create maps for factories, aggregate factories, and property types + // + s_factoryTable = new TagFactoryMap; + s_aggregateFactoryTable = new AggregateFactoryMap; + s_propertyTypeTable = new TagTypeRecordMap; + + // + // Foreach cosmos factory, add it to factory table + // todo: Is this cosmos specific? If so, remove it. + // + UInt32 n = sizeof(s_cosmosTypeFactory) / sizeof(s_cosmosTypeFactory[0]); + UInt32 i; + for (i=0; iAddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_Void(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagVoidRef tag; + DrError err = + DryadMTagVoid::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_Boolean(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagBooleanRef tag; + DrError err = + DryadMTagBoolean::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_Int16(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagInt16Ref tag; + DrError err = + DryadMTagInt16::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_Int32(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagInt32Ref tag; + DrError err = + DryadMTagInt32::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_Int64(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagInt64Ref tag; + DrError err = + DryadMTagInt64::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_Double(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagDoubleRef tag; + DrError err = + DryadMTagDouble::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_UInt16(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUInt16Ref tag; + DrError err = + DryadMTagUInt16::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_TagIdValue(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUInt16Ref tag; + DrError err = + DryadMTagUInt16::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + + +DrError DrMetaDataTagFactory_UInt32(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUInt32Ref tag; + DrError err = + DryadMTagUInt32::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_UInt64(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUInt64Ref tag; + DrError err = + DryadMTagUInt64::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_HexUInt16(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUInt16Ref tag; + DrError err = + DryadMTagUInt16::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_HexUInt32(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUInt32Ref tag; + DrError err = + DryadMTagUInt32::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_HexUInt64(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUInt64Ref tag; + DrError err = + DryadMTagUInt64::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_String(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagStringRef tag; + DrError err = + DryadMTagString::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_Guid(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagGuidRef tag; + DrError err = + DryadMTagGuid::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_TimeStamp(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagTimeStampRef tag; + DrError err = + DryadMTagTimeStamp::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_TimeInterval(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagTimeIntervalRef tag; + DrError err = + DryadMTagTimeInterval::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_DrExitCode(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown:: + ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_DrExitCode, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_DrError(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagDrErrorRef tag; + DrError err = + DryadMTagDrError::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_Blob(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown::ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_Blob, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_Payload(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown::ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_Payload, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_EnvironmentBlock(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown::ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_EnvironmentBlock, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError + DrMetaDataTagFactory_AppendExtentOptions(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown:: + ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_AppendExtentOptions, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError + DrMetaDataTagFactory_AppendBlockOptions(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown:: + ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_AppendBlockOptions, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError + DrMetaDataTagFactory_FailureInjectionOptions(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown:: + ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_FailureInjectionOptions, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError + DrMetaDataTagFactory_SyncDirectiveOptions(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown:: + ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_SyncDirectiveOptions, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError + DrMetaDataTagFactory_SyncOptions(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown:: + ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_SyncOptions, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError + DrMetaDataTagFactory_ReadExtentOptions(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown:: + ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_ReadExtentOptions, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError + DrMetaDataTagFactory_AppendStreamOptions(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown:: + ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_AppendStreamOptions, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError + DrMetaDataTagFactory_StreamCapabilityBits(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown:: + ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_StreamCapabilityBits, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError + DrMetaDataTagFactory_EnumDirectoryOptions(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown:: + ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_EnumDirectoryOptions, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError + DrMetaDataTagFactory_EnInfoBits(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown:: + ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_EnInfoBits, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError + DrMetaDataTagFactory_UpdateExtentMetadataOptions(DrMemoryReader* + reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* + parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown:: + ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_UpdateExtentMetadataOptions, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError + DrMetaDataTagFactory_StreamInfoBits(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown:: + ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_StreamInfoBits, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError + DrMetaDataTagFactory_ExtentInfoBits(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown:: + ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_ExtentInfoBits, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError + DrMetaDataTagFactory_ExtentInstanceInfoBits(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagUnknownRef tag; + DrError err = + DryadMTagUnknown:: + ReadFromStreamWithType(reader, enumId, dataLen, + DrPropertyTagType_ExtentInstanceInfoBits, + &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError DrMetaDataTagFactory_BeginTag(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + if (enumId != Prop_Dryad_BeginTag) + { + return DrError_InvalidProperty; + } + + UInt16 tagValue; + DrError err + = reader->PeekNextUInt16Property(Prop_Dryad_BeginTag, &tagValue); + if (err == DrError_OK) + { + AggregateFactoryMap::iterator iter = + s_aggregateFactoryTable->find(tagValue); + if (iter == s_aggregateFactoryTable->end()) + { + err = DrError_InvalidProperty; + } + else + { + DryadMetaDataParser::AggregateFactory* factory = iter->second; + DryadMTagRef tag; + err = (*factory)(reader, &tag); + { + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + } + } + } + + return err; +} + +DrError DrMetaDataTagFactory_EndTag(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + /* this should always be consumed as a side-effect of the + BeginTag */ + return DrError_InvalidProperty; +} + +DrError DrMetaDataTagFactory_PropertyList(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + const void* data; + DrError err = reader->ReadNextProperty(&enumId, &dataLen, &data); + if (err == DrError_OK) + { + DryadMTagMetaDataRef tag; + err = DryadMTagMetaData::ReadFromArray(enumId, data, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + } + return err; +} + +DrError + DryadMetaDataTagFactory_Void(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagVoidRef tag; + DrError err = + DryadMTagVoid::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError + DryadMetaDataTagFactory_VertexCommand(DrMemoryReader* reader, + UInt16 enumId, + UInt32 dataLen, + DryadMetaDataParser* parent) +{ + DryadMTagVertexCommandRef tag; + DrError err = + DryadMTagVertexCommand::ReadFromStream(reader, enumId, dataLen, &tag); + if (err == DrError_OK) + { + parent->AddTag(tag); + } + else + { + LogAssert(tag == NULL); + } + return err; +} + +DrError + DryadMetaDataAggregateFactory_MetaData(DrMemoryReader* reader, + DryadMTagRef* pTag) +{ + DryadMTagMetaDataRef tag; + DrError err = DryadMTagMetaData::ReadFromStreamInAggregate(reader, &tag); + if (err == DrError_OK) + { + *pTag = tag; + } + return err; +} + +DrError + DryadMetaDataAggregateFactory_InputChannelDescription(DrMemoryReader* reader, + DryadMTagRef* pTag) +{ + DryadMTagInputChannelDescriptionRef tag; + DrError err = + DryadMTagInputChannelDescription::ReadFromStream(reader, &tag); + if (err == DrError_OK) + { + *pTag = tag; + } + return err; +} + +DrError + DryadMetaDataAggregateFactory_OutputChannelDescription(DrMemoryReader* reader, + DryadMTagRef* pTag) +{ + DryadMTagOutputChannelDescriptionRef tag; + DrError err = + DryadMTagOutputChannelDescription::ReadFromStream(reader, &tag); + if (err == DrError_OK) + { + *pTag = tag; + } + return err; +} + +DrError + DryadMetaDataAggregateFactory_VertexProcessStatus(DrMemoryReader* reader, + DryadMTagRef* pTag) +{ + DryadMTagVertexProcessStatusRef tag; + DrError err = DryadMTagVertexProcessStatus::ReadFromStream(reader, &tag); + if (err == DrError_OK) + { + *pTag = tag; + } + return err; +} + +DrError + DryadMetaDataAggregateFactory_VertexStatus(DrMemoryReader* reader, + DryadMTagRef* pTag) +{ + DryadMTagVertexStatusRef tag; + DrError err = DryadMTagVertexStatus::ReadFromStream(reader, &tag); + if (err == DrError_OK) + { + *pTag = tag; + } + return err; +} + +DrError + DryadMetaDataAggregateFactory_VertexCommandBlock(DrMemoryReader* reader, + DryadMTagRef* pTag) +{ + DryadMTagVertexCommandBlockRef tag; + DrError err = DryadMTagVertexCommandBlock::ReadFromStream(reader, &tag); + if (err == DrError_OK) + { + *pTag = tag; + } + return err; +} diff --git a/DryadVertex/VertexHost/system/common/src/dryadmetadatatag.cpp b/DryadVertex/VertexHost/system/common/src/dryadmetadatatag.cpp new file mode 100644 index 0000000..333ae71 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/src/dryadmetadatatag.cpp @@ -0,0 +1,52 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include + +#include + + +#pragma unmanaged + +DryadMTag::DryadMTag(UInt16 tagValue, UInt16 type) +{ + m_tag = tagValue; + m_type = type; +} + +DryadMTag::~DryadMTag() +{ +} + +UInt16 DryadMTag::GetTagValue() +{ + return m_tag; +} + +UInt16 DryadMTag::GetType() +{ + return m_type; +} + +void DryadMTag::Clone(DryadMTagRef* dstTag) +{ + *dstTag = this; +} diff --git a/DryadVertex/VertexHost/system/common/src/dryadmetadatatagtypes.cpp b/DryadVertex/VertexHost/system/common/src/dryadmetadatatagtypes.cpp new file mode 100644 index 0000000..c4054ad --- /dev/null +++ b/DryadVertex/VertexHost/system/common/src/dryadmetadatatagtypes.cpp @@ -0,0 +1,1138 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include + +#pragma unmanaged + +DryadMTagUnknown::DryadMTagUnknown(UInt16 tag, UInt32 dataLength, + UInt16 originalType) : + DryadMTag(tag, DrPropertyTagType_Unknown) +{ + m_dataLength = dataLength; + if ((tag & PropLengthMask) == PropLength_Short) + { + LogAssert(m_dataLength < 256); + } + m_data = new char[m_dataLength]; + m_originalType = originalType; +} + +DryadMTagUnknown::~DryadMTagUnknown() +{ + delete [] m_data; +} + +DryadMTagUnknown* DryadMTagUnknown::Create(UInt16 enumID, UInt32 dataLen, + void* data, UInt16 originalType) +{ + DryadMTagUnknown* tag = + new DryadMTagUnknown(enumID, dataLen, originalType); + ::memcpy(tag->GetData(), data, dataLen); + return tag; +} + +DrError DryadMTagUnknown::ReadFromStreamWithType(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + UInt16 originalType, + DryadMTagUnknownRef* outTag) +{ + DryadMTagUnknown* tag = + new DryadMTagUnknown(enumID, dataLen, originalType); + DrError cse = reader->ReadNextKnownProperty(enumID, dataLen, + tag->GetData()); + if (cse == DrError_OK) + { + outTag->Attach(tag); + } + else + { + tag->DecRef(); + (*outTag) = NULL; + } + return cse; +} + +DrError DryadMTagUnknown::ReadFromStream(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + DryadMTagUnknownRef* outTag) +{ + return ReadFromStreamWithType(reader, enumID, dataLen, + DrPropertyTagType_Unknown, outTag); +} + +UInt32 DryadMTagUnknown::GetDataLength() +{ + return m_dataLength; +} + +void* DryadMTagUnknown::GetData() +{ + return m_data; +} + +UInt16 DryadMTagUnknown::GetOriginalType() +{ + return m_originalType; +} + +DrError DryadMTagUnknown::Serialize(DrMemoryWriter* writer) +{ + return writer->WriteAnySizeBlobProperty(GetTagValue(), m_dataLength, m_data); +} + + +DryadMTagVoid::DryadMTagVoid(UInt16 tag) : + DryadMTag(tag, DryadPropertyTagType_Void) +{ +} + +DryadMTagVoid::~DryadMTagVoid() +{ +} + +DryadMTagVoid* DryadMTagVoid::Create(UInt16 tag) +{ + return new DryadMTagVoid(tag); +} + +DrError DryadMTagVoid::ReadFromStream(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + DryadMTagVoidRef* outTag) +{ + if (dataLen > 0) + { + return DrError_InvalidProperty; + } + + const void* dummyData; + DrError cse = reader->ReadNextProperty(&enumID, &dataLen, &dummyData); + if (cse == DrError_OK) + { + outTag->Attach(Create(enumID)); + } + else + { + (*outTag) = NULL; + } + return cse; +} + +DrError DryadMTagVoid::Serialize(DrMemoryWriter* writer) +{ + return writer->WriteAnySizeBlobProperty(GetTagValue(), 0, NULL); +} + + +DryadMTagBoolean::DryadMTagBoolean(UInt16 tag, bool val) : + DryadMTag(tag, DrPropertyTagType_Boolean) +{ + m_val = val; +} + +DryadMTagBoolean::~DryadMTagBoolean() +{ +} + +DryadMTagBoolean* DryadMTagBoolean::Create(UInt16 tag, bool val) +{ + return new DryadMTagBoolean(tag, val); +} + +DrError DryadMTagBoolean::ReadFromStream(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + DryadMTagBooleanRef* outTag) +{ + bool val; + DrError cse = reader->ReadNextBoolProperty(enumID, &val); + if (cse == DrError_OK) + { + outTag->Attach(Create(enumID, val)); + } + else + { + (*outTag) = NULL; + } + return cse; +} + +bool DryadMTagBoolean::GetBoolean() +{ + return m_val; +} + +DrError DryadMTagBoolean::Serialize(DrMemoryWriter* writer) +{ + return writer->WriteBoolProperty(GetTagValue(), m_val); +} + + +DryadMTagInt16::DryadMTagInt16(UInt16 tag, Int16 val) : + DryadMTag(tag, DrPropertyTagType_Int16) +{ + m_i16Val = val; +} + +DryadMTagInt16::~DryadMTagInt16() +{ +} + +DryadMTagInt16* DryadMTagInt16::Create(UInt16 tag, Int16 val) +{ + return new DryadMTagInt16(tag, val); +} + +DrError DryadMTagInt16::ReadFromStream(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + DryadMTagInt16Ref* outTag) +{ + Int16 val; + DrError cse = reader->ReadNextInt16Property(enumID, &val); + if (cse == DrError_OK) + { + outTag->Attach(Create(enumID, val)); + } + else + { + (*outTag) = NULL; + } + return cse; +} + +Int16 DryadMTagInt16::GetInt16() +{ + return m_i16Val; +} + +DrError DryadMTagInt16::Serialize(DrMemoryWriter* writer) +{ + return writer->WriteInt16Property(GetTagValue(), m_i16Val); +} + + +DryadMTagUInt16::DryadMTagUInt16(UInt16 tag, UInt16 val) : + DryadMTag(tag, DrPropertyTagType_UInt16) +{ + m_uI16Val = val; +} + +DryadMTagUInt16::~DryadMTagUInt16() +{ +} + +DryadMTagUInt16* DryadMTagUInt16::Create(UInt16 tag, UInt16 val) +{ + return new DryadMTagUInt16(tag, val); +} + +DrError DryadMTagUInt16::ReadFromStream(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + DryadMTagUInt16Ref* outTag) +{ + UInt16 val; + DrError cse = reader->ReadNextUInt16Property(enumID, &val); + if (cse == DrError_OK) + { + outTag->Attach(Create(enumID, val)); + } + else + { + (*outTag) = NULL; + } + return cse; +} + +UInt16 DryadMTagUInt16::GetUInt16() +{ + return m_uI16Val; +} + +DrError DryadMTagUInt16::Serialize(DrMemoryWriter* writer) +{ + return writer->WriteUInt16Property(GetTagValue(), m_uI16Val); +} + + +DryadMTagInt32::DryadMTagInt32(UInt16 tag, Int32 val) : + DryadMTag(tag, DrPropertyTagType_Int32) +{ + m_i32Val = val; +} + +DryadMTagInt32::~DryadMTagInt32() +{ +} + +DryadMTagInt32* DryadMTagInt32::Create(UInt16 tag, Int32 val) +{ + return new DryadMTagInt32(tag, val); +} + +DrError DryadMTagInt32::ReadFromStream(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + DryadMTagInt32Ref* outTag) +{ + Int32 val; + DrError cse = reader->ReadNextInt32Property(enumID, &val); + if (cse == DrError_OK) + { + outTag->Attach(Create(enumID, val)); + } + else + { + (*outTag) = NULL; + } + return cse; +} + +Int32 DryadMTagInt32::GetInt32() +{ + return m_i32Val; +} + +DrError DryadMTagInt32::Serialize(DrMemoryWriter* writer) +{ + return writer->WriteInt32Property(GetTagValue(), m_i32Val); +} + + +DryadMTagUInt32::DryadMTagUInt32(UInt16 tag, UInt32 val) : + DryadMTag(tag, DrPropertyTagType_UInt32) +{ + m_uI32Val = val; +} + +DryadMTagUInt32::~DryadMTagUInt32() +{ +} + +DryadMTagUInt32* DryadMTagUInt32::Create(UInt16 tag, UInt32 val) +{ + return new DryadMTagUInt32(tag, val); +} + +DrError DryadMTagUInt32::ReadFromStream(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + DryadMTagUInt32Ref* outTag) +{ + UInt32 val; + DrError cse = reader->ReadNextUInt32Property(enumID, &val); + if (cse == DrError_OK) + { + outTag->Attach(Create(enumID, val)); + } + else + { + (*outTag) = NULL; + } + return cse; +} + +UInt32 DryadMTagUInt32::GetUInt32() +{ + return m_uI32Val; +} + +DrError DryadMTagUInt32::Serialize(DrMemoryWriter* writer) +{ + return writer->WriteUInt32Property(GetTagValue(), m_uI32Val); +} + + +DryadMTagInt64::DryadMTagInt64(UInt16 tag, Int64 val) : + DryadMTag(tag, DrPropertyTagType_Int64) +{ + m_i64Val = val; +} + +DryadMTagInt64::~DryadMTagInt64() +{ +} + +DryadMTagInt64* DryadMTagInt64::Create(UInt16 tag, Int64 val) +{ + return new DryadMTagInt64(tag, val); +} + +DrError DryadMTagInt64::ReadFromStream(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + DryadMTagInt64Ref* outTag) +{ + Int64 val; + DrError cse = reader->ReadNextInt64Property(enumID, &val); + if (cse == DrError_OK) + { + outTag->Attach(Create(enumID, val)); + } + else + { + (*outTag) = NULL; + } + return cse; +} + +Int64 DryadMTagInt64::GetInt64() +{ + return m_i64Val; +} + +DrError DryadMTagInt64::Serialize(DrMemoryWriter* writer) +{ + return writer->WriteInt64Property(GetTagValue(), m_i64Val); +} + + +DryadMTagUInt64::DryadMTagUInt64(UInt16 tag, UInt64 val) : + DryadMTag(tag, DrPropertyTagType_UInt64) +{ + m_uI64Val = val; +} + +DryadMTagUInt64::~DryadMTagUInt64() +{ +} + +DryadMTagUInt64* DryadMTagUInt64::Create(UInt16 tag, UInt64 val) +{ + return new DryadMTagUInt64(tag, val); +} + +DrError DryadMTagUInt64::ReadFromStream(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + DryadMTagUInt64Ref* outTag) +{ + UInt64 val; + DrError cse = reader->ReadNextUInt64Property(enumID, &val); + if (cse == DrError_OK) + { + outTag->Attach(Create(enumID, val)); + } + else + { + (*outTag) = NULL; + } + return cse; +} + +UInt64 DryadMTagUInt64::GetUInt64() +{ + return m_uI64Val; +} + +DrError DryadMTagUInt64::Serialize(DrMemoryWriter* writer) +{ + return writer->WriteUInt64Property(GetTagValue(), m_uI64Val); +} + +DryadMTagDouble::DryadMTagDouble(UInt16 tag, double val) : + DryadMTag(tag, DrPropertyTagType_Double) +{ + m_doubleVal = val; +} + +DryadMTagDouble::~DryadMTagDouble() +{ +} + +DryadMTagDouble* DryadMTagDouble::Create(UInt16 tag, double val) +{ + return new DryadMTagDouble(tag, val); +} + +DrError DryadMTagDouble::ReadFromStream(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + DryadMTagDoubleRef* outTag) +{ + double val; + DrError cse = reader->ReadNextDoubleProperty(enumID, &val); + if (cse == DrError_OK) + { + outTag->Attach(Create(enumID, val)); + } + else + { + (*outTag) = NULL; + } + return cse; +} + +double DryadMTagDouble::GetDouble() +{ + return m_doubleVal; +} + +DrError DryadMTagDouble::Serialize(DrMemoryWriter* writer) +{ + return writer->WriteDoubleProperty(GetTagValue(), m_doubleVal); +} + +DryadMTagString::DryadMTagString(UInt16 enumID, const char* val) : + DryadMTag(enumID, DrPropertyTagType_String) +{ + m_string.Set(val); +} + +DryadMTagString::~DryadMTagString() +{ +} + +DryadMTagString* DryadMTagString::Create(UInt16 enumID, const char* val) +{ + return new DryadMTagString(enumID, val); +} + +DrError DryadMTagString::ReadFromStream(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + DryadMTagStringRef* outTag) +{ + DryadMTagString* tag = Create(enumID, NULL); + DrError cse = + reader->ReadNextStringProperty(enumID, + tag->GetWritableString(dataLen), + dataLen); + if (cse == DrError_OK) + { + outTag->Attach(tag); + } + else + { + tag->DecRef(); + (*outTag) = NULL; + } + return cse; +} + +const char* DryadMTagString::GetString() +{ + return m_string.GetString(); +} + +char* DryadMTagString::GetWritableString(size_t dataLen) +{ + return m_string.GetWritableBuffer(dataLen); +} + +DrError DryadMTagString::Serialize(DrMemoryWriter* writer) +{ + return writer->WriteLongStringProperty(GetTagValue(), m_string.GetString()); +} + + +DryadMTagGuid::DryadMTagGuid(UInt16 enumID, const DrGuid* val) : + DryadMTag(enumID, DrPropertyTagType_Guid) +{ + m_guid.Set(*val); +} + +DryadMTagGuid::~DryadMTagGuid() +{ +} + +DryadMTagGuid* DryadMTagGuid::Create(UInt16 enumID, const DrGuid* val) +{ + return new DryadMTagGuid(enumID, val); +} + +DrError DryadMTagGuid::ReadFromStream(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + DryadMTagGuidRef* outTag) +{ + DrGuid guid; + DrError cse = + reader->ReadNextGuidProperty(enumID, &guid); + if (cse == DrError_OK) + { + outTag->Attach(Create(enumID, &guid)); + } + else + { + (*outTag) = NULL; + } + return cse; +} + +const DrGuid* DryadMTagGuid::GetGuid() +{ + return &m_guid; +} + +DrError DryadMTagGuid::Serialize(DrMemoryWriter* writer) +{ + return writer->WriteGuidProperty(GetTagValue(), m_guid); +} + + +DryadMTagTimeStamp::DryadMTagTimeStamp(UInt16 tag, DrTimeStamp val) : + DryadMTag(tag, DrPropertyTagType_TimeStamp) +{ + m_val = val; +} + +DryadMTagTimeStamp::~DryadMTagTimeStamp() +{ +} + +DryadMTagTimeStamp* DryadMTagTimeStamp::Create(UInt16 tag, DrTimeStamp val) +{ + return new DryadMTagTimeStamp(tag, val); +} + +DrError DryadMTagTimeStamp::ReadFromStream(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + DryadMTagTimeStampRef* outTag) +{ + DrTimeStamp val; + DrError cse = reader->ReadNextTimeStampProperty(enumID, &val); + if (cse == DrError_OK) + { + outTag->Attach(Create(enumID, val)); + } + else + { + (*outTag) = NULL; + } + return cse; +} + +DrTimeStamp DryadMTagTimeStamp::GetTimeStamp() +{ + return m_val; +} + +DrError DryadMTagTimeStamp::Serialize(DrMemoryWriter* writer) +{ + return writer->WriteTimeStampProperty(GetTagValue(), m_val); +} + + +DryadMTagTimeInterval::DryadMTagTimeInterval(UInt16 tag, DrTimeInterval val) : + DryadMTag(tag, DrPropertyTagType_TimeInterval) +{ + m_val = val; +} + +DryadMTagTimeInterval::~DryadMTagTimeInterval() +{ +} + +DryadMTagTimeInterval* DryadMTagTimeInterval::Create(UInt16 tag, + DrTimeInterval val) +{ + return new DryadMTagTimeInterval(tag, val); +} + +DrError DryadMTagTimeInterval::ReadFromStream(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + DryadMTagTimeIntervalRef* outTag) +{ + DrTimeInterval val; + DrError cse = reader->ReadNextTimeIntervalProperty(enumID, &val); + if (cse == DrError_OK) + { + outTag->Attach(Create(enumID, val)); + } + else + { + (*outTag) = NULL; + } + return cse; +} + +DrTimeInterval DryadMTagTimeInterval::GetTimeInterval() +{ + return m_val; +} + +DrError DryadMTagTimeInterval::Serialize(DrMemoryWriter* writer) +{ + return writer->WriteTimeIntervalProperty(GetTagValue(), m_val); +} + + +DryadMTagDrError::DryadMTagDrError(UInt16 tag, DrError val) : + DryadMTag(tag, DrPropertyTagType_DrError) +{ + m_val = val; +} + +DryadMTagDrError::~DryadMTagDrError() +{ +} + +DryadMTagDrError* DryadMTagDrError::Create(UInt16 tag, DrError val) +{ + return new DryadMTagDrError(tag, val); +} + +DrError DryadMTagDrError::ReadFromStream(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + DryadMTagDrErrorRef* outTag) +{ + DrError val; + DrError cse = reader->ReadNextDrErrorProperty(enumID, &val); + if (cse == DrError_OK) + { + outTag->Attach(Create(enumID, val)); + } + else + { + (*outTag) = NULL; + } + return cse; +} + +DrError DryadMTagDrError::GetDrError() +{ + return m_val; +} + +DrError DryadMTagDrError::Serialize(DrMemoryWriter* writer) +{ + return writer->WriteDrErrorProperty(GetTagValue(), m_val); +} + + +DryadMTagMetaData::DryadMTagMetaData(UInt16 tag, DryadMetaData* val, + bool marshalAsAggregate) : + DryadMTag(tag, DryadPropertyTagType_MetaData) +{ + m_val = val; + m_marshalAsAggregate = marshalAsAggregate; +} + +DryadMTagMetaData::~DryadMTagMetaData() +{ +} + +DryadMTagMetaData* DryadMTagMetaData::Create(UInt16 tag, DryadMetaData* val, + bool marshalAsAggregate) +{ + return new DryadMTagMetaData(tag, val, marshalAsAggregate); +} + +DrError DryadMTagMetaData:: + ReadFromStreamInAggregate(DrMemoryReader* reader, + DryadMTagMetaDataRef* outTag) +{ + UInt16 tagValue; + DrError err + = reader->PeekNextUInt16Property(Prop_Dryad_BeginTag, &tagValue); + if (err == DrError_OK) + { + DryadMetaDataParser nestedParser; + err = reader->ReadAggregate(tagValue, &nestedParser, NULL); + if (err == DrError_OK) + { + DryadMetaData* parsed = nestedParser.GetMetaData(); + outTag->Attach(Create(tagValue, parsed, true)); + } + else + { + (*outTag) = NULL; + } + } + else + { + (*outTag) = NULL; + } + return err; +} + +DrError DryadMTagMetaData::ReadFromArray(UInt16 tagValue, + const void* data, UInt32 dataLen, + DryadMTagMetaDataRef* outTag) +{ + DryadMetaDataParser parser; + DrError err = parser.ParseBuffer(data, dataLen); + if (err == DrError_OK) + { + DryadMetaData* parsed = parser.GetMetaData(); + outTag->Attach(Create(tagValue, parsed, false)); + } + else + { + (*outTag) = NULL; + } + return err; +} + +DryadMetaData* DryadMTagMetaData::GetMetaData() +{ + return m_val; +} + +DrError DryadMTagMetaData::Serialize(DrMemoryWriter* writer) +{ + return m_val->WriteAsAggregate(writer, GetTagValue(), true); +} + +DryadMTag* DryadMTagMetaData::Clone() +{ + DryadMetaDataRef clonedValue; + m_val->Clone(&clonedValue); + return Create(GetTagValue(), clonedValue, m_marshalAsAggregate); +} + + +DryadMTagVertexCommand::DryadMTagVertexCommand(UInt16 tag, + DVertexCommand val) : + DryadMTag(tag, DryadPropertyTagType_VertexCommand) +{ + m_val = val; +} + +DryadMTagVertexCommand::~DryadMTagVertexCommand() +{ +} + +DryadMTagVertexCommand* DryadMTagVertexCommand::Create(UInt16 tag, + DVertexCommand val) +{ + return new DryadMTagVertexCommand(tag, val); +} + +DrError DryadMTagVertexCommand:: + ReadFromStream(DrMemoryReader* reader, + UInt16 enumID, UInt32 dataLen, + DryadMTagVertexCommandRef* outTag) +{ + UInt32 val; + DrError cse = reader->ReadNextUInt32Property(enumID, &val); + if (cse == DrError_OK) + { + if (val < DVertexCommand_Max) + { + outTag->Attach(Create(enumID, (DVertexCommand) val)); + } + else + { + (*outTag) = NULL; + cse = DrError_InvalidProperty; + } + } + else + { + (*outTag) = NULL; + } + return cse; +} + +DVertexCommand DryadMTagVertexCommand::GetVertexCommand() +{ + return m_val; +} + +DrError DryadMTagVertexCommand::Serialize(DrMemoryWriter* writer) +{ + return writer->WriteUInt32Property(GetTagValue(), (UInt32) m_val); +} + + +DryadMTagInputChannelDescription:: + DryadMTagInputChannelDescription(UInt16 tag, + DryadInputChannelDescription* val) : + DryadMTag(tag, DryadPropertyTagType_InputChannelDescription) +{ + m_val = val; +} + +DryadMTagInputChannelDescription::~DryadMTagInputChannelDescription() +{ + delete m_val; +} + +DryadMTagInputChannelDescription* DryadMTagInputChannelDescription:: + Create(UInt16 tag, DryadInputChannelDescription* val) +{ + return new DryadMTagInputChannelDescription(tag, val); +} + +DrError DryadMTagInputChannelDescription:: + ReadFromStream(DrMemoryReader* reader, + DryadMTagInputChannelDescriptionRef* outTag) +{ + UInt16 tagValue; + DrError err + = reader->PeekNextUInt16Property(Prop_Dryad_BeginTag, &tagValue); + if (err == DrError_OK) + { + DryadInputChannelDescription* cDescription = + new DryadInputChannelDescription(); + err = reader->ReadAggregate(tagValue, cDescription, NULL); + if (err == DrError_OK) + { + outTag->Attach(Create(tagValue, cDescription)); + } + else + { + delete cDescription; + (*outTag) = NULL; + } + } + else + { + (*outTag) = NULL; + } + return err; +} + +DryadInputChannelDescription* DryadMTagInputChannelDescription:: + GetInputChannelDescription() +{ + return m_val; +} + +DrError DryadMTagInputChannelDescription:: + Serialize(DrMemoryWriter* writer) +{ + return m_val->Serialize(writer); +} + + +DryadMTagOutputChannelDescription:: + DryadMTagOutputChannelDescription(UInt16 tag, + DryadOutputChannelDescription* val) : + DryadMTag(tag, DryadPropertyTagType_OutputChannelDescription) +{ + m_val = val; +} + +DryadMTagOutputChannelDescription::~DryadMTagOutputChannelDescription() +{ + delete m_val; +} + +DryadMTagOutputChannelDescription* DryadMTagOutputChannelDescription:: + Create(UInt16 tag, DryadOutputChannelDescription* val) +{ + return new DryadMTagOutputChannelDescription(tag, val); +} + +DrError DryadMTagOutputChannelDescription:: + ReadFromStream(DrMemoryReader* reader, + DryadMTagOutputChannelDescriptionRef* outTag) +{ + UInt16 tagValue; + DrError err + = reader->PeekNextUInt16Property(Prop_Dryad_BeginTag, &tagValue); + if (err == DrError_OK) + { + DryadOutputChannelDescription* cDescription = + new DryadOutputChannelDescription(); + err = reader->ReadAggregate(tagValue, cDescription, NULL); + if (err == DrError_OK) + { + outTag->Attach(Create(tagValue, cDescription)); + } + else + { + delete cDescription; + (*outTag) = NULL; + } + } + else + { + (*outTag) = NULL; + } + return err; +} + +DryadOutputChannelDescription* DryadMTagOutputChannelDescription:: + GetOutputChannelDescription() +{ + return m_val; +} + +DrError DryadMTagOutputChannelDescription:: + Serialize(DrMemoryWriter* writer) +{ + return m_val->Serialize(writer); +} + + +DryadMTagVertexProcessStatus:: + DryadMTagVertexProcessStatus(UInt16 tag, + DVertexProcessStatus* val) : + DryadMTag(tag, DryadPropertyTagType_VertexProcessStatus) +{ + m_val = val; +} + +DryadMTagVertexProcessStatus::~DryadMTagVertexProcessStatus() +{ +} + +DryadMTagVertexProcessStatus* DryadMTagVertexProcessStatus:: + Create(UInt16 tag, DVertexProcessStatus* val) +{ + return new DryadMTagVertexProcessStatus(tag, val); +} + +DrError DryadMTagVertexProcessStatus:: + ReadFromStream(DrMemoryReader* reader, + DryadMTagVertexProcessStatusRef* outTag) +{ + UInt16 tagValue; + DrError err + = reader->PeekNextUInt16Property(Prop_Dryad_BeginTag, &tagValue); + if (err == DrError_OK) + { + DrRef pStatus; + pStatus.Attach(new DVertexProcessStatus()); + err = reader->ReadAggregate(tagValue, pStatus, NULL); + if (err == DrError_OK) + { + outTag->Attach(Create(tagValue, pStatus)); + } + else + { + (*outTag) = NULL; + } + } + else + { + (*outTag) = NULL; + } + return err; +} + +DVertexProcessStatus* DryadMTagVertexProcessStatus::GetVertexProcessStatus() +{ + return m_val; +} + +DrError DryadMTagVertexProcessStatus::Serialize(DrMemoryWriter* writer) +{ + return m_val->Serialize(writer); +} + + +DryadMTagVertexStatus::DryadMTagVertexStatus(UInt16 tag, + DVertexStatus* val) : + DryadMTag(tag, DryadPropertyTagType_VertexStatus) +{ + m_val = val; +} + +DryadMTagVertexStatus::~DryadMTagVertexStatus() +{ +} + +DryadMTagVertexStatus* DryadMTagVertexStatus:: + Create(UInt16 tag, DVertexStatus* val) +{ + return new DryadMTagVertexStatus(tag, val); +} + +DrError DryadMTagVertexStatus:: + ReadFromStream(DrMemoryReader* reader, + DryadMTagVertexStatusRef* outTag) +{ + UInt16 tagValue; + DrError err + = reader->PeekNextUInt16Property(Prop_Dryad_BeginTag, &tagValue); + if (err == DrError_OK) + { + DrRef vStatus; + vStatus.Attach(new DVertexStatus()); + err = reader->ReadAggregate(tagValue, vStatus, NULL); + if (err == DrError_OK) + { + outTag->Attach(Create(tagValue, vStatus)); + } + else + { + (*outTag) = NULL; + } + } + else + { + (*outTag) = NULL; + } + return err; +} + +DVertexStatus* DryadMTagVertexStatus::GetVertexStatus() +{ + return m_val; +} + +DrError DryadMTagVertexStatus::Serialize(DrMemoryWriter* writer) +{ + return m_val->Serialize(writer); +} + + +DryadMTagVertexCommandBlock:: + DryadMTagVertexCommandBlock(UInt16 tag, + DVertexCommandBlock* val) : + DryadMTag(tag, DryadPropertyTagType_VertexCommandBlock) +{ + m_val = val; +} + +DryadMTagVertexCommandBlock::~DryadMTagVertexCommandBlock() +{ +} + +DryadMTagVertexCommandBlock* DryadMTagVertexCommandBlock:: + Create(UInt16 tag, DVertexCommandBlock* val) +{ + return new DryadMTagVertexCommandBlock(tag, val); +} + +DrError DryadMTagVertexCommandBlock:: + ReadFromStream(DrMemoryReader* reader, + DryadMTagVertexCommandBlockRef* outTag) +{ + UInt16 tagValue; + DrError err + = reader->PeekNextUInt16Property(Prop_Dryad_BeginTag, &tagValue); + if (err == DrError_OK) + { + DrRef cBlock; + cBlock.Attach(new DVertexCommandBlock()); + err = reader->ReadAggregate(tagValue, cBlock, NULL); + if (err == DrError_OK) + { + outTag->Attach(Create(tagValue, cBlock)); + } + else + { + (*outTag) = NULL; + } + } + else + { + (*outTag) = NULL; + } + return err; +} + +DVertexCommandBlock* DryadMTagVertexCommandBlock::GetVertexCommandBlock() +{ + return m_val; +} + +DrError DryadMTagVertexCommandBlock::Serialize(DrMemoryWriter* writer) +{ + return m_val->Serialize(writer); +} diff --git a/DryadVertex/VertexHost/system/common/src/dryadnativeport.cpp b/DryadVertex/VertexHost/system/common/src/dryadnativeport.cpp new file mode 100644 index 0000000..0fe0320 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/src/dryadnativeport.cpp @@ -0,0 +1,923 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include +#include +#include + + +#pragma unmanaged + +// +// No cleanup required for handlerbase +// +DryadNativePort::HandlerBase::~HandlerBase() +{ +} + +OVERLAPPED* DryadNativePort::HandlerBase::GetOverlapped() +{ + return this; +} + +// +// Create uninitialized handler +// +DryadNativePort::Handler::Handler() +{ + hEvent = 0; + Offset = 0; + OffsetHigh = 0; + m_bytesToTransfer = 0; + + // todo: remove cosmos code + //JCm_cosmosPosition.ExtentIndex = DR_INVALID_EXTENT_INDEX; + //JCm_cosmosPosition.Offset = DR_UNKNOWN_OFFSET; +} + +// +// No cleanup required for handler +// +DryadNativePort::Handler::~Handler() +{ +} + +// +// Initialize handler offset and number of bytes to transfer +// +void DryadNativePort::Handler::InitializeInternal(UInt32 bytesToTransfer, + UInt64 requestOffset) +{ + hEvent = 0; + Offset = (DWORD) (requestOffset & 0xffffffff); + OffsetHigh = (DWORD) (requestOffset >> 32); + m_bytesToTransfer = (SIZE_T) bytesToTransfer; + + // todo: remove cosmost code + //JCm_cosmosPosition.ExtentIndex = 0; + //JCm_cosmosPosition.Offset = requestOffset; +} + +// +// Return number of bytes to tranfer +// +PSIZE_T DryadNativePort::Handler::GetNumberOfBytesToTransferPtr() +{ + return &m_bytesToTransfer; +} + +// todo: remove commented code +/*JC +DR_STREAM_POSITION *DryadNativePort::Handler::GetDryadPositionPtr() +{ + return &m_cosmosPosition; +}*/ + +// +// Constructor. Create a pool of worker thread handles for later use. +// +DryadNativePort::DryadNativePort(DWORD numWorkerThreads, + DWORD numConcurrentThreads) +{ + m_state = BPS_Stopped; + m_outstandingRequests = 0; + + m_numWorkerThreads = numWorkerThreads; + m_numConcurrentThreads = numConcurrentThreads; + m_completionPort = INVALID_HANDLE_VALUE; + + m_threadHandle = new HANDLE[m_numWorkerThreads]; + LogAssert(m_threadHandle != NULL); + DWORD i; + for (i=0; i(arg); + self->WriteFileThread(); + return 0; +} + +void DryadNativePort::WriteFileThread() +{ + bool mustStop = false; + + do + { + WaitForSingleObject(m_writeFileEvent, INFINITE); + + bool mustBreak = false; + do + { + DryadWriteFileList writeList; + { + AutoCriticalSection acs(&m_writeFileCS); + + writeList.TransitionToTail(&m_writeFileList); + mustStop = m_writeFileFinished; + + if (writeList.IsEmpty()) + { + ResetEvent(m_writeFileEvent); + mustBreak = true; + } + } + + while (writeList.IsEmpty() == false) + { + WriteFileRequest* wfr = + writeList.CastOut(writeList.RemoveHead()); + + DWORD bytesToTransfer = + (DWORD) *(wfr->m_request->GetNumberOfBytesToTransferPtr()); + BOOL bRet = ::WriteFile(wfr->m_fileHandle, + wfr->m_request->GetData(), + bytesToTransfer, + NULL, + wfr->m_request->GetOverlapped()); + + if (bRet == 0) + { + DWORD dErr = ::GetLastError(); + + DrError cse; + if (dErr == ERROR_HANDLE_EOF) + { + cse = DrError_EndOfStream; + } + else + { + cse = DrErrorFromWin32(dErr); + } + + if (dErr != ERROR_IO_PENDING) + { + wfr->m_request->ProcessIO(cse, 0); + + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_outstandingRequests > 0); + --m_outstandingRequests; + } + } + } + + delete wfr; + } + } while (mustBreak == false); + } while (mustStop == false); +} + +// +// Worker threads' "main" +// +unsigned __stdcall DryadNativePort::ThreadFunc(void* arg) +{ + // + // Get port reference and validate initialization + // + DryadNativePort* self = (DryadNativePort *) arg; + LogAssert(self->m_completionPort != INVALID_HANDLE_VALUE); + + DrLogI("DryadNativePort::ThreadFunc starting thread"); + + // + // Get I/O completion events until shutdown event received + // + bool finished = false; + do + { + DWORD numBytes; + ULONG_PTR completionKey; + LPOVERLAPPED overlapped; + + // todo: do we want this log? +// DrLogD( +// "DryadNativePort::ThreadFunc waiting for completion event"); + + // + // Attempt to dequeue a completion packet + // + BOOL retval = ::GetQueuedCompletionStatus(self->m_completionPort, + &numBytes, + &completionKey, + &overlapped, + INFINITE); + + // todo: do we want this log? +// DrLogD( +// "DryadNativePort::ThreadFunc received completion event", +// "retval: %d", retval); + + if (completionKey == (ULONG_PTR) self) + { + // + // If completionkey is the native port, finish or fail depending on return code + // + if (retval != 0) + { + // + // If GetQueuedCompletionStatus succeeded, make sure nothing was + // transfered and stop waiting + // + LogAssert(numBytes == 0); + LogAssert(overlapped == NULL); + finished = true; + + DrLogI("DryadNativePort::ThreadFunc received shutdown event"); + } + else + { + // + // If GetQueuedCompletitionStatus failed, log error and fail + // + DWORD errCode = GetLastError(); + DrLogA("DryadNativePort::GetQueuedCompletionStatus - error code: 0x%08x", HRESULT_FROM_WIN32(errCode)); + } + } + else + { + // + // If completionkey is not the native port, validate results and + // handle any errors + // + LogAssert(completionKey == NULL); + LogAssert(overlapped != NULL); + HandlerBase* handler = (HandlerBase *) overlapped; + DrError cse; + if (retval != 0) + { + // + // If success, then everything's ok + // + cse = DrError_OK; + } + else + { + // + // If failure, set reason. + // + DWORD errCode = GetLastError(); + if (errCode == ERROR_HANDLE_EOF) + { + cse = DrError_EndOfStream; + if (numBytes > 0) + { + // + // If end of file with bytes remaining, log unexpected + // completion. + // + DrLogD("Unexpected non-zero byte count on EOF. numbytes %u", (UInt32) numBytes); + numBytes = 0; + } + } + else + { + cse = DrErrorFromWin32(errCode); + } + } + + // todo: do we want this log? +// DrLogD( +// "Forwarding completed IO to client", +// "error %s; numbytes %u", +// DRERRORSTRING(cse), (UInt32) numBytes); + + // + // + // todo: figure out which polymorphic ProcessIO this goes to + // + handler->ProcessIO(cse, (UInt32) numBytes); + + // + // Enter critical section and decrement number of outstanding requests + // + { + AutoCriticalSection acs(&(self->m_baseDR)); + LogAssert(self->m_outstandingRequests > 0); + --(self->m_outstandingRequests); + } + } + } while (!finished); + + // + // Exit cleanly once shutdown event received + // + DrLogI("DryadNativePort::ThreadFunc exiting thread"); + return 0; +} + +void DryadNativePort::AssociateHandle(HANDLE fileHandle) +{ + { + AutoCriticalSection acs(&m_baseDR); + LogAssert(m_state == BPS_Running); + + HANDLE completionPort = ::CreateIoCompletionPort(fileHandle, + m_completionPort, + NULL, + 0); + if (completionPort == NULL) + { + DrLogA("CreateIoCompletionPort failed. error: %u", GetLastError()); + } + else + { + LogAssert(completionPort == m_completionPort); + } + } +} + +// +// Begin waiting for events on a completion port +// +void DryadNativePort::Start() +{ + // + // Enter a critical section to create the completion port and start worker threads + // + { + AutoCriticalSection acs(&m_baseDR); + + // + // Ensure that port is in initialized, but stopped state + // + LogAssert(m_state == BPS_Stopped); + LogAssert(m_completionPort == INVALID_HANDLE_VALUE); + LogAssert(m_outstandingRequests == 0); + + DrLogI("DryadNativePort::Start entered"); + + // + // Create an io completion port without associating it with a file + // Assert that it is correctly created - this is probably safe + // because of the lack of associated file + // + m_completionPort = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, + NULL, + NULL, + m_numConcurrentThreads); + LogAssert(m_completionPort != NULL); + + DrLogI("DryadNativePort::Start created completion port"); + + // + // Create each worker thread and have them wait for IO completion events + // + DWORD i; + for (i=0; i= WAIT_OBJECT_0 &&*/ + waitRet < (WAIT_OBJECT_0 + m_numWorkerThreads)); + + DrLogI("DryadNativePort::Stop all worker threads have terminated"); + + { + AutoCriticalSection acs(&m_writeFileCS); + + m_writeFileFinished = true; + } + + waitRet = ::WaitForSingleObject(m_writeFileHandle, INFINITE); + LogAssert(waitRet == WAIT_OBJECT_0); + + { + AutoCriticalSection acs(&m_writeFileCS); + + LogAssert(m_writeFileList.IsEmpty()); + } + + DrLogI("DryadNativePort::Stop all threads have terminated"); + + { + AutoCriticalSection acs(&m_baseDR); + + BOOL bRetval; + + for (i=0; iGetNumberOfBytesToTransferPtr()); + + // + // Read requested number of bytes from a file asynchronously + // + BOOL bRet = ::ReadFile(fileHandle, + request->GetData(), + bytesToTransfer, + NULL, + request->GetOverlapped()); + + if (bRet == 0) + { + DWORD dErr = ::GetLastError(); + + DrError cse; + if (dErr == ERROR_HANDLE_EOF) + { + // + // If reached EOF, report EOS error + // + cse = DrError_EndOfStream; + } + else + { + cse = DrErrorFromWin32(dErr); + } + + if (dErr != ERROR_IO_PENDING) + { + // + // If not IO pending, asynchronous function has failed running, so process the error + // + request->ProcessIO(cse, 0); + + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_outstandingRequests > 0); + --m_outstandingRequests; + } + } + } +} + +void DryadNativePort::QueueNativeXComputeRead(XCPROCESSFILEHANDLE fileHandle, + Handler* request, + UInt64* readPosition, + DrError* operationStatePtr) +{ + LogAssert(request != NULL); + + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_state == BPS_Running); + ++m_outstandingRequests; + } + + XC_ASYNC_INFO asyncInfo; + memset(&asyncInfo, 0, sizeof(asyncInfo)); + asyncInfo.cbSize = sizeof(asyncInfo); + asyncInfo.IOCP = m_completionPort; + asyncInfo.pOperationState = operationStatePtr; + asyncInfo.pOverlapped = request->GetOverlapped(); + asyncInfo.CompletionKey = NULL; + + XCERROR hr = XcReadProcessFile(fileHandle, request->GetData(), request->GetNumberOfBytesToTransferPtr(), readPosition, &asyncInfo); + + if (hr != S_OK) + { + if (hr != HRESULT_FROM_WIN32(ERROR_IO_PENDING)) + { + request->ProcessIO(hr, 0); + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_outstandingRequests > 0); + --m_outstandingRequests; + } + } + } +} + +/*JCvoid DryadNativePort::QueueDryadRead(DRHANDLE streamHandle, + DrError* pendingStatePtr, + UInt64 streamOffset, + Handler* request) +{ + LogAssert(request != NULL); + + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_state == BPS_Running); + ++m_outstandingRequests; + } + + DR_ASYNC_INFO asyncInfo; + memset(&asyncInfo, 0, sizeof(asyncInfo)); + asyncInfo.cbSize = sizeof(asyncInfo); + asyncInfo.IOCP = m_completionPort; + asyncInfo.pOperationState = pendingStatePtr; + asyncInfo.pOverlapped = request->GetOverlapped(); + + DrLogI( + "Queuing cosmos read", + "streamhandle=%p, offset=%I64u, buffsize=%I64u", + streamHandle, streamOffset, + (UInt64) *(request->GetNumberOfBytesToTransferPtr())); + DR_STREAM_POSITION *pReadPosition; + pReadPosition = request->GetDryadPositionPtr(); + pReadPosition->ExtentIndex = 0; + pReadPosition->Offset = streamOffset; + DrError err = ::DrReadStream(streamHandle, + request->GetData(), + request->GetNumberOfBytesToTransferPtr(), + 0, + pReadPosition, + &asyncInfo); + + LogAssert(err != DrError_OK); + + if (err == DrErrorFromWin32( ERROR_IO_PENDING ) ) { + err = DrError_OK; + } else if (err == DrErrorFromWin32( ERROR_HANDLE_EOF ) ) { + err = DrError_EndOfStream; + } + + if (err != DrError_OK) + { + DrLogI( + "Dryad read failed immediately", + "streamhandle=%p, offset=%I64u, err=%s", + streamHandle, streamOffset, DRERRORSTRING(err)); + + *pendingStatePtr = err; + + request->ProcessIO(DrError_OK, 0); + + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_outstandingRequests > 0); + --m_outstandingRequests; + } + } +}*/ + +void DryadNativePort::QueueNativeWrite(HANDLE fileHandle, Handler* request) +{ + LogAssert(request != NULL); + + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_state == BPS_Running); + ++m_outstandingRequests; + } + + WriteFileRequest* wfr = new + WriteFileRequest(fileHandle, request); + { + AutoCriticalSection acs(&m_writeFileCS); + + BOOL mustWake = m_writeFileList.IsEmpty(); + m_writeFileList.InsertAsTail(m_writeFileList.CastIn(wfr)); + if (mustWake) + { + SetEvent(m_writeFileEvent); + } + } + +#if 0 + DWORD bytesToTransfer = + (DWORD) *(request->GetNumberOfBytesToTransferPtr()); + + BOOL bRet = ::WriteFile(fileHandle, + request->GetData(), + bytesToTransfer, + NULL, + request->GetOverlapped()); + + if (bRet == 0) + { + DWORD dErr = ::GetLastError(); + + DrError cse; + if (dErr == ERROR_HANDLE_EOF) + { + cse = DrError_EndOfStream; + } + else + { + cse = DrErrorFromWin32(dErr); + } + + if (dErr != ERROR_IO_PENDING) + { + request->ProcessIO(cse, 0); + + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_outstandingRequests > 0); + --m_outstandingRequests; + } + } + } +#endif +} + +/*JCvoid DryadNativePort::QueueDryadWrite(DRHANDLE streamHandle, + DrError* pendingStatePtr, + UInt64 streamOffset, + Handler* request) +{ + LogAssert(request != NULL); + + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_state == BPS_Running); + ++m_outstandingRequests; + } + + DR_ASYNC_INFO asyncInfo; + memset(&asyncInfo, 0, sizeof(asyncInfo)); + asyncInfo.cbSize = sizeof(asyncInfo); + asyncInfo.IOCP = m_completionPort; + asyncInfo.pOperationState = pendingStatePtr; + asyncInfo.pOverlapped = request->GetOverlapped(); + + DR_STREAM_POSITION *pAppendPosition; + pAppendPosition = request->GetDryadPositionPtr(); + pAppendPosition->ExtentIndex = 0; + pAppendPosition->Offset = streamOffset; + + DrLogI( + "Queueing Dryad append", + "streamhandle=%p, offset=%I64u, numBytes=%I64u", + streamHandle, streamOffset, + (UInt64) *(request->GetNumberOfBytesToTransferPtr())); + DrError err = ::DrAppendStream(streamHandle, + request->GetData(), + (*request->GetNumberOfBytesToTransferPtr()), + DR_FIXED_OFFSET_APPEND, + pAppendPosition, + request->GetNumberOfBytesToTransferPtr(), + &asyncInfo); + + LogAssert(err != DrError_OK); + + if (err == DrErrorFromWin32( ERROR_IO_PENDING ) ) { + err = DrError_OK; + } + + if (err != DrError_OK) + { + DrLogI( + "Dryad append failed immediately", + "streamhandle=%p, offset=%I64u, err=%s", + streamHandle, streamOffset, DRERRORSTRING(err)); + + *pendingStatePtr = err; + + request->ProcessIO(DrError_OK, 0); + + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_outstandingRequests > 0); + --m_outstandingRequests; + } + } +} + +void DryadNativePort::QueueDryadSetStreamProperties(const char* uri, + DrError* pendingStatePtr, + Handler* request) +{ + LogAssert(request != NULL); + + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_state == BPS_Running); + ++m_outstandingRequests; + } + + DR_ASYNC_INFO asyncInfo; + memset(&asyncInfo, 0, sizeof(asyncInfo)); + asyncInfo.cbSize = sizeof(asyncInfo); + asyncInfo.IOCP = m_completionPort; + asyncInfo.pOperationState = pendingStatePtr; + asyncInfo.pOverlapped = request->GetOverlapped(); + + PCDR_STREAM_PROPERTIES properties = + (PCDR_STREAM_PROPERTIES) request->GetData(); + + DrTimeInterval expireInterval = + properties->ExpirePeriod * DrTimeInterval_100ns; + DrLogI( + "Queueing Dryad set stream properties", + "stream=%s expirePeriod=%s", + uri, DRTIMEINTERVALSTRING(expireInterval)); + + DrError err = ::DrSetStreamProperties(uri, + properties, + &asyncInfo); + + LogAssert(err != DrError_OK); + + if (err == DrErrorFromWin32( ERROR_IO_PENDING ) ) { + err = DrError_OK; + } + + if (err != DrError_OK) + { + DrLogI( + "Dryad set stream properties failed immediately", + "stream=%s, err=%s", + uri, DRERRORSTRING(err)); + + *pendingStatePtr = err; + + request->ProcessIO(DrError_OK, 0); + + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_outstandingRequests > 0); + --m_outstandingRequests; + } + } +}*/ + +void DryadNativePort::IncrementOutstandingRequests() +{ + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_state == BPS_Running); + ++m_outstandingRequests; + } +} + +void DryadNativePort::DecrementOutstandingRequests() +{ + { + AutoCriticalSection acs (&m_baseDR); + + LogAssert(m_outstandingRequests > 0); + --m_outstandingRequests; + } +} + +HANDLE DryadNativePort::GetCompletionPort() +{ + return m_completionPort; +} diff --git a/DryadVertex/VertexHost/system/common/src/dryadopaqueresources.cpp b/DryadVertex/VertexHost/system/common/src/dryadopaqueresources.cpp new file mode 100644 index 0000000..e99d0fb --- /dev/null +++ b/DryadVertex/VertexHost/system/common/src/dryadopaqueresources.cpp @@ -0,0 +1,32 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + + +#pragma unmanaged + +DryadProcessIdentifier::~DryadProcessIdentifier() +{ +} + +DryadMachineIdentifier::~DryadMachineIdentifier() +{ +} diff --git a/DryadVertex/VertexHost/system/common/src/dryadpropertydumper.cpp b/DryadVertex/VertexHost/system/common/src/dryadpropertydumper.cpp new file mode 100644 index 0000000..1ca06bc --- /dev/null +++ b/DryadVertex/VertexHost/system/common/src/dryadpropertydumper.cpp @@ -0,0 +1,187 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include +#include +#include +#include +#include + +#pragma unmanaged + +const char* g_dVertexCommandText[DVertexCommand_Max] = { + "Start", + "ReOpenChannels", + "Terminate" +}; + +//JC +#if 0 +typedef DrError (*PropertyConverter)(DrPropertyDumper *pDumper, UInt16 enumId, const char *propertyName); + +typedef struct { + UInt16 value; + const char *pszDescription; + PropertyConverter pConverter; +} PropEntry; + +static PropEntry g_DryadPropertyMap[] = { + +#ifdef DEFINE_DRPROPERTY +#undef DEFINE_DRPROPERTY +#endif + +#ifdef DEFINE_DRYADPROPERTY +#undef DEFINE_DRYADPROPERTY +#endif + +#define DEFINE_DRPROPERTY(name, number, type, description) {name, description, DrPropertyToText_##type}, + +#define DEFINE_DRYADPROPERTY(name, number, type, description) {name, description, DryadPropertyToText_##type}, + +#include "dryadproperties.h" + +#undef DEFINE_DRPROPERTY +#undef DEFINE_DRYADPROPERTY +}; + + + +typedef struct { + UInt16 value; + const char *pszDescription; +} TagEntry; + +static TagEntry g_DryadTagMap[] = { + +#ifdef DEFINE_DRYADTAG +#undef DEFINE_DRYADTAG +#endif + +#define DEFINE_DRYADTAG(name, number, description, type) {name, description}, + +#include "dryadtags.h" + +#undef DEFINE_DRYADTAG +}; + + +typedef struct { + DrError value; + const char *pszDescription; +} ErrEntry; + +static ErrEntry g_DryadErrorMap[] = { +#ifdef DEFINE_DRYAD_ERROR +#undef DEFINE_DRYAD_ERROR +#endif + +#define DEFINE_DRYAD_ERROR(name, number, description) {name, description}, + +#include "dryaderror.h" + +#undef DEFINE_DRYAD_ERROR +}; + + +DrError DryadPropertyToText_Void(DrPropertyDumper *pDumper, + UInt16 enumId, + const char *propertyName) +{ + UInt16 actualEnumId; + UInt32 length; + DrError err = pDumper->GetReader()->ReadNextPropertyTag(&actualEnumId, &length); + if (err == DrError_OK) { + if (enumId != actualEnumId || length != 0) { + err = pDumper->GetReader()->SetStatus(DrError_InvalidProperty); + } else { + err = pDumper->WriteSimpleTagValue(propertyName, "Void"); + } + } + return err; +} + +DrError DryadPropertyToText_VertexCommand(DrPropertyDumper *pDumper, + UInt16 enumId, + const char *propertyName) +{ + char descBuffer[100]; + const char* desc; + UInt32 val; + DrError err = pDumper->GetReader()->ReadNextUInt32Property(enumId, &val); + if (err == DrError_OK) { + if (val < DVertexCommand_Max) + { + desc = g_dVertexCommandText[val]; + } + else + { + HRESULT hr = ::StringCbPrintfA(descBuffer, sizeof(descBuffer), + "Unknown state %u", val); + LogAssert(SUCCEEDED(hr)); + desc = descBuffer; + } + + err = pDumper->WriteSimpleTagValue(propertyName, desc); + } + return err; +} + +void DryadInitErrorTable() +{ + UInt32 n = sizeof(g_DryadErrorMap) / sizeof(g_DryadErrorMap[0]); + UInt32 i; + for (i=0; iGetOutstandingRequests(); +//VS fprintf(stdout, "DryadShutdownXCompute: %u outstanding requests\n", n); + while (n > 1) // There will always be 1 outstanding request for the next vertex command + { + // + // While their are requests remaining beyond 1, wait 10 seconds and try again + // todo: do we want this log? probably not since this could be indefinite + // todo: is 10 seconds appropriate? + // + Sleep(10); + n = g_dryadNativePort->GetOutstandingRequests(); +//VS fprintf(stdout, "DryadShutdownXCompute: %u outstanding requests\n", n); + } + + // + // Close the xcompute session + // + err = XcCloseSession(s_session); + return err; +} + +// +// Initialize the dryad metadata and start up a completion port waiting for events +// +void DryadInitialize() +{ + // todo: check whether this commented out code matters +//JC DryadInitPropertyTable(); +//JC DryadInitTagTable(); +//JC DryadInitErrorTable(); + + // + // Initialize the tables defining the metadata + // + DryadInitMetaDataTable(); + + // + // Create a port and start it + // + g_dryadNativePort = new DryadNativePort(4, 2); + g_dryadNativePort->Start(); +} + + +XDRESSIONHANDLE GetSessionHandle() +{ + return s_session; +} + +XCPROCESSHANDLE GetProcessHandle() +{ + return s_processHandle; +} + +//JC +#if 0 +#include +#include +#include +#include +#include +#include +#include +#include + + +static char* s_configFileName = "cosmos.ini"; + +#if 0 +static char* s_cosmosOptionPrefix = "--inioption:"; +#endif + +static char* s_verbosePrefix = "--verbose"; +static char* s_popupPrefix = "--popup"; +static char* s_debugBreakPrefix = "--debugbreak"; +static char* s_assertHandlerPrefix = "--asserthandler"; +static char* s_disableProfiler = "--disableProfiler"; + +static const char* s_defaultParameters[] = { + "Counters", "NoCountersFile", "1", + "Dryad", "Cluster", "{!machinename}", + "Dryad", "DumpAllSentMessages", "false", + "Dryad", "DumpAllReceivedMessages", "false", + + "stdout", "LogSourceInfo", "1", + + "LogRules", "Rule1", "*,*,*,localLog", + "LogRules", "Rule120", "*,A,*,terminate", + "LogRules", "Rule121", "*,SEAW,*,stdout", + + "localLog", "FileNameBase", "local\\Log", + "localLog", "MaxFiles", "100", + "localLog", "MaxFileSize", "10000000", + "localLog", "BufferSize", "10000", + + NULL +}; + +static const char* s_dryadProfilerLogRule = "Rule122"; + +class DryadConfigurationManager : public ConfigurationManager +{ +private: + DrStr64 m_strNetLibName; + DrStr128 m_strIniFileName; + int m_argc; + char** m_argv; // heap allocated copy + int m_nOpts; // number of command line arguments consumed + +public: + DryadConfigurationManager( + const char* netLibName, + const char* iniFileName, + int argc, + char* argv[]) + { + m_strNetLibName = netLibName; + m_strIniFileName = iniFileName; + m_argc = argc; + m_nOpts = 0; + if (argc == 0) { + m_argv = NULL; + } else { + LogAssert(argv != NULL); + m_argv = new char *[argc]; + LogAssert(m_argv != NULL); + for (int i = 0; i < argc; i++) { + if (argv[i] == NULL) { + m_argv[i] = NULL; + } else { + size_t len = strlen(argv[i])+1; + m_argv[i] = new char[len]; + LogAssert(m_argv[i] != NULL); + memcpy(m_argv[i], argv[i], len); + } + } + } + } + + ~DryadConfigurationManager() + { + for (int i = 0; i < m_argc; i++) { + if (m_argv[i] != NULL) { + delete[] m_argv[i]; + } + } + if (m_argv != NULL) { + delete[] m_argv; + } + } + + int GetNumOptsConsumed() + { + return m_nOpts; + } + + DrError ApplyProfilerConfig(IMutableConfiguration *cosmosFile) { + DrError err = DrError_OK; + cosmosFile->SetParameter("LogRules", + s_dryadProfilerLogRule, "DryadProfiler,*,*,collectorDryadProfiler"); + cosmosFile->SetParameter("collectorDryadProfiler", + "FileNameBase", "collector\\dryadProfiler"); + cosmosFile->SetParameter("collectorDryadProfiler", + "MaxFiles", "0"); + cosmosFile->SetParameter("collectorDryadProfiler", + "MaxFileSize", "10000000"); + cosmosFile->SetParameter("collectorDryadProfiler", + "BufferSize", "10000"); + return err; + } + + // returns the number of arguments consumed in *pNOpts + DrError ApplyDryadConfigOverrides( + int argc, // Number of command line arguments eligible for parsing + char* argv[], // Command line arguments eligible for parsing + int* pNOpts, // Returned # of command line arguments consumed + IMutableConfiguration *cosmosFile) // Returned autopilot.ini contents + { + DrError err = DrError_OK; + bool disableProfiler = false; + + *pNOpts = 0; + + ++argv; + --argc; + while (argc > 0) + { +#if 0 + if (::_strnicmp(argv[0], s_cosmosOptionPrefix, + ::strlen(s_cosmosOptionPrefix)) == 0) + { + if (argc < 2) + { + err = DrError_InvalidParameter; + goto exit; + } + + if (::strcmp(argv[1], "-") == 0) + { + cosmosFile->RemoveSectionFromArg(argv[0]); + } + else if (::strchr(argv[1], '=') == NULL) + { + cosmosFile->RemoveParameterFromArg(argv[0], argv[1]); + } + else + { + cosmosFile->SetParameterFromArg(argv[0], argv[1]); + } + + argv += 2; + argc -= 2; + (*pNOpts) += 2; + } + else +#endif + if (::_strnicmp(argv[0], s_verbosePrefix, + ::strlen(s_verbosePrefix)) == 0) + { + cosmosFile->SetParameter("Dryad", + "DumpAllSentMessages", "true"); + cosmosFile->SetParameter("Dryad", + "DumpAllReceivedMessages", "true"); + cosmosFile->SetParameter("Dryad", + "MessageDumpFile", + "messages.{!component}.{!nodename}.txt"); + + cosmosFile->SetParameter("LogRules", + "Rule121", "*,ISEAW,*,stdout"); + + argv += 1; + argc -= 1; + (*pNOpts) += 1; + } + else if (::_strnicmp(argv[0], s_popupPrefix, + ::strlen(s_popupPrefix)) == 0) + { + cosmosFile->SetParameter("LogRules", + "Rule120", "*,A,*,popup"); + + argv += 1; + argc -= 1; + (*pNOpts) += 1; + } + else if (::_strnicmp(argv[0], s_assertHandlerPrefix, + ::strlen(s_assertHandlerPrefix)) == 0) + { + cosmosFile->SetParameter("LogRules", + "Rule122", "*,A,*,applicationcallback"); + + argv += 1; + argc -= 1; + (*pNOpts) += 1; + } + else if (::_strnicmp(argv[0], s_debugBreakPrefix, + ::strlen(s_debugBreakPrefix)) == 0) + { + argv += 1; + argc -= 1; + (*pNOpts) += 1; + + ::DebugBreak(); + } + else if (::_strnicmp(argv[0], s_disableProfiler, + ::strlen(s_disableProfiler)) == 0) + { + argv += 1; + argc -= 1; + (*pNOpts) += 1; + + disableProfiler = true; + } + else + { + break; + } + } + + err = DrError_OK; + if (!disableProfiler) { + err = ApplyProfilerConfig(cosmosFile); + } + + return err; + } + + + bool DoInitialize() + { + if (m_valid) { + // We don't want perf counter files + Counters::SetInitNoPerfFiles(); + } + + // Register ourselves as the singleton config manager + if (!m_valid || !Configuration::PreInitialize(this)) { + DrLogA( "DryadConfigurationManager", + "Failed to preinitialize dryad configuration manager"); + m_valid = false; + } + + if (m_valid) { + if (!CommonInit(m_strIniFileName.GetString(), 0, -1 )) { + DrLogA( "DryadConfigurationManager", + "Failed to initialize dryad configuration manager"); + m_valid = false; + } + } + + return m_valid; + } + + /** + * Called at the start of InitBootstrapConfiguration, this method is a last chance for a subclass to mess with the bootstrap configuration before it is used. + * + * On entry, bootstrapConfigPathname is the fully qualified name of the bootstrap file. + * bootstrapConfiguration is the raw (no macro expansion or override collapsing) configuration, or NULL if the configuration was not found. + * + * On exit, bootstrapConfiguration is the final raw (no macro expansion or override collapsing) bootstrap configuration. If NULL, initialization will fail. + * + * The default implementation does nothing. + * + * Returns false if initialization should fail. + */ + virtual bool PreprocessBootstrapConfiguration(const char *bootstrapConfigPathname, Ptr& bootstrapConfiguration) + { + if (bootstrapConfiguration == NULL) { + // The bootstrap config file is missing -- build a default one + DrStr64 strDataDirLocation; + DrStr32 strRelDataDirLocation; + strRelDataDirLocation.SetF(".\\DataDir.%u", GetCurrentProcessId()); + DrError err = DrCanonicalizeFilePath(strDataDirLocation, strRelDataDirLocation); + if (err != DrError_OK) { + DrLogE( "DryadConfigurationManager", + "Failed to canonicalize data directory name %s error=%s", + strRelDataDirLocation.GetString(), DRERRORSTRING(err)); + return false; + } + + Ptr cfg = Configuration::GenerateDefaultBootstrapConfig( + strDataDirLocation.GetString(), + "...", + "default", + NULL); + if (cfg == NULL) { + DrLogE( "DryadConfigurationManager", + "Failed to create default bootstrap file"); + return false; + } + + bootstrapConfiguration = cfg; + } + + return true; + } + + /** + * Called immediately after attempted loading of the default configuration, this method is a last chance for a subclass to + * mess with the default configuration before it is used. + * + * On entry, defaultConfigPathname is the fully qualified name of the default configuration. + * "configuration" is the default filtered view of the configuration, or NULL if the configuration was not found + * rawConfiguration is the raw (no macro expansion or override collapsing) configuration, or NULL if the configuration was not found. + * + * On exit, rawConfiguration is the final raw (no macro expansion or override collapsing) default configuration. If NULL, initialization will fail. + * + * The default implementation does nothing. + * + * Returns false if initialization should fail. + */ + virtual bool PreprocessDefaultConfiguration( + const char *defaultConfigPathname, + const IConfiguration *configuration, + Ptr& rawConfiguration) + { + // Create an editable version of the configuration + Ptr newConfig; + if (configuration == NULL) { + // The config file is missing -- build a default one + newConfig = Configuration::GenerateDefaultConfig(); + if (newConfig == NULL) { + DrLogE( "DryadConfigurationManager", + "Failed to create default config file"); + return false; + } + for (const char **ppDefaults = s_defaultParameters; *ppDefaults != NULL; ppDefaults += 3) { + const char *section = ppDefaults[0]; + const char *param = ppDefaults[1]; + LogAssert(param != NULL); + const char *value = ppDefaults[2]; + LogAssert(value != NULL); + newConfig->SetParameter(section, param, value); + } + } else { + newConfig = new ConfigurationMap(configuration); + if (newConfig == NULL) { + DrLogE( "DryadConfigurationManager", + "Failed to create copy of config file"); + return false; + } + } + + // process the command line to override values + DrError err = ApplyDryadConfigOverrides(m_argc, m_argv, &m_nOpts, newConfig); + if (err != DrError_OK) { + DrLogE( "DryadConfigurationManager", + "Failed to apply command line overrides to config file: %s", err); + return false; + } + + // replace the configuration with the edited one. Note that overrides, etc. have already been applied and removed. + rawConfiguration = newConfig; + return true; + } + +}; + + + +DrJobTicket* CreateGlobalJob() +{ + DrError hr = S_OK; + DrServiceDescriptor sd; + DrRef jobTicket = g_pDryadConfig->GetDefaultJobTicket(); + + sd.Set("xcps", g_pDryadConfig->GetDefaultClusterName(), NULL, "rd.RDRBasic.XComputeProcessScheduler_0"); + + DrRef msg; + msg.Attach(new XcPsCreateJobRequest()); + msg->SetCreateJobTicket(jobTicket); + XcJobConstraint& constraint = msg->CreateJobConstraint(); + constraint.SetMaxConcurrentProcesses(999); + constraint.SetMaxExecutionTime(DrTimeInterval_Hour); + g_pDrClient->SendTo(msg, sd); + msg->WaitForResponse( &hr ); + if (hr != DrError_OK) + { + DrLogE( "DryadConfigurationManager", + "CreateJob failied, error=%s", + DRERRORSTRING(hr)); + LogAssert(false); + } + return jobTicket.Detach(); +} + + + +#endif // if 0 diff --git a/DryadVertex/VertexHost/system/common/src/dryadxcomputeresources.cpp b/DryadVertex/VertexHost/system/common/src/dryadxcomputeresources.cpp new file mode 100644 index 0000000..e92e3f9 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/src/dryadxcomputeresources.cpp @@ -0,0 +1,150 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include + +#pragma unmanaged + +// +// Constructor - store process handle. Set guid to process id. +// +DryadXComputeProcessIdentifier:: + DryadXComputeProcessIdentifier(XCPROCESSHANDLE handle) +{ + m_handle = handle; + ReCacheGUID(); +} + +// +// Sets guid to xcompute process id +// +void DryadXComputeProcessIdentifier::ReCacheGUID() +{ + GUID id; + + // + // Get the process ID from xcompute layer + // + XCERROR err = XcGetProcessId(m_handle, &id); + if (err == DrError_OK) + { + // + // If process ID found, store ID and string version of ID as guids + // + m_guid = id; + m_guidString = DRGUIDSTRING(id); + } + else + { + // + // If process ID not found, use null. + // todo: is this fatal? If so, log and fail. + // + m_guid.SetToNull(); + m_guidString = "{No Guid}"; + } +} + +// +// Return handle to process +// +XCPROCESSHANDLE DryadXComputeProcessIdentifier::GetHandle() +{ + return m_handle; +} + +// +// Return GUID +// +DrGuid* DryadXComputeProcessIdentifier::GetGuid() +{ + return &m_guid; +} + +// +// Return string version of GUID +// +const char* DryadXComputeProcessIdentifier::GetGuidString() +{ + return m_guidString; +} + +// +// +// +void DryadXComputeProcessIdentifier:: + MakeURIForRelativeFile(DrStr* dst, + const char* baseDirectory, + const char* relativeFileName) +{ + // + // Combind base directory and relative file name unless base directory is not supplied + // + DrStr256 fullPath; + if (baseDirectory == NULL) + { + fullPath.SetF("wd/%s",relativeFileName); + } + else + { + fullPath.SetF("wd/%s/%s", baseDirectory, relativeFileName); + } + + // + // Get URI relative to process following xcompute access rules + // todo: ask why this matters one way or the other + // + char* uri = NULL; + XCERROR err = XcGetProcessUri(m_handle, fullPath, &uri); + if (err == DrError_OK) + { + // + // If return successful, use the URI + // + dst->Set(uri); + } + else + { + // + // If the return isn't successful, just use the full path + // + dst->Set(fullPath); + } + + // + // Free URI if created + // + if (uri != NULL) + { + XcFreeMemory(uri); + } +} + +DryadXComputeMachineIdentifier:: + DryadXComputeMachineIdentifier(XCPROCESSNODEID node) +{ + m_node = node; +} + +XCPROCESSNODEID DryadXComputeMachineIdentifier::GetNodeID() +{ + return m_node; +} diff --git a/DryadVertex/VertexHost/system/common/src/dvertexcommand.cpp b/DryadVertex/VertexHost/system/common/src/dvertexcommand.cpp new file mode 100644 index 0000000..ccf3175 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/src/dvertexcommand.cpp @@ -0,0 +1,997 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include +#include +#include + +#pragma unmanaged + +static const char* s_StatusPropertyLabel = "DVertexStatus"; +static const char* s_CommandPropertyLabel = "DVertexCommand"; + +DryadPnProcessPropertyRequest::~DryadPnProcessPropertyRequest() +{ +} + +DryadPnProcessPropertyResponse::~DryadPnProcessPropertyResponse() +{ +} + +DryadChannelDescription::DryadChannelDescription(bool isInputChannel) +{ + m_state = DryadError_ChannelAbort; + m_totalLength = 0; + m_processedLength = 0; + m_isInputChannel = isInputChannel; +} + +DryadChannelDescription::~DryadChannelDescription() +{ +} + +DrError DryadChannelDescription::GetChannelState() const +{ + return m_state; +} + +void DryadChannelDescription::SetChannelState(DrError state) +{ + m_state = state; +} + +const char* DryadChannelDescription::GetChannelURI() const +{ + return m_URI; +} + +void DryadChannelDescription::SetChannelURI(const char* channelURI) +{ + m_URI.Set(channelURI); +} + +DryadMetaData* DryadChannelDescription::GetChannelMetaData() const +{ + return m_metaData; +} + +void DryadChannelDescription::SetChannelMetaData(DryadMetaData* metaData) +{ + m_metaData.Set(metaData); +} + +UInt64 DryadChannelDescription::GetChannelTotalLength() const +{ + return m_totalLength; +} + +void DryadChannelDescription::SetChannelTotalLength(UInt64 totalLength) +{ + m_totalLength = totalLength; +} + +UInt64 DryadChannelDescription::GetChannelProcessedLength() const +{ + return m_processedLength; +} + +void DryadChannelDescription:: + SetChannelProcessedLength(UInt64 processedLength) +{ + m_processedLength = processedLength; +} + +DrError DryadChannelDescription::Serialize(DrMemoryWriter* writer) +{ + UInt16 tagValue = (m_isInputChannel) ? + DryadTag_InputChannelDescription : + DryadTag_OutputChannelDescription; + + writer->WriteUInt16Property(Prop_Dryad_BeginTag, tagValue); + + writer->WriteDrErrorProperty(Prop_Dryad_ChannelState, m_state); + writer->WriteLongDrStrProperty(Prop_Dryad_ChannelURI, m_URI); + writer->WriteUInt64Property(Prop_Dryad_ChannelTotalLength, m_totalLength); + writer->WriteUInt64Property(Prop_Dryad_ChannelProcessedLength, + m_processedLength); + if (m_metaData.Ptr() != NULL) + { + m_metaData.Ptr()->WriteAsAggregate(writer, + DryadTag_ChannelMetaData, false); + } + + writer->WriteUInt16Property(Prop_Dryad_EndTag, tagValue); + + return writer->GetStatus(); +} + +DrError DryadChannelDescription::OnParseProperty(DrMemoryReader *reader, + UInt16 enumID, + UInt32 dataLen, + void *cookie) +{ + DrError err; + + switch (enumID) + { + default: + DrLogW("Unknown property in channel description. enumID %u", (DWORD ) enumID); + err = reader->SkipNextPropertyOrAggregate(); + break; + + case Prop_Dryad_ChannelState: + err = reader->ReadNextDrErrorProperty(enumID, &m_state); + break; + + case Prop_Dryad_ChannelURI: + { + const char* URI; + err = reader->ReadNextStringProperty(enumID, &URI); + if (err == DrError_OK) + { + SetChannelURI(URI); + } + } + break; + + case Prop_Dryad_ChannelTotalLength: + err = reader->ReadNextUInt64Property(enumID, &m_totalLength); + break; + + case Prop_Dryad_ChannelProcessedLength: + err = reader->ReadNextUInt64Property(enumID, &m_processedLength); + break; + + case Prop_Dryad_BeginTag: + { + UInt16 tagID; + err = reader->PeekNextUInt16Property(Prop_Dryad_BeginTag, &tagID); + if (err == DrError_OK) + { + if (tagID == DryadTag_ChannelMetaData) + { + DryadMetaDataParser parser; + err = reader->ReadAggregate(tagID, &parser, NULL); + if (err == DrError_OK) + { + SetChannelMetaData(parser.GetMetaData()); + } + } + else + { + DrLogW("Unknown aggregate in channel description. tagID %u", (DWORD) tagID); + } + } + } + break; + } + + return err; +} + +void DryadChannelDescription::CopyFrom(DryadChannelDescription* src, + bool includeLengths) +{ + LogAssert(m_isInputChannel == src->m_isInputChannel); + + SetChannelURI(src->GetChannelURI()); + SetChannelState(src->GetChannelState()); + SetChannelMetaData(src->GetChannelMetaData()); + if (includeLengths) + { + SetChannelProcessedLength(src->GetChannelProcessedLength()); + SetChannelTotalLength(src->GetChannelTotalLength()); + } +} + + +DryadInputChannelDescription::DryadInputChannelDescription() : + DryadChannelDescription(true) +{ +} + + +DryadOutputChannelDescription::DryadOutputChannelDescription() : + DryadChannelDescription(false) +{ +} + + +// +// Create new Process status with default property values +// +DVertexProcessStatus::DVertexProcessStatus() +{ + m_id = 0; + m_version = 0; + m_nInputChannels = 0; + m_inputChannel = NULL; + m_maxInputChannels = 0; + m_nOutputChannels = 0; + m_outputChannel = NULL; + m_maxOutputChannels = 0; + m_canShareWorkQueue = false; + + m_nextInputChannelToRead = 0; + m_nextOutputChannelToRead = 0; +} + +DVertexProcessStatus::~DVertexProcessStatus() +{ + delete [] m_inputChannel; + delete [] m_outputChannel; +} + +UInt32 DVertexProcessStatus::GetVertexId() +{ + return m_id; +} + +void DVertexProcessStatus::SetVertexId(UInt32 channelId) +{ + m_id = channelId; +} + +UInt32 DVertexProcessStatus::GetVertexInstanceVersion() +{ + return m_version; +} + +void DVertexProcessStatus::SetVertexInstanceVersion(UInt32 version) +{ + m_version = version; +} + +DryadMetaData* DVertexProcessStatus::GetVertexMetaData() +{ + return m_metaData.Ptr(); +} + +void DVertexProcessStatus::SetVertexMetaData(DryadMetaData* metaData) +{ + m_metaData.Set(metaData); +} + +UInt32 DVertexProcessStatus::GetInputChannelCount() +{ + return m_nInputChannels; +} + +DryadInputChannelDescription* DVertexProcessStatus::GetInputChannels() +{ + return m_inputChannel; +} + +void DVertexProcessStatus::SetInputChannelCount(UInt32 nInputChannels) +{ + delete [] m_inputChannel; + m_inputChannel = new DryadInputChannelDescription[nInputChannels]; + m_nInputChannels = nInputChannels; + m_nextInputChannelToRead = 0; +} + +UInt32 DVertexProcessStatus::GetMaxOpenInputChannelCount() +{ + return m_maxInputChannels; +} + +void DVertexProcessStatus::SetMaxOpenInputChannelCount(UInt32 channelCount) +{ + m_maxInputChannels = channelCount; +} + +UInt32 DVertexProcessStatus::GetOutputChannelCount() +{ + return m_nOutputChannels; +} + +DryadOutputChannelDescription* DVertexProcessStatus::GetOutputChannels() +{ + return m_outputChannel; +} + +void DVertexProcessStatus::SetOutputChannelCount(UInt32 nOutputChannels) +{ + delete [] m_outputChannel; + m_outputChannel = new DryadOutputChannelDescription[nOutputChannels]; + m_nOutputChannels = nOutputChannels; + m_nextOutputChannelToRead = 0; +} + +UInt32 DVertexProcessStatus::GetMaxOpenOutputChannelCount() +{ + return m_maxOutputChannels; +} + +void DVertexProcessStatus::SetMaxOpenOutputChannelCount(UInt32 channelCount) +{ + m_maxOutputChannels = channelCount; +} + +bool DVertexProcessStatus::GetCanShareWorkQueue() +{ + return m_canShareWorkQueue; +} + +void DVertexProcessStatus::SetCanShareWorkQueue(bool canShareWorkQueue) +{ + m_canShareWorkQueue = canShareWorkQueue; +} + +DrError DVertexProcessStatus::Serialize(DrMemoryWriter* writer) +{ + UInt32 i; + + writer->WriteUInt16Property(Prop_Dryad_BeginTag, + DryadTag_VertexProcessStatus); + + writer->WriteUInt32Property(Prop_Dryad_VertexId, m_id); + writer->WriteUInt32Property(Prop_Dryad_VertexVersion, m_version); + if (m_metaData.Ptr() != NULL) + { + m_metaData.Ptr()->WriteAsAggregate(writer, + DryadTag_VertexMetaData, false); + } + + writer->WriteUInt32Property(Prop_Dryad_VertexInputChannelCount, + m_nInputChannels); + for (i=0; iGetStatus() == DrError_OK) + { + writer->SetStatus(err); + } + } + writer->WriteUInt32Property(Prop_Dryad_VertexMaxOpenInputChannelCount, + m_maxInputChannels); + + writer->WriteUInt32Property(Prop_Dryad_VertexOutputChannelCount, + m_nOutputChannels); + for (i=0; iGetStatus() == DrError_OK) + { + writer->SetStatus(err); + } + } + writer->WriteUInt32Property(Prop_Dryad_VertexMaxOpenOutputChannelCount, + m_maxOutputChannels); + + writer->WriteBoolProperty(Prop_Dryad_CanShareWorkQueue, + m_canShareWorkQueue); + + writer->WriteUInt16Property(Prop_Dryad_EndTag, + DryadTag_VertexProcessStatus); + + return writer->GetStatus(); +} + +DrError DVertexProcessStatus::OnParseProperty(DrMemoryReader *reader, + UInt16 enumID, + UInt32 dataLen, + void *cookie) +{ + DrError err; + + switch (enumID) + { + default: + DrLogW("Unknown property in vertex status message. enumID %u", (DWORD ) enumID); + err = reader->SkipNextPropertyOrAggregate(); + break; + + case Prop_Dryad_VertexId: + err = reader->ReadNextUInt32Property(enumID, &m_id); + break; + + case Prop_Dryad_VertexVersion: + err = reader->ReadNextUInt32Property(enumID, &m_version); + break; + + case Prop_Dryad_VertexInputChannelCount: + UInt32 nInputChannels; + err = reader->ReadNextUInt32Property(enumID, &nInputChannels); + if (err == DrError_OK) + { + SetInputChannelCount(nInputChannels); + } + break; + + case Prop_Dryad_VertexMaxOpenInputChannelCount: + UInt32 maxInputChannels; + err = reader->ReadNextUInt32Property(enumID, &maxInputChannels); + if (err == DrError_OK) + { + SetMaxOpenInputChannelCount(maxInputChannels); + } + break; + + case Prop_Dryad_VertexOutputChannelCount: + UInt32 nOutputChannels; + err = reader->ReadNextUInt32Property(enumID, &nOutputChannels); + if (err == DrError_OK) + { + SetOutputChannelCount(nOutputChannels); + } + break; + + case Prop_Dryad_VertexMaxOpenOutputChannelCount: + UInt32 maxOutputChannels; + err = reader->ReadNextUInt32Property(enumID, &maxOutputChannels); + if (err == DrError_OK) + { + SetMaxOpenOutputChannelCount(maxOutputChannels); + } + break; + + case Prop_Dryad_CanShareWorkQueue: + err = reader->ReadNextBoolProperty(enumID, &m_canShareWorkQueue); + break; + + case Prop_Dryad_BeginTag: + UInt16 tagValue; + err = reader->PeekNextUInt16Property(Prop_Dryad_BeginTag, &tagValue); + if (err != DrError_OK) + { + DrLogE("Error reading Prop_Dryad_BeginTag - 0x%08x", err); + } else { + switch (tagValue) + { + case DryadTag_InputChannelDescription: + if (m_nextInputChannelToRead >= m_nInputChannels) + { + DrLogE("Too many input channel descriptions. nextInputChannelToRead=%u, nInputChannels=%u", + m_nextInputChannelToRead, m_nInputChannels); + err = DrError_InvalidParameter; + } + else + { + DryadInputChannelDescription* channel = + &(m_inputChannel[m_nextInputChannelToRead]); + err = reader->ReadAggregate(tagValue, channel, NULL); + if (err == DrError_OK) + { + ++m_nextInputChannelToRead; + } + } + break; + + case DryadTag_OutputChannelDescription: + if (m_nextOutputChannelToRead >= m_nOutputChannels) + { + DrLogE( + "Too many output channel descriptions. nextOutputChannelToRead=%u, nOutputChannels=%u", + m_nextOutputChannelToRead, m_nOutputChannels); + err = DrError_InvalidParameter; + } + else + { + DryadOutputChannelDescription* channel = + &(m_outputChannel[m_nextOutputChannelToRead]); + err = reader->ReadAggregate(tagValue, channel, NULL); + if (err == DrError_OK) + { + ++m_nextOutputChannelToRead; + } + } + break; + + case DryadTag_VertexMetaData: + { + DryadMetaDataParser parser; + err = reader->ReadAggregate(tagValue, &parser, NULL); + if (err == DrError_OK) + { + SetVertexMetaData(parser.GetMetaData()); + } + } + break; + + default: + DrLogW( + "Unexpected tag - %hu", tagValue); + err = reader->SkipNextPropertyOrAggregate(); + } + } + break; + } + + return err; +} + +void DVertexProcessStatus::CopyFrom(DVertexProcessStatus* src, + bool includeLengths) +{ + UInt32 i; + + SetVertexId(src->GetVertexId()); + SetVertexInstanceVersion(src->GetVertexInstanceVersion()); + SetVertexMetaData(src->GetVertexMetaData()); + + SetInputChannelCount(src->GetInputChannelCount()); + DryadInputChannelDescription* srcInputs = src->GetInputChannels(); + for (i=0; iGetMaxOpenInputChannelCount()); + + SetOutputChannelCount(src->GetOutputChannelCount()); + DryadOutputChannelDescription* srcOutputs = src->GetOutputChannels(); + for (i=0; iGetMaxOpenOutputChannelCount()); +} + + +DVertexStatus::DVertexStatus() +{ + m_state = DrError_OK; + m_processStatus = new DVertexProcessStatus(); +} + +DrError DVertexStatus::GetVertexState() +{ + return m_state; +} + +// +// Update vertex state +// +void DVertexStatus::SetVertexState(DrError state) +{ + m_state = state; +} + +DVertexProcessStatus* DVertexStatus::GetProcessStatus() +{ + return m_processStatus; +} + +void DVertexStatus::SetProcessStatus(DVertexProcessStatus* processStatus) +{ + m_processStatus = processStatus; +} + +DrError DVertexStatus::Serialize(DrMemoryWriter* writer) +{ + writer->WriteUInt16Property(Prop_Dryad_BeginTag, DryadTag_VertexStatus); + + writer->WriteDrErrorProperty(Prop_Dryad_VertexState, m_state); + DrError err = m_processStatus->Serialize(writer); + if (err != DrError_OK && writer->GetStatus() == DrError_OK) + { + writer->SetStatus(err); + } + + writer->WriteUInt16Property(Prop_Dryad_EndTag, DryadTag_VertexStatus); + + return writer->GetStatus(); +} + +DrError DVertexStatus::OnParseProperty(DrMemoryReader *reader, + UInt16 enumID, + UInt32 dataLen, + void *cookie) +{ + DrError err; + + switch (enumID) + { + default: + DrLogW( + "Unknown property in vertex status message. enumID %u", (DWORD ) enumID); + err = reader->SkipNextPropertyOrAggregate(); + break; + + case Prop_Dryad_VertexState: + err = reader->ReadNextDrErrorProperty(enumID, &m_state); + break; + + case Prop_Dryad_BeginTag: + UInt16 tagValue; + err = reader->PeekNextUInt16Property(Prop_Dryad_BeginTag, &tagValue); + if (err != DrError_OK) + { + DrLogE("Error reading Prop_Dryad_BeginTag - 0x08x", err); + } else { + switch (tagValue) + { + case DryadTag_VertexProcessStatus: + err = reader->ReadAggregate(tagValue, m_processStatus, NULL); + break; + + default: + DrLogW("Unexpected tag - %hu", tagValue); + err = reader->SkipNextPropertyOrAggregate(); + } + } + break; + } + + return err; +} + +void DVertexStatus:: + StoreInRequestMessage(DryadPnProcessPropertyRequest* request) +{ + DrStr64 label; + GetPnPropertyLabel(&label, + m_processStatus->GetVertexId(), + m_processStatus->GetVertexInstanceVersion(), + false); + + DrStr64 controlLabel; + GetPnPropertyLabel(&controlLabel, + m_processStatus->GetVertexId(), + m_processStatus->GetVertexInstanceVersion(), + true); + + DrLogI( "Storing status update property. Label: %s", label.GetString()); + + request->SetPropertyLabel(label, controlLabel); + request->SetPropertyString(DRERRORSTRING(m_state)); + + DrMemoryBuffer* block = request->GetPropertyBlock(); + + { + DrMemoryBufferWriter writer(block); + DrError err = Serialize(&writer); + LogAssert(err == DrError_OK); + err = writer.FlushMemoryWriter(); + LogAssert(err == DrError_OK); + } +} + +DrError DVertexStatus:: + ReadFromResponseMessage(DryadPnProcessPropertyResponse* response, + UInt32 vertexId, UInt32 vertexVersion) +{ + DrStr64 label; + GetPnPropertyLabel(&label, vertexId, vertexVersion, false); + + response->RetrievePropertyLabel(label); + DrMemoryBuffer* block = response->GetPropertyBlock(); + if (block != NULL) + { + DrMemoryBufferReader reader(block); + return reader.ReadAggregate(DryadTag_VertexStatus, this, NULL); + } + else + { + return DrError_InvalidProperty; + } +} + +void DVertexStatus::GetPnPropertyLabel(DrStr* pDstString, + UInt32 vertexId, UInt32 vertexVersion, + bool notifyWaiters) +{ + pDstString->SetF("%s-%u.%u%s", + s_StatusPropertyLabel, vertexId, vertexVersion, + (notifyWaiters) ? "-update" : ""); +} + + +// +// Create new command block with default properties +// +DVertexCommandBlock::DVertexCommandBlock() +{ + m_command = DVertexCommand_Terminate; + m_processStatus = new DVertexProcessStatus(); + m_nArguments = 0; + m_argument = NULL; + m_serializedBlockLength = 0; + m_serializedBlock = NULL; + m_setBreakpointOnCommandArrival = false; + m_nextArgumentToRead = 0; +} + +DVertexCommandBlock::~DVertexCommandBlock() +{ + delete [] m_argument; +} + +DVertexCommand DVertexCommandBlock::GetVertexCommand() +{ + return m_command; +} + +void DVertexCommandBlock::SetVertexCommand(DVertexCommand command) +{ + m_command = command; +} + +DVertexProcessStatus* DVertexCommandBlock::GetProcessStatus() +{ + return m_processStatus; +} + +void DVertexCommandBlock::SetProcessStatus(DVertexProcessStatus* processStatus) +{ + m_processStatus = processStatus; +} + +UInt32 DVertexCommandBlock::GetArgumentCount() +{ + return m_nArguments; +} + +void DVertexCommandBlock::SetArgumentCount(UInt32 nArguments) +{ + delete [] m_argument; + m_nArguments = nArguments; + m_argument = new DrStr64[m_nArguments]; + m_nextArgumentToRead = 0; +} + +DrStr64* DVertexCommandBlock::GetArgumentVector() +{ + return m_argument; +} + +void DVertexCommandBlock::SetArgument(UInt32 argumentIndex, + const char* argument) +{ + LogAssert(argumentIndex < m_nArguments); + m_argument[argumentIndex].Set(argument); +} + +void* DVertexCommandBlock::GetRawSerializedBlock() +{ + return m_serializedBlock; +} + +UInt32 DVertexCommandBlock::GetRawSerializedBlockLength() +{ + return m_serializedBlockLength; +} + +void DVertexCommandBlock::SetRawSerializedBlock(UInt32 length, + const void* data) +{ + delete [] m_serializedBlock; + m_serializedBlockLength = length; + m_serializedBlock = new char[m_serializedBlockLength]; + LogAssert(m_serializedBlock != NULL); + ::memcpy(m_serializedBlock, data, m_serializedBlockLength); +} + +void DVertexCommandBlock::SetDebugBreak(bool setBreakpointOnCommandArrival) +{ + m_setBreakpointOnCommandArrival = setBreakpointOnCommandArrival; +} + +bool DVertexCommandBlock::GetDebugBreak() +{ + return m_setBreakpointOnCommandArrival; +} + +DrError DVertexCommandBlock::Serialize(DrMemoryWriter* writer) +{ + DrError err; + UInt32 i; + + writer->WriteUInt16Property(Prop_Dryad_BeginTag, DryadTag_VertexCommand); + + writer->WriteUInt32Property(Prop_Dryad_VertexCommand, m_command); + + err = m_processStatus->Serialize(writer); + if (err != DrError_OK && writer->GetStatus() == DrError_OK) + { + writer->SetStatus(err); + } + + writer->WriteUInt32Property(Prop_Dryad_VertexArgumentCount, + m_nArguments); + for (i=0; iWriteLongDrStrProperty(Prop_Dryad_VertexArgument, + m_argument[i]); + } + + writer->WriteLongBlobProperty(Prop_Dryad_VertexSerializedBlock, + m_serializedBlockLength, m_serializedBlock); + + writer->WriteBoolProperty(Prop_Dryad_DebugBreak, + m_setBreakpointOnCommandArrival); + + writer->WriteUInt16Property(Prop_Dryad_EndTag, DryadTag_VertexCommand); + + return writer->GetStatus(); +} + +DrError DVertexCommandBlock::OnParseProperty(DrMemoryReader *reader, + UInt16 enumID, + UInt32 dataLen, + void *cookie) +{ + DrError err; + + switch (enumID) + { + default: + DrLogW( + "Unknown property in vertex command message. enumID %u", (DWORD ) enumID); + err = reader->SkipNextPropertyOrAggregate(); + break; + + case Prop_Dryad_VertexCommand: + UInt32 marshaledCommand; + err = reader->ReadNextUInt32Property(enumID, &marshaledCommand); + if (err == DrError_OK) + { + if (m_command < DVertexCommand_Max) + { + m_command = (DVertexCommand) marshaledCommand; + } + else + { + err = DrError_InvalidProperty; + } + } + break; + + case Prop_Dryad_VertexArgumentCount: + UInt32 nArguments; + err = reader->ReadNextUInt32Property(enumID, &nArguments); + if (err == DrError_OK) + { + SetArgumentCount(nArguments); + } + break; + + case Prop_Dryad_VertexArgument: + if (m_nextArgumentToRead >= m_nArguments) + { + DrLogE( + "Too many arguments. nextArgumentToRead=%u, nArguments=%u", + m_nextArgumentToRead, m_nArguments); + err = DrError_InvalidParameter; + } + else + { + const char* arg; + err = reader->ReadNextStringProperty(enumID, &arg); + if (err == DrError_OK) + { + SetArgument(m_nextArgumentToRead, arg); + ++m_nextArgumentToRead; + } + } + break; + + case Prop_Dryad_VertexSerializedBlock: + UInt32 blockLength; + const void* blockData; + err = reader->ReadNextProperty(enumID, &blockLength, &blockData); + if (err == DrError_OK) + { + SetRawSerializedBlock(blockLength, blockData); + } + break; + + case Prop_Dryad_DebugBreak: + bool debugBreak; + err = reader->ReadNextBoolProperty(enumID, &debugBreak); + if (err == DrError_OK) + { + SetDebugBreak(debugBreak); + } + break; + + case Prop_Dryad_BeginTag: + UInt16 tagValue; + err = reader->PeekNextUInt16Property(Prop_Dryad_BeginTag, &tagValue); + if (err != DrError_OK) + { + DrLogE("Error reading Prop_Dryad_BeginTag - 0x08x", err); + } else { + switch (tagValue) + { + case DryadTag_VertexProcessStatus: + err = reader->ReadAggregate(tagValue, m_processStatus, NULL); + break; + + default: + DrLogW("Unexpected tag - %hu", tagValue); + err = reader->SkipNextPropertyOrAggregate(); + } + } + break; + } + + return err; +} + +void DVertexCommandBlock:: + StoreInRequestMessage(DryadPnProcessPropertyRequest* request) +{ + DrStr64 label; + GetPnPropertyLabel(&label, + m_processStatus->GetVertexId(), + m_processStatus->GetVertexInstanceVersion()); + + DrLogI( "Storing command property. Label: %s", label.GetString()); + + request->SetPropertyLabel(label, NULL); + + LogAssert(m_command < DVertexCommand_Max); + request->SetPropertyString(g_dVertexCommandText[m_command]); + + DrMemoryBuffer* block = request->GetPropertyBlock(); + + { + DrMemoryBufferWriter writer(block); + DrError err = Serialize(&writer); + LogAssert(err == DrError_OK); + err = writer.FlushMemoryWriter(); + LogAssert(err == DrError_OK); + } +} + +// +// Get response from message +// +DrError DVertexCommandBlock:: + ReadFromResponseMessage(DryadPnProcessPropertyResponse* response, + UInt32 vertexId, UInt32 vertexVersion) +{ + // + // Get the property associated with this vertex + // + DrStr64 label; + GetPnPropertyLabel(&label, vertexId, vertexVersion); + response->RetrievePropertyLabel(label); + + // + // Get the property contents + // + DrMemoryBuffer* block = response->GetPropertyBlock(); + if (block != NULL) + { + // + // If non-null return property contents + // + DrMemoryBufferReader reader(block); + return reader.ReadAggregate(DryadTag_VertexCommand, this, NULL); + } + else + { + // + // If nothing there, return invalid property + // + return DrError_InvalidProperty; + } +} + +void DVertexCommandBlock::GetPnPropertyLabel(DrStr* pDstString, + UInt32 vertexId, + UInt32 vertexVersion) +{ + pDstString->SetF("%s-%u.%u", + s_CommandPropertyLabel, vertexId, vertexVersion); +} diff --git a/DryadVertex/VertexHost/system/common/src/errorreporter.cpp b/DryadVertex/VertexHost/system/common/src/errorreporter.cpp new file mode 100644 index 0000000..de3d275f --- /dev/null +++ b/DryadVertex/VertexHost/system/common/src/errorreporter.cpp @@ -0,0 +1,88 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include + +#pragma unmanaged + +DVErrorReporter::DVErrorReporter() +{ + m_errorCode = DrError_OK; +} + +bool DVErrorReporter::NoError() +{ + return m_errorCode == DrError_OK; +} + +DrError DVErrorReporter::GetErrorCode() +{ + return m_errorCode; +} + +// +// Return any error metadata accumulated +// +DryadMetaData* DVErrorReporter::GetErrorMetaData() +{ + return m_metaData; +} + +void DVErrorReporter::InterruptProcessing() +{ + ReportError(DryadError_ProcessingInterrupted); +} + +void DVErrorReporter::ReportError(DrError errorStatus) +{ + m_errorCode = errorStatus; +} + +void DVErrorReporter::ReportError(const char* errorFormat, ...) +{ + va_list ptr; va_start(ptr, errorFormat); + ReportFormattedErrorInternal(DryadError_VertexError, errorFormat, ptr); +} + +void DVErrorReporter::ReportError(DrError errorStatus, + const char* errorFormat, ...) +{ + va_list ptr; va_start(ptr, errorFormat); + ReportFormattedErrorInternal(errorStatus, errorFormat, ptr); +} + +void DVErrorReporter::ReportError(DrError errorStatus, + DryadMetaData* metaData) +{ + m_metaData = metaData; + m_errorCode = errorStatus; +} + +void DVErrorReporter::ReportFormattedErrorInternal(DrError errorStatus, + const char* errorFormat, + va_list args) +{ + DryadMetaData::Create(&m_metaData); + DrStr128 errorString; + errorString.VSetF(errorFormat, args); + m_metaData->AddErrorWithDescription(errorStatus, errorString); + m_errorCode = errorStatus; +} diff --git a/DryadVertex/VertexHost/system/common/src/portmemorybuffers.cpp b/DryadVertex/VertexHost/system/common/src/portmemorybuffers.cpp new file mode 100644 index 0000000..c17d418 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/src/portmemorybuffers.cpp @@ -0,0 +1,298 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +#pragma unmanaged + +DryadFixedMemoryBuffer::DryadFixedMemoryBuffer() +{ +} + +DryadFixedMemoryBuffer::DryadFixedMemoryBuffer(BYTE *pData, + Size_t allocatedSize, + Size_t availableSize) : + DrFixedMemoryBuffer(pData, allocatedSize, availableSize) +{ +} + +DryadFixedMemoryBuffer::~DryadFixedMemoryBuffer() +{ +} + + +DryadFixedMemoryBufferCopy:: + DryadFixedMemoryBufferCopy(DryadFixedMemoryBuffer* src) +{ + Size_t copySize = src->GetAvailableSize(); + Size_t availableSize; + BYTE* data = (BYTE *) src->GetDataAddress(0, &availableSize, NULL); + LogAssert(data != NULL && availableSize >= copySize); + + m_dataCopy = new BYTE[src->GetAvailableSize()]; + ::memcpy(m_dataCopy, data, copySize); + + Init(m_dataCopy, copySize, copySize); +} + +DryadFixedMemoryBufferCopy::~DryadFixedMemoryBufferCopy() +{ + delete [] m_dataCopy; +} + +DryadLockedMemoryBuffer::DryadLockedMemoryBuffer() +{ +} + +DryadLockedMemoryBuffer::DryadLockedMemoryBuffer(BYTE *pData, + Size_t allocatedSize) : + DrFixedMemoryBuffer(pData, allocatedSize, allocatedSize) +{ +} + +DryadLockedMemoryBuffer::~DryadLockedMemoryBuffer() +{ +} + +// +// Initialize memory buffer with array and size +// +void DryadLockedMemoryBuffer::Init(BYTE *pData, Size_t allocatedSize) +{ + DrFixedMemoryBuffer::Init(pData, allocatedSize, allocatedSize); +} + +void DryadLockedMemoryBuffer::SetAvailableSize(Size_t uSize) +{ + LogAssert(false); +} + +static CRITSEC s_readPoolCS; +static size_t s_readPoolAlignment = 0; +static size_t s_readPoolBufferSize = 0; +static const size_t s_readPoolSize = 20; +static void* s_readPool[s_readPoolSize]; +static size_t s_readPoolValid = 0; + +// +// Create a fixed length buffer aligned to a provided 2^N alignment +// +DryadAlignedReadBlock::DryadAlignedReadBlock(size_t size, + size_t alignment) +{ + // + // If alignment set, create buffer and set start to correct alignment, otherwise just use random address + // + if (alignment > 0) + { + void* poolData = NULL; + + { + AutoCriticalSection acs(&s_readPoolCS); + if (s_readPoolAlignment == 0) + { + s_readPoolAlignment = alignment; + } + if (s_readPoolBufferSize == 0) + { + s_readPoolBufferSize = size; + } + + if (s_readPoolAlignment == alignment && + s_readPoolBufferSize == size && + s_readPoolValid > 0) + { + --s_readPoolValid; + poolData = s_readPool[s_readPoolValid]; + } + } + + if (poolData == NULL) + { + // + // alignment must be a power of 2 (eg 1000 & 111 = 0) + // + LogAssert ((alignment & (alignment-1)) == 0); + + // + // Create buffer that's big enough to hold size even if base address has to move up by (alignment - 1) + // + m_data = new char[size + alignment - 1]; + } + else + { + m_data = poolData; + } + + ULONG_PTR baseAddress = (ULONG_PTR) m_data; + + // + // Round base address up to nearest multiple of alignment (=baseaddress if already aligned) + // eg: base = 13, alignment = 8: 1101 + 111 - ((1101 + 111) & 111) = 10100 - (10100 & 111) = 10000 = 16 + // + ULONG_PTR alignedAddress = baseAddress + alignment - 1; + alignedAddress -= (alignedAddress & (alignment - 1)); + LogAssert(alignedAddress + size < baseAddress + size + alignment); + + m_alignedData = (void *) alignedAddress; + } + else + { + m_data = new char[size]; + m_alignedData = m_data; + } + + // + // Initialize a fixed length buffer + // + this->Init((BYTE *) m_alignedData, size); + m_alignment = alignment; +} + +DryadAlignedReadBlock::~DryadAlignedReadBlock() +{ + bool mustDelete = true; + + if (m_alignment > 0) + { + AutoCriticalSection acs(&s_readPoolCS); + + if (s_readPoolAlignment == m_alignment && + s_readPoolBufferSize == GetAllocatedSize() && + s_readPoolValid < s_readPoolSize) + { + s_readPool[s_readPoolValid] = m_data; + ++s_readPoolValid; + mustDelete = false; + } + } + + if (mustDelete) + { + delete [] m_data; + } +} + +void* DryadAlignedReadBlock::GetData() +{ + return m_alignedData; +} + +// +// Update max data size in this buffer to supplied value +// +void DryadAlignedReadBlock::Trim(Size_t numBytes) +{ + LogAssert(numBytes <= GetAvailableSize()); + InternalSetAvailableSize(numBytes); +} + + +static CRITSEC s_writePoolCS; +static size_t s_writePoolAlignment = 0; +static size_t s_writePoolBufferSize = 0; +static const size_t s_writePoolSize = 20; +static void* s_writePool[s_writePoolSize]; +static size_t s_writePoolValid = 0; + +DryadAlignedWriteBlock::DryadAlignedWriteBlock(size_t size, + size_t alignment) +{ + if (alignment > 0) + { + void* poolData = NULL; + + { + AutoCriticalSection acs(&s_writePoolCS); + if (s_writePoolAlignment == 0) + { + s_writePoolAlignment = alignment; + } + if (s_writePoolBufferSize == 0) + { + s_writePoolBufferSize = size; + } + + if (s_writePoolAlignment == alignment && + s_writePoolBufferSize == size && + s_writePoolValid > 0) + { + --s_writePoolValid; + poolData = s_writePool[s_writePoolValid]; + } + } + + if (poolData == NULL) + { + /* alignment must be a power of 2 */ + LogAssert ((alignment & (alignment-1)) == 0); + + m_data = new char[size + alignment - 1]; + } + else + { + m_data = poolData; + } + + ULONG_PTR baseAddress = (ULONG_PTR) m_data; + ULONG_PTR alignedAddress = baseAddress + alignment - 1; + alignedAddress -= (alignedAddress & (alignment - 1)); + LogAssert(alignedAddress + size < baseAddress + size + alignment); + + m_alignedData = (void *) alignedAddress; + } + else + { + m_data = new char[size]; + m_alignedData = m_data; + } + + this->Init((BYTE *) m_alignedData, size, 0); + m_alignment = alignment; +} + +DryadAlignedWriteBlock::~DryadAlignedWriteBlock() +{ + bool mustDelete = true; + + if (m_alignment > 0) + { + AutoCriticalSection acs(&s_writePoolCS); + + if (s_writePoolAlignment == m_alignment && + s_writePoolBufferSize == GetAllocatedSize() && + s_writePoolValid < s_writePoolSize) + { + s_writePool[s_writePoolValid] = m_data; + ++s_writePoolValid; + mustDelete = false; + } + } + + if (mustDelete) + { + delete [] m_data; + } +} + +void* DryadAlignedWriteBlock::GetData() +{ + return m_alignedData; +} diff --git a/DryadVertex/VertexHost/system/common/src/workqueue.cpp b/DryadVertex/VertexHost/system/common/src/workqueue.cpp new file mode 100644 index 0000000..ba5f1be --- /dev/null +++ b/DryadVertex/VertexHost/system/common/src/workqueue.cpp @@ -0,0 +1,373 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "workqueue.h" +#include +#include +//JC#include "logging.h" + +#pragma unmanaged + +#define DWORKQUEUE_CONTINUE (0) +#define DWORKQUEUE_EXIT (1) + +WorkRequest::~WorkRequest() +{ +} + +// +// Create work queue using provided numbers of threads +// +WorkQueue::WorkQueue(DWORD numWorkerThreads, + DWORD numConcurrentThreads) +{ + m_state = WQS_Stopped; + m_numWorkerThreads = numWorkerThreads; + m_numConcurrentThreads = numConcurrentThreads; + m_completionPort = INVALID_HANDLE_VALUE; + m_threadHandle = new HANDLE[m_numWorkerThreads]; + LogAssert(m_threadHandle != NULL); + DWORD i; + for (i=0; im_completionPort != INVALID_HANDLE_VALUE); + + DrLogI("WorkQueue::ThreadFunc starting thread"); + + bool finished = false; + + do + { + DWORD numBytes; + ULONG_PTR completionKey; + LPOVERLAPPED overlapped; + +// DrLogD( +// "WorkQueue::ThreadFunc waiting for completion event"); + + BOOL retval = ::GetQueuedCompletionStatus(self->m_completionPort, + &numBytes, + &completionKey, + &overlapped, + INFINITE); + +// DrLogD( +// "WorkQueue::ThreadFunc received completion event", +// "retval: %d", retval); + + if (retval != 0) + { + finished = (numBytes == DWORKQUEUE_EXIT); + + if (finished) + { + DrLogI("WorkQueue::ThreadFunc received shutdown event"); + } + } + else + { + DWORD errCode = GetLastError(); + DrLogA("WorkQueue::GetQueuedCompletionStatus. error code: 0x%08x", HRESULT_FROM_WIN32(errCode)); + } + + bool queueDrained = false; + bool decrementedCount = false; + do + { + WorkRequest* request = NULL; + { + AutoCriticalSection acs(&(self->m_baseDR)); + + if (!decrementedCount) + { + decrementedCount = true; + LogAssert(self->m_numQueuedWakeUps > 0); + --(self->m_numQueuedWakeUps); +// DrLogD( +// "WorkQueue::ThreadFunc decremented queued wakeups", +// "new val: %d", self->m_numQueuedWakeUps); + } + + if (self->m_list.IsEmpty()) + { + queueDrained = true; +// DrLogD( +// "WorkQueue::ThreadFunc found empty work queue"); + } + else + { + request = self->m_list.CastOut(self->m_list.RemoveHead()); + LogAssert(request != NULL); +// DrLogD( +// "WorkQueue::ThreadFunc removed work item to process"); + } + } + + if (!queueDrained) + { + request->Process(); + delete request; + request = NULL; + } + } while (!queueDrained); + } while (!finished); + + DrLogI("WorkQueue::ThreadFunc exiting thread"); + + return 0; +} + +void WorkQueue::Start() +{ + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_state == WQS_Stopped); + LogAssert(m_completionPort == INVALID_HANDLE_VALUE); + + DrLogI("WorkQueue::Start entered"); + + m_completionPort = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, + NULL, + NULL, + m_numConcurrentThreads); + LogAssert(m_completionPort != NULL); + + DrLogI("WorkQueue::Start created completion port"); + + DWORD i; + for (i=0; i= WAIT_OBJECT_0 &&*/ + waitRet < (WAIT_OBJECT_0 + m_numWorkerThreads)); + + DrLogI("WorkQueue::Stop all threads have terminated"); + + { + AutoCriticalSection acs(&m_baseDR); + + BOOL bRetval; + + LogAssert(m_numQueuedWakeUps == 0); + LogAssert(m_list.IsEmpty()); + + for (i=0; iShouldAbort()) + { + m_list.InsertAsTail(m_list.CastIn(item)); + item = NULL; + + if (m_numQueuedWakeUps < m_numWorkerThreads) + { + // + // If additional worker threads are availble, post queued work + // + ++m_numQueuedWakeUps; + BOOL retval = ::PostQueuedCompletionStatus(m_completionPort, + DWORKQUEUE_CONTINUE, + NULL, + NULL); + if (retval == 0) + { + // + // Log any failure posting queued work item + // + DWORD errCode = GetLastError(); + DrLogA("WorkQueue::EnQueue post completion status. error code:0x%08x", HRESULT_FROM_WIN32(errCode)); + } + } + } + } + + // + // If item is non-null, then ShouldAbort returned true above. + // In this case, log, abort, and clean up + // + if (item != NULL) + { + DrLogD("WorkQueue::EnQueue processing aborting work item"); + + item->Process(); + delete item; + } + + return true; +} + +void WorkQueue::Clean() +{ + WorkRequestList cleanedList; + DrBListEntry* listEntry; + + { + AutoCriticalSection acs (&m_baseDR); + + listEntry = m_list.GetHead(); + while (listEntry != NULL) + { + WorkRequest* request = m_list.CastOut(listEntry); + listEntry = m_list.GetNext(listEntry); + + if (request->ShouldAbort()) + { + DrLogD("WorkQueue::Clean removing work item from list"); + cleanedList.TransitionToTail(cleanedList.CastIn(request)); + } + } + } + + listEntry = cleanedList.GetHead(); + while (listEntry != NULL) + { + WorkRequest* request = cleanedList.CastOut(listEntry); + listEntry = cleanedList.GetNext(listEntry); + + DrLogD("WorkQueue::Clean processing removed work item"); + + request->Process(); + cleanedList.Remove(cleanedList.CastIn(request)); + delete request; + } + LogAssert(cleanedList.IsEmpty()); +} diff --git a/DryadVertex/VertexHost/system/common/src/xcomputepropertyblock.cpp b/DryadVertex/VertexHost/system/common/src/xcomputepropertyblock.cpp new file mode 100644 index 0000000..f5eea89 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/src/xcomputepropertyblock.cpp @@ -0,0 +1,84 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "xcomputepropertyblock.h" + +#pragma unmanaged + +DryadXComputePnProcessPropertyRequest::DryadXComputePnProcessPropertyRequest() +{ + m_block.Attach(new DrSimpleHeapBuffer()); +} + +void DryadXComputePnProcessPropertyRequest:: + SetPropertyLabel(const char* label, const char* controlLabel) +{ + m_label.Set(label); + m_controlLabel.Set(controlLabel); +} + +void DryadXComputePnProcessPropertyRequest:: + SetPropertyString(const char* string) +{ + m_string.Set(string); +} + +DrMemoryBuffer* + DryadXComputePnProcessPropertyRequest::GetPropertyBlock() +{ + return m_block; +} + + +DryadXComputePnProcessPropertyResponse:: + DryadXComputePnProcessPropertyResponse(PXC_PROCESS_INFO response) +{ + m_processInfo = response; + m_propertyInfo = NULL; + m_block.Attach(new DrFixedMemoryBuffer()); +} + +void DryadXComputePnProcessPropertyResponse:: + RetrievePropertyLabel(const char* label) +{ + m_propertyInfo = NULL; + + UInt32 i; + for (i=0; iNumberofProcessProperties; ++i) + { + PXC_PROCESSPROPERTY_INFO propertyInfo = m_processInfo->ppProperties[i]; + if (::strcmp(propertyInfo->pPropertyLabel, label) == 0) + { + m_propertyInfo = propertyInfo; + break; + } + } + LogAssert(m_propertyInfo != NULL); +} + +DrMemoryBuffer* DryadXComputePnProcessPropertyResponse:: + GetPropertyBlock() +{ + m_block->Init((const BYTE *) m_propertyInfo->pPropertyBlock, + m_propertyInfo->PropertyBlockSize, + m_propertyInfo->PropertyBlockSize); + + return m_block; +} diff --git a/DryadVertex/VertexHost/system/common/src/yarnpropertyblock.cpp b/DryadVertex/VertexHost/system/common/src/yarnpropertyblock.cpp new file mode 100644 index 0000000..9dbcbd7 --- /dev/null +++ b/DryadVertex/VertexHost/system/common/src/yarnpropertyblock.cpp @@ -0,0 +1,90 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + + +DryadYarnPnProcessPropertyRequest::DryadYarnPnProcessPropertyRequest() +{ + m_block.Attach(new DrSimpleHeapBuffer()); +} + +void DryadYarnPnProcessPropertyRequest:: + SetPropertyLabel(const char* label, const char* controlLabel) +{ + m_label.Set(label); + m_controlLabel.Set(controlLabel); +} + +void DryadYarnPnProcessPropertyRequest:: + SetPropertyString(const char* string) +{ + m_string.Set(string); +} + +DrMemoryBuffer* + DryadYarnPnProcessPropertyRequest::GetPropertyBlock() +{ + return m_block; +} + +DryadYarnPnProcessPropertyResponse:: + DryadYarnPnProcessPropertyResponse() +{ + m_block.Attach(new DrFixedMemoryBuffer()); +} + +void DryadYarnPnProcessPropertyResponse:: + RetrievePropertyLabel(const char* /*label*/) +{ + /* + m_propertyInfo = NULL; + + UInt32 i; + for (i=0; iNumberofProcessProperties; ++i) + { + PXC_PROCESSPROPERTY_INFO propertyInfo = m_processInfo->ppProperties[i]; + if (::strcmp(propertyInfo->pPropertyLabel, label) == 0) + { + m_propertyInfo = propertyInfo; + break; + } + } + + LogAssert(m_propertyInfo != NULL); + */ + +} + +DrMemoryBuffer* DryadYarnPnProcessPropertyResponse:: + GetPropertyBlock() +{ + /* + m_block->Init((const BYTE *) m_propertyInfo->pPropertyBlock, + m_propertyInfo->PropertyBlockSize, + m_propertyInfo->PropertyBlockSize); + + return m_block; + */ + return NULL; +} + + + diff --git a/DryadVertex/VertexHost/system/dprocess/dprocess.vcxproj b/DryadVertex/VertexHost/system/dprocess/dprocess.vcxproj new file mode 100644 index 0000000..92bf839 --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/dprocess.vcxproj @@ -0,0 +1,170 @@ + + + + + Debug + Win32 + + + Debug + x64 + + + Release + Win32 + + + Release + x64 + + + + {AA529122-F51C-48D7-A8C1-C0B24F570885} + dprocess + Win32Proj + + + + StaticLibrary + + + StaticLibrary + + + StaticLibrary + + + StaticLibrary + Unicode + true + + + + + + + + + + + + + + + + + + + <_ProjectFileVersion>10.0.40219.1 + Debug\ + Debug\ + $(Platform)\$(Configuration)\ + $(Platform)\$(Configuration)\ + Release\ + Release\ + $(Platform)\$(Configuration)\ + $(Platform)\$(Configuration)\ + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + + + + Disabled + WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions) + true + EnableFastChecks + MultiThreadedDebugDLL + + + Level3 + EditAndContinue + + + + + X64 + + + Disabled + include;..\common\include;..\classlib\include;..\channel\include;src;%(AdditionalIncludeDirectories) + WIN32;_DEBUG;_LIB;WIN32_LEAN_AND_MEAN;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + false + Default + MultiThreadedDebugDLL + + + Level3 + ProgramDatabase + + + %(AdditionalDependencies) + %(AdditionalLibraryDirectories) + + + + + WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + + + + + X64 + + + WIN32;NDEBUG;_LIB;WIN32_LEAN_AND_MEAN;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + include;..\common\include;..\classlib\include;..\channel\include;src;%(AdditionalIncludeDirectories) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + {e092e2b9-d3c9-4ce2-8201-bda442574c97} + + + + + + \ No newline at end of file diff --git a/DryadVertex/VertexHost/system/dprocess/include/dryadvertex.h b/DryadVertex/VertexHost/system/dprocess/include/dryadvertex.h new file mode 100644 index 0000000..b190eb4 --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/include/dryadvertex.h @@ -0,0 +1,590 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once +class DryadVertexFactoryBase; +class DVertexProcessStatus; +class DryadChannelDescription; +class WorkQueue; + +#include +#include +#include +#include +#include + +enum TransformType { + TT_NullTransform = 0, + TT_GzipCompression, + TT_GzipFastCompression, + TT_GzipDecompression, + TT_DeflateCompression, + TT_DeflateDecompression, + TT_DeflateFastCompression + + /* Xpress removed, but left in comments as an example of + * supporting an alternate compression scheme. + TT_XpressCompression, + TT_XpressDecompression, + TT_XpressFastCompression + */ +}; + + +enum VertexAffinityConstraint { + VAC_HardConstraint = 0, + VAC_OptimizationConstraint, + VAC_Preference, + VAC_DontCare +}; + +class DryadVertexController +{ +public: + virtual void AssimilateNewStatus(DVertexProcessStatus* status, + bool sendUpdate, bool notifyWaiters) = 0; +}; + +class DryadVertexBase : public IDrRefCounter +{ +public: + DryadVertexBase(); + virtual ~DryadVertexBase(); + + /* initialize is called once when the process starts up */ + void Initialize(DryadVertexController* controller); + + + /* PrepareDryadVertex is called by the controller before a call to + RunDryadVertex. It will never overlap with RunDryadVertex or + ReOpenChannels. The initial state includes the vertex id and + version, and the initial channel information. */ + virtual void PrepareDryadVertex(DVertexProcessStatus* initialState, + DryadVertexFactoryBase* factory, + UInt32 argumentCount, + DrStr64* argumentList, + UInt32 serializedBlockLength, + const void* serializedBlock) = 0; + + /* RunDryadVertex is the main body of the vertex + execution. initialState is guaranteed to be the same object as + was passed to the preceding call to PrepareDryadVertex. When + RunDryadVertex returns the vertex is considered to have + completed. If the return value is DrExitCode_StillActive then the + controller may re-use the process for a subsequent execution + (i.e. call PrepareDryadVertex again). The vertex may only call + ReportStatus while it is within an executing call to + RunDryadVertex, though it is permissible for RunDryadVertex to + return while there is an outstanding call to ReportStatus. + */ + virtual DrError RunDryadVertex(DVertexProcessStatus* initialState, + UInt32 argumentCount, + DrStr64* argumentList) = 0; + + /* ReOpenChannels may be called by the controller at any time + after a call to PrepareDryadVertex has returned and before the + next call to PrepareDryadVertex. */ + virtual void ReOpenChannels(DVertexProcessStatus* newChannelStatus) = 0; + + void ReportStatus(DVertexProcessStatus* status, + bool sendUpdate, bool notifyWaiters); + +private: + DryadVertexController* m_controller; +}; + +typedef DrRef DryadVertexRef; + +class DryadVertex : public DryadVertexBase +{ +public: + virtual ~DryadVertex(); + DRREFCOUNTIMPL +}; + +class DryadVertexProgramCompletionHandler +{ +public: + virtual void ProgramCompleted() = 0; +}; + +/* this is a container for a resource which will be copied by the PN + to the vertex process's local working directory before it is executed */ +class JobCopiedResource : public DrRefCounter +{ +public: + JobCopiedResource(const char* localResourceName, + const char* remoteResourceName, + void *content, size_t contentLen); + ~JobCopiedResource(); + + const char* GetLocalResourceName(); + const char* GetRemoteResourceName(); + const void* GetContent(); + size_t GetContentLen(); + UInt64 GetContentFingerprint(); + +private: + DrStr64 m_localResourceName; + DrStr64 m_remoteResourceName; + // new[] array of embedded resource content, or NULL if not embedded + unsigned char* m_content; + size_t m_contentLen; + UInt64 m_contentFingerprint; +}; + +typedef DrRef JobCopiedResourceRef; + +class VertexInvocationBase : public RChannelContext +{ +public: + VertexInvocationBase(); + virtual ~VertexInvocationBase(); + + /* The CommandLine is the string sent to Windows' CreateProcess + when the vertex is remotely executed. This can be set + explicitly at the Job Manager, though usually it is done + automatically by a DryadJointApp which is in charge of making + sure the executable is present remotely as a resource and + therefore knows its name. It can be read by a running vertex, + but is probably not of much use (and it may be NULL, e.g. for a + subgraph vertex). */ + const char* GetCommandLine(); + void SetCommandLine(const char* commandLine); + + /* The argument list is the array of strings passed to the vertex + in its Main routine. It can be set at the Job Manager, and is + transmitted by the vertex Start command and deserialized when + the vertex is created remotely, it can be read by the vertex at + any time. argument[0] is reserved for a string which uniquely + identifies the type of the vertex: this is set by all vertex + factories, and is used by the remote executable to decide which + vertex to create. */ + UInt32 GetArgumentCount(); + DrStr64* GetArgumentList(); + const char* GetArgument(UInt32 whichArgument); + void AddArgument(const char* argument); + + /* If these are set to non-zero values, the number of + files/streams that the vertex can hold open at any given time + is throttled. This will lead to deadlock if a vertex blocks on + reads or writes to more than the allowed number of channels, + however in cases where read order is unimportant (such as + non-deterministic merge) this will automatically block the + first read of some channels until the last read of others has + completed. */ + UInt32 GetMaxOpenInputChannelCount(); + void SetMaxOpenInputChannelCount(UInt32 channelCount); + UInt32 GetMaxOpenOutputChannelCount(); + void SetMaxOpenOutputChannelCount(UInt32 channelCount); + + /* The metadata is analogous to the argument list but allows more + complex structured data to be sent from the job manager to the + running vertex. For example this is used by a subgraph vertex + to serialize an arbitrary graph */ + DryadMetaData* GetMetaData(); + void SetMetaData(DryadMetaData* metaData); + + /* There is a raw serialized block containing opaque data sent + from the job manager to the running vertex. It is not typically + accessed directly. Instead a vertex program writer will + override the Serialize method which is called at the job + manager before a vertex is executed, and the DeSerialize method + which is called at the remote process to restore the + state. Serialize/DeSerialize are also used by the graph builder + when cloning vertices. If an error occurs during + deserialization the method should call ReportError to report + it. */ + virtual void Serialize(DrMemoryBufferWriter* writer); + virtual void DeSerialize(DrMemoryBufferReader* reader); + + /* The resources are files which the PN ensures will be available + to the vertex when it is run remotely. They can be set + explicitly at the Job Manager but are usually handled + automatically by a DryadJointApp. They do not appear at the + running vertex (GetResourceCount() will always return 0). */ + UInt32 GetResourceCount(); + JobCopiedResourceRef* GetResourceList(); + JobCopiedResource* GetResource(UInt32 whichResource); + void AddResource(JobCopiedResource* resource); + void AttachResource(JobCopiedResource* resource); + + /* This flag can be set at the job manager and will cause the + specified vertex to break into the debugger on startup. It will + always return false at the running vertex. */ + bool GetDebugBreak(); + void SetDebugBreak(bool debugBreak); + + /* This flag can be set at the job manager and will cause the + specified vertex to simulate failure after every execution, for + testing purposes. It will always return false at the running + vertex. */ + bool GetFakeVertexFailure(); + void SetFakeVertexFailure(bool fakeVertexFailure); + + /* This flag can be set at the job manager and will cause the + specified vertex to simulate failure of its inputs after every + execution, for testing purposes. It will always return false at + the running vertex. */ + bool GetFakeVertexInputFailure(); + void SetFakeVertexInputFailure(bool fakeVertexInputFailure); + + /* This enumeration specifies where a vertex would like to run. By + default it is VAC_DontCare. If it is anything else, the + location list describes which machines are required or + preferred. */ + VertexAffinityConstraint GetAffinityConstraint(); + void SetAffinityConstraint(VertexAffinityConstraint constraint); + std::list* GetAffinityLocationList(); + + /* This flag can be set at the job manager and will cause the + specified vertex to allow itself to be run in a subgraph on a + shared work queue. It is generally set automatically by the + vertex implementation. */ + void SetCanShareWorkQueue(bool canShareWorkQueue); + bool GetCanShareWorkQueue(); + + void SetDisplayName(const char* displayName); + const char* GetDisplayName(); + + void SetCpuUsage(UInt32 cpu); + UInt32 GetCpuUsage(); + + void SetMemoryUsage(UInt64 memory); + UInt64 GetMemoryUsage(); + + void SetDiskUsage(UInt32 disk); + UInt32 GetDiskUsage(); + + /* if an output size hint vector is set, this consists of one + number per output channel that may be used by the channel + writing code to pre-allocate files, increasing performance and + reducing fragmentation. If it has not been set, then + GetOutputSizeHintVector returns NULL. If the number of outputs + changes when the hint vector is non-NULL, there is an assertion + failure. If a total output size hint is set, the number of + outputs can vary and the system will estimate that each output + will get the same amount of data. */ + UInt32 GetOutputSizeHintVectorLength() const; + UInt64* GetOutputSizeHintVector() const; + void SetOutputSizeHintVector(UInt32 numberOfOutputs, UInt64* hints); + UInt64 GetOutputTotalSizeHint() const; + void SetOutputTotalSizeHint(UInt64 hint); + + VertexInvocationBase* CloneInvocation(); + +private: + DrStr64 m_commandLine; + UInt32 m_argumentArraySize; + UInt32 m_numberOfArguments; + DrStr64* m_argument; + UInt32 m_maxInputChannels; + UInt32 m_maxOutputChannels; + DryadMetaDataRef m_metaData; + UInt32 m_resourceArraySize; + UInt32 m_numberOfResources; + JobCopiedResourceRef* m_resource; + bool m_debugBreak; + bool m_fakeVertexFailure; + bool m_fakeVertexInputFailure;; + VertexAffinityConstraint m_affinityConstraint; + std::list m_affinityLocations; + bool m_canShareWorkQueue; + DrStr64 m_displayName; + UInt32 m_cpuUsage; + UInt64 m_memoryUsage; + UInt32 m_diskUsage; + UInt64 m_outputTotalSizeHint; + UInt32 m_outputSizeHintVectorLength; + UInt64* m_outputSizeHintVector; +}; + +class VertexInvocationRecord : public VertexInvocationBase +{ +public: + virtual ~VertexInvocationRecord(); + + DRREFCOUNTIMPL +}; + +class DryadVertexProgramBase : + public VertexInvocationRecord, public DVErrorReporter +{ +public: + DryadVertexProgramBase(); + virtual ~DryadVertexProgramBase(); + + virtual void Usage(FILE* f); + + virtual void Initialize(UInt32 numberOfInputChannels, + UInt32 numberOfOutputChannels); + + UInt32 GetVertexId(); + void SetVertexId(UInt32 vertexId); + + UInt32 GetVertexVersion(); + void SetVertexVersion(UInt32 vertexVersion); + + UInt64 GetExpectedInputLength(UInt32 inputChannel); + void SetExpectedInputLength(UInt32 numberOfChannels, + UInt64* expectedLength); + + void SetNumberOfParserFactories(UInt32 numberOfFactories); + void SetParserFactory(UInt32 whichFactory, + DryadParserFactoryBase* factory); + void SetCommonParserFactory(DryadParserFactoryBase* factory); + DryadParserFactoryBase* GetCommonParserFactory(); + DryadParserFactoryBase* GetParserFactory(UInt32 whichFactory); + + void SetNumberOfMarshalerFactories(UInt32 numberOfFactories); + void SetMarshalerFactory(UInt32 whichFactory, + DryadMarshalerFactoryBase* factory); + void SetCommonMarshalerFactory(DryadMarshalerFactoryBase* factory); + DryadMarshalerFactoryBase* GetCommonMarshalerFactory(); + DryadMarshalerFactoryBase* GetMarshalerFactory(UInt32 whichFactory); + + virtual void MakeInputParser(UInt32 whichInput, + RChannelItemParserRef* pParser); + + virtual void MakeOutputMarshaler(UInt32 whichOutput, + RChannelItemMarshalerRef* pMarshaler); + + /* a vertex program must override at least one of the Main or + MainAsync methods. By default, the Main method calls MainAsync + and waits on an event for the completion handler to be + called. By default the MainAsync method creates a thread which + calls Main and triggers the handler when Main completes. */ + virtual void Main(WorkQueue* workQueue, + UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel); + virtual void MainAsync(WorkQueue* workQueue, + UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel, + DryadVertexProgramCompletionHandler* handler); + /* After MainAsync has triggered its handler, AsyncPostCompletion + is called to allow the vertex to do any required cleanup on the + main calling thread (i.e. not within any handlers). The current + status when the competion handler was called can be read using + GetError amd GetMetaData and overridden using ReportError. */ + virtual void AsyncPostCompletion(); + + void NotifyChannelsOfCompletion(UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel); + void DrainChannels(UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel); + +private: + class ThreadBlock + { + public: + ThreadBlock(DryadVertexProgramBase* parent, + WorkQueue* workQueue, + UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel, + DryadVertexProgramCompletionHandler* handler); + void Run(); + + private: + DryadVertexProgramBase* m_parent; + WorkQueue* m_workQueue; + UInt32 m_numberOfInputChannels; + RChannelReader** m_inputChannel; + UInt32 m_numberOfOutputChannels; + RChannelWriter** m_outputChannel; + DryadVertexProgramCompletionHandler* m_handler; + }; + + class DefaultHandler : public DryadVertexProgramCompletionHandler + { + public: + DefaultHandler(HANDLE completionEvent); + + void ProgramCompleted(); + + private: + HANDLE m_completionEvent; + }; + + static unsigned MainThreadFunc(void* arg); + + UInt32 m_vertexId; + UInt32 m_vertexVersion; + UInt32 m_expectedSizeArrayLength; + UInt64* m_expectedSizeArray; + UInt32 m_numberOfParserFactories; + DryadParserFactoryRef* m_parserFactoryArray; + UInt32 m_numberOfMarshalerFactories; + DryadMarshalerFactoryRef* m_marshalerFactoryArray; + bool m_defaultMainCalled; + bool m_defaultMainAsyncCalled; +}; + +typedef DrRef DryadVertexProgramRef; + +class DryadVertexProgram : public DryadVertexProgramBase +{ +public: + virtual ~DryadVertexProgram(); + DRREFCOUNTIMPL +}; + +class DryadSimpleChannelVertexBase : + public DryadVertexBase, DryadVertexProgramCompletionHandler +{ +public: + DryadSimpleChannelVertexBase(); + virtual ~DryadSimpleChannelVertexBase(); + + void SetStatusInterval(DrTimeInterval interval); + DrTimeInterval GetStatusInterval(); + + void PrepareDryadVertex(DVertexProcessStatus* initialState, + DryadVertexFactoryBase* factory, + UInt32 argumentCount, + DrStr64* argumentList, + UInt32 serializedBlockLength, + const void* serializedBlock); + DrError RunDryadVertex(DVertexProcessStatus* initialState, + UInt32 argumentCount, + DrStr64* argumentList); + void ReOpenChannels(DVertexProcessStatus* newChannelStatus); + + void ProgramCompleted(); + +private: + class ReaderData + { + public: + ReaderData(); + + RChannelItemParserRef m_parser; + RChannelBufferReader* m_bufferReader; + RChannelReader* m_reader; + bool m_isFifo; + }; + + class WriterData + { + public: + WriterData(); + + RChannelItemMarshalerRef m_marshaler; + RChannelBufferWriter* m_bufferWriter; + RChannelWriter* m_writer; + bool m_isFifo; + }; + + void UpdateChannelProgress(DVertexProcessStatus* status, + UInt32 inputChannelCount, + RChannelReaderHolderRef* rData, + UInt32 outputChannelCount, + RChannelWriterHolderRef* wData); + void RunProgram(DVertexProcessStatus* status, + WorkQueue* workQueue, + UInt32 inputChannelCount, + RChannelReaderHolderRef* rData, + UInt32 outputChannelCount, + RChannelWriterHolderRef* wData); + + DrError m_initializationError; + HANDLE m_programCompleted; + DrTimeInterval m_statusInterval; +public: + DryadVertexProgramRef m_vertexProgram; + UInt32 m_maxParseBatchSize; + UInt32 m_maxMarshalBatchSize; +}; + +class DryadSimpleChannelVertex : public DryadSimpleChannelVertexBase +{ +public: + virtual ~DryadSimpleChannelVertex(); + DRREFCOUNTIMPL +}; + +// +// Remove argentia vertex +// +#if 0 +class DryadArgentiaVertexBase : + public DryadVertexBase, DryadVertexProgramCompletionHandler +{ +public: + DryadArgentiaVertexBase(); + virtual ~DryadArgentiaVertexBase(); + + void SetStatusInterval(DrTimeInterval interval); + DrTimeInterval GetStatusInterval(); + + void PrepareDryadVertex( + DVertexProcessStatus* initialState, + DryadVertexFactoryBase* factory, + UInt32 argumentCount, + DrStr64* argumentList, + UInt32 serializedBlockLength, + const void* serializedBlock + ); + + DrError RunDryadVertex( + DVertexProcessStatus* initialState, + UInt32 argumentCount, + DrStr64* argumentList + ); + + void ReOpenChannels(DVertexProcessStatus* newChannelStatus); + + void ProgramCompleted(); + +private: + + DrError WriteVertexPlan( + DVertexProcessStatus* initialState, + UInt32 argumentCount, + DrStr64* argumentList, + CString & sFileName + ); + + DrError RunProgram(LPCWSTR pszVertexPlanPath, LPCWSTR pszExecutablePath); + + UInt32 m_maxParseBatchSize; + UInt32 m_maxMarshalBatchSize; + DrError m_initializationError; + DryadVertexProgramRef m_vertexProgram; + HANDLE m_programCompleted; + DrTimeInterval m_statusInterval; +}; + +class DryadArgentiaVertex : public DryadArgentiaVertexBase +{ +public: + virtual ~DryadArgentiaVertex(); + DRREFCOUNTIMPL +}; +#endif diff --git a/DryadVertex/VertexHost/system/dprocess/include/dvertexcosmosenvironment.h b/DryadVertex/VertexHost/system/dprocess/include/dvertexcosmosenvironment.h new file mode 100644 index 0000000..80c40e2 --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/include/dvertexcosmosenvironment.h @@ -0,0 +1,38 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +class DVertexDryadEnvironment : public DVertexEnvironment +{ +public: + DVertexDryadEnvironment(); + + DrError InitializeFromEnvironment(); + + const char* GetCluster(); + +private: + DrPortNumber m_processNodePort; + DrStr32 m_processCluster; + DrServiceDescriptor m_sd; +}; diff --git a/DryadVertex/VertexHost/system/dprocess/include/dvertexenvironment.h b/DryadVertex/VertexHost/system/dprocess/include/dvertexenvironment.h new file mode 100644 index 0000000..cfdc814 --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/include/dvertexenvironment.h @@ -0,0 +1,47 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include + +class DVertexEnvironment +{ +public: + DVertexEnvironment(); + virtual ~DVertexEnvironment(); + + virtual DrError InitializeFromEnvironment() = 0; + + DryadProcessIdentifier* GetPNProcess(); + DryadMachineIdentifier* GetPNMachine(); + + UInt32 GetPNQuota() const; + UInt32 GetMaxFailureThreshold() const; + UInt32 GetMinNumberOfFailuresBeforeAbort() const; + +protected: + DrRef m_process; + DrRef m_machine; + UInt32 m_pnQuota; + UInt32 m_minNumberOfFailuresBeforeAbort; + UInt32 m_maxFailureThreshold; +}; diff --git a/DryadVertex/VertexHost/system/dprocess/include/dvertexmain.h b/DryadVertex/VertexHost/system/dprocess/include/dvertexmain.h new file mode 100644 index 0000000..ed24811 --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/include/dvertexmain.h @@ -0,0 +1,40 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include + +class DryadVertexFactoryBase; + +// This is the default implementation of main() for a vertex. If the application does not +// implement main(), this function gets called by a stub main(). +extern int DryadVertexMain(int argc, char* argv[], + DryadVertexFactoryBase* factory); +extern int DryadJobManagerMain(int argc, char* argv[]); +extern int DryadJointMain(int argc, char* argv[]); + +/* if an app wants to use DVERTEXMAIN instead of DVERTEXJOINTMAIN it + needs to implement this call to include registration of all its + static vertex factories. If it returns NULL then the vertex to run + will be determined by the command line. Otherwise the vertex will + be the one generated by the factory that is returned. */ +DryadVertexFactoryBase* DryadRegisterFactories(); diff --git a/DryadVertex/VertexHost/system/dprocess/include/dvertexxcomputeenvironment.h b/DryadVertex/VertexHost/system/dprocess/include/dvertexxcomputeenvironment.h new file mode 100644 index 0000000..5d98a4c --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/include/dvertexxcomputeenvironment.h @@ -0,0 +1,29 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +class DVertexXComputeEnvironment : public DVertexEnvironment +{ +public: + DrError InitializeFromEnvironment(); +}; diff --git a/DryadVertex/VertexHost/system/dprocess/include/vertexfactory.h b/DryadVertex/VertexHost/system/dprocess/include/vertexfactory.h new file mode 100644 index 0000000..5e376a2 --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/include/vertexfactory.h @@ -0,0 +1,183 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#pragma warning(disable:4995) +#include +#include +#include + +class DryadVertexFactoryBase +{ +public: + // + // Constructor/Destructor + // + DryadVertexFactoryBase(const char* name); + virtual ~DryadVertexFactoryBase(); + + // + // Return the name of the factory + // + const char* GetName(); + + // + // The Register method does nothing, but can be used to pull in + // static factories from other compilation units. + // + void Register(); + + // + // Can make new typed or untyped dryad vertex. Used differently depending on + // child implementation + // + DryadVertexProgramRef MakeUntyped(); + DryadVertexProgramBase* NewUntyped(); + +private: + // + // NewUntypedInternal is not defined in factory base + // + virtual DryadVertexProgramBase* NewUntypedInternal() = 0; + + // + // Name of factory + // + DrStr64 m_name; +}; + +// +// Vertex factory using defined type +// +template class TypedVertexFactory : public DryadVertexFactoryBase +{ +public: + typedef _T VertexClass; + typedef DrRef<_T> Vertex; + + TypedVertexFactory(const char* name) : DryadVertexFactoryBase(name) {} + virtual ~TypedVertexFactory() {} + + // + // Create new vertex of the defined type with name of factory + // + VertexClass* New() + { + VertexClass* v = NewInternal(); + v->AddArgument(GetName()); + return v; + } + + // + // Make a DrRef to a new vertex of the defined type + // + Vertex Make() + { + Vertex v; + v.Attach(New()); + return v; + } + +private: + // + // Create a vertex of the defined type + // + DryadVertexProgramBase* NewUntypedInternal() + { + return New(); + } + + // + // Don't define NewInternal + // + virtual VertexClass* NewInternal() = 0; +}; + +// +// Standard implementation of a vertex factory +// +template class StdTypedVertexFactory : + public TypedVertexFactory<_T> +{ +public: + StdTypedVertexFactory(const char* name) : TypedVertexFactory<_T>(name) {} + +private: + // + // Create new vertex of the defined type + // + VertexClass* NewInternal() + { + return new VertexClass(); + } +}; + +// +// Registry for vertex factories +// +class VertexFactoryRegistry +{ +public: + VertexFactoryRegistry(); + + static void ShowAllVertexUsageMessages(FILE* f); + + // + // Get reference to vertex factory in registry. Returns null if DNE. + // + static DryadVertexFactoryBase* LookUpFactory(const char* name); + + // + // Place a factory in the registry + // + static void RegisterFactory(DryadVertexFactoryBase* factory); + + // + // Create a vertex + // + static DrError MakeVertex(UInt32 vertexId, + UInt32 vertexVersion, + UInt32 numberOfInputChannels, + UInt32 numberOfOutputChannels, + UInt64* expectedInputLength, + DryadVertexFactoryBase* factory, + DryadMetaData* metaData, + UInt32 maxInputChannels, + UInt32 maxOutputChannels, + UInt32 argumentCount, + DrStr64* argumentList, + UInt32 serializedBlockLength, + const void* serializedBlock, + DryadMetaDataRef* pErrorData, + DryadVertexProgramRef* pProgram); + +private: + typedef std::map< std::string, DryadVertexFactoryBase*, + std::less > FactoryMap; + typedef std::set< std::string, std::less > DuplicateSet; + + FactoryMap m_factories; + DuplicateSet m_errorSet; + bool m_registeredNULL; +}; + +extern DryadVertexFactoryBase* g_subgraphFactory; diff --git a/DryadVertex/VertexHost/system/dprocess/src/dryadvertex.cpp b/DryadVertex/VertexHost/system/dprocess/src/dryadvertex.cpp new file mode 100644 index 0000000..881da6b --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/src/dryadvertex.cpp @@ -0,0 +1,2009 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// Includes +// +#include +#include "DrExecution.h" +#include "dryadvertex.h" +#include "dvertexcommand.h" +#include "dvertexpncontrol.h" +#include "workqueue.h" +#include "concreterchannel.h" +#include "dryadmetadata.h" +#include "dryaderrordef.h" +#include "vertexfactory.h" +#include "subgraphvertex.h" +#include "fingerprint.h" + +#pragma managed + +// +// Use 2 minute update period to avoid flooding the GM +// +static const DrTimeInterval s_defaultStatusInterval = DrTimeInterval_Minute * 2; + +static TransformType StripCompressionModeFromUri(char *uri) +{ + TransformType mode = TT_NullTransform; + char *start = strstr(uri, "?c="); + if (start != NULL) + { + int modeInt = atoi(start + 3); + switch (modeInt) + { + // Microsoft.Hpc.Dsc.Internal.DscCompressionSchemeInternal.None + case 0: + mode = TT_NullTransform; + break; + // Microsoft.Hpc.Dsc.Internal.DscCompressionSchemeInternal.Gzip + case 1: + mode = TT_GzipFastCompression; + break; + default: + DrLogA("Invalid compression scheme %d specified in URI: %s", modeInt, uri); + break; + } + *start = 0; + } + + return mode; +} + +// +// Constructor. No associated controller +// +DryadVertexBase::DryadVertexBase() +{ + m_controller = NULL; +} + +// +// Destructor. Does nothing. +// +DryadVertexBase::~DryadVertexBase() +{ +} + +// +// Associate controller with vertex +// +void DryadVertexBase::Initialize(DryadVertexController* controller) +{ + LogAssert(controller != NULL); + + m_controller = controller; +} + +void DryadVertexBase::ReportStatus(DVertexProcessStatus* status, + bool sendUpdate, bool notifyWaiters) +{ + LogAssert(m_controller != NULL); + + m_controller->AssimilateNewStatus(status, sendUpdate, notifyWaiters); +} + +// +// Destructor. Does Nothing. +// +DryadVertex::~DryadVertex() +{ +} + + +JobCopiedResource::JobCopiedResource(const char* localResourceName, + const char* remoteResourceName, + void *content, + size_t contentLen) +{ + LogAssert(localResourceName != NULL); + m_localResourceName.Set(localResourceName); + m_remoteResourceName.Set(remoteResourceName); + m_content = (unsigned char *) content; + m_contentLen = contentLen; + if (m_content == NULL) + { + LogAssert(m_contentLen == 0); + m_contentFingerprint = 0; + } + else + { + m_contentFingerprint = FingerPrint64::GetInstance()-> + GetFingerPrint(m_content, m_contentLen); + } +} + +JobCopiedResource::~JobCopiedResource() +{ + delete [] m_content; +} + +const char* JobCopiedResource::GetLocalResourceName() +{ + return m_localResourceName; +} + +const char* JobCopiedResource::GetRemoteResourceName() +{ + return m_remoteResourceName; +} + +const void* JobCopiedResource::GetContent() +{ + return m_content; +} + +size_t JobCopiedResource::GetContentLen() +{ + return m_contentLen; +} + +UInt64 JobCopiedResource::GetContentFingerprint() +{ + return m_contentFingerprint; +} + + +VertexInvocationBase::VertexInvocationBase() +{ + m_argumentArraySize = 8; + m_argument = new DrStr64[m_argumentArraySize]; + m_numberOfArguments = 0; + + m_maxInputChannels = 0; + m_maxOutputChannels = 0; + + m_resourceArraySize = 8; + m_resource = new JobCopiedResourceRef[m_resourceArraySize]; + m_numberOfResources = 0; + + m_debugBreak = false; + m_fakeVertexFailure = false; + m_fakeVertexInputFailure = false; + + m_affinityConstraint = VAC_DontCare; + + m_canShareWorkQueue = false; + + m_cpuUsage = MAX_UINT32; + m_memoryUsage = MAX_UINT64; + m_diskUsage = MAX_UINT32; + + m_outputTotalSizeHint = 0; + m_outputSizeHintVectorLength = 0; + m_outputSizeHintVector = NULL; +} + +VertexInvocationBase::~VertexInvocationBase() +{ + delete [] m_argument; + delete [] m_resource; + delete [] m_outputSizeHintVector; +} + +const char* VertexInvocationBase::GetCommandLine() +{ + return m_commandLine; +} + +void VertexInvocationBase::SetCommandLine(const char* commandLine) +{ + m_commandLine = commandLine; +} + +UInt32 VertexInvocationBase::GetMaxOpenInputChannelCount() +{ + return m_maxInputChannels; +} + +void VertexInvocationBase::SetMaxOpenInputChannelCount(UInt32 channelCount) +{ + m_maxInputChannels = channelCount; +} + +UInt32 VertexInvocationBase::GetMaxOpenOutputChannelCount() +{ + return m_maxOutputChannels; +} + +void VertexInvocationBase::SetMaxOpenOutputChannelCount(UInt32 channelCount) +{ + m_maxOutputChannels = channelCount; +} + +UInt32 VertexInvocationBase::GetArgumentCount() +{ + return m_numberOfArguments; +} + +DrStr64* VertexInvocationBase::GetArgumentList() +{ + return m_argument; +} + +// +// Return an argument if in valid range +// +const char* VertexInvocationBase::GetArgument(UInt32 whichArgument) +{ + LogAssert(whichArgument < m_numberOfArguments); + return m_argument[whichArgument]; +} + +// +// Add an argument to a vertex +// +void VertexInvocationBase::AddArgument(const char* argument) +{ + // + // If argument list is full, grow it by factor of two + // + if (m_numberOfArguments == m_argumentArraySize) + { + // + // Create new array twice as big + // + m_argumentArraySize *= 2; + DrStr64* newArray = new DrStr64[m_argumentArraySize]; + LogAssert(newArray != NULL); + + // + // Copy each element in old array into new array + // + UInt32 i; + for (i=0; iGetLocalResourceName(); + for(UInt32 i = 0; i < m_numberOfResources; ++i) + { + if (m_resource[i] == NULL) + continue; + if(_stricmp(m_resource[i]->GetLocalResourceName(), localResourceName) == 0) + { + if((m_resource[i]->GetContentLen() != resource->GetContentLen()) || + (m_resource[i]->GetContentFingerprint() != resource->GetContentFingerprint()) || + (memcmp(m_resource[i]->GetContent(), resource->GetContent(), resource->GetContentLen()) != 0)) + { + DrLogE( "Existing resource", "RemoteName=%s, LocalName=%s, Len=%u", m_resource[i]->GetRemoteResourceName(), m_resource[i]->GetLocalResourceName(), m_resource[i]->GetContentLen()); + DrLogE( "New resource", "RemoteName=%s, LocalName=%s, Len=%u", resource->GetRemoteResourceName(), resource->GetLocalResourceName(), resource->GetContentLen()); + DrLogA( "Duplicated embedded resource with different content"); + } + else + { + DrLogW( "Duplicated embedded resource ignored", "RemoteName=%s, LocalName=%s, Len=%u", resource->GetRemoteResourceName(), resource->GetLocalResourceName(), resource->GetContentLen()); + // Resource will be free'd automatically + return; + } + } + } + m_resource[m_numberOfResources] = resource; + ++m_numberOfResources; +} + +void VertexInvocationBase::AttachResource(JobCopiedResource* resource) +{ + AddResource(resource); + resource->DecRef(); +} + +bool VertexInvocationBase::GetDebugBreak() +{ + return m_debugBreak; +} + +void VertexInvocationBase::SetDebugBreak(bool debugBreak) +{ + m_debugBreak = debugBreak; +} + +bool VertexInvocationBase::GetFakeVertexFailure() +{ + return m_fakeVertexFailure; +} + +void VertexInvocationBase::SetFakeVertexFailure(bool fakeVertexFailure) +{ + m_fakeVertexFailure = fakeVertexFailure; +} + +bool VertexInvocationBase::GetFakeVertexInputFailure() +{ + return m_fakeVertexInputFailure; +} + +void VertexInvocationBase::SetFakeVertexInputFailure(bool fakeVertexFailure) +{ + m_fakeVertexInputFailure = fakeVertexFailure; +} + +void VertexInvocationBase:: + SetAffinityConstraint(VertexAffinityConstraint constraint) +{ + m_affinityConstraint = constraint; +} + +VertexAffinityConstraint VertexInvocationBase::GetAffinityConstraint() +{ + return m_affinityConstraint; +} + +std::list* VertexInvocationBase::GetAffinityLocationList() +{ + return &m_affinityLocations; +} + +void VertexInvocationBase::SetCanShareWorkQueue(bool canShareWorkQueue) +{ + m_canShareWorkQueue = canShareWorkQueue; +} + +bool VertexInvocationBase::GetCanShareWorkQueue() +{ + return m_canShareWorkQueue; +} + +void VertexInvocationBase::SetDisplayName(const char* displayName) +{ + m_displayName = displayName; +} + +const char* VertexInvocationBase::GetDisplayName() +{ + return m_displayName; +} + +void VertexInvocationBase::SetCpuUsage(UInt32 cpu) +{ + m_cpuUsage = cpu; +} + +UInt32 VertexInvocationBase::GetCpuUsage() +{ + return m_cpuUsage; +} + +void VertexInvocationBase::SetMemoryUsage(UInt64 memory) +{ + m_memoryUsage = memory; +} + +UInt64 VertexInvocationBase::GetMemoryUsage() +{ + return m_memoryUsage; +} + +void VertexInvocationBase::SetDiskUsage(UInt32 disk) +{ + m_diskUsage = disk; +} + +UInt32 VertexInvocationBase::GetDiskUsage() +{ + return m_diskUsage; +} + +UInt64 VertexInvocationBase::GetOutputTotalSizeHint() const +{ + return m_outputTotalSizeHint; +} + +void VertexInvocationBase::SetOutputTotalSizeHint(UInt64 hint) +{ + if (hint != 0) + { + LogAssert(m_outputSizeHintVector == NULL); + } + m_outputTotalSizeHint = hint; +} + +UInt32 VertexInvocationBase::GetOutputSizeHintVectorLength() const +{ + return m_outputSizeHintVectorLength; +} + +UInt64* VertexInvocationBase::GetOutputSizeHintVector() const +{ + return m_outputSizeHintVector; +} + +void VertexInvocationBase::SetOutputSizeHintVector(UInt32 numberOfOutputs, + UInt64* hints) +{ + if (numberOfOutputs != 0) + { + LogAssert(m_outputTotalSizeHint == 0); + } + + delete [] m_outputSizeHintVector; + + m_outputSizeHintVectorLength = numberOfOutputs; + if (numberOfOutputs == 0) + { + m_outputSizeHintVector = NULL; + } + else + { + m_outputSizeHintVector = new UInt64[m_outputSizeHintVectorLength]; + + UInt32 i; + for (i=0; i 0); + + DryadVertexFactoryBase* factory = + VertexFactoryRegistry::LookUpFactory(m_argument[0]); + + LogAssert(factory != NULL, "Clone called on invalid vertex: %s", m_argument[0].GetString()); + + DryadVertexProgramBase* clone = factory->NewUntyped(); + + clone->SetCommandLine(m_commandLine); + + UInt32 i; + + for (i=1; iAddArgument(m_argument[i]); + } + + clone->SetMetaData(m_metaData); + + DrRef buffer; + buffer.Attach(new DrSimpleHeapBuffer()); + { + DrMemoryBufferWriter writer(buffer); + Serialize(&writer); + } + + { + DrMemoryBufferReader reader(buffer); + clone->DeSerialize(&reader); + if (clone->GetErrorCode() != DrError_OK) + { + DrLogA( "Deserialize clone failed. Error %s", DRERRORSTRING(clone->GetErrorCode())); + } + } + + for (i=0; iAddResource(m_resource[i]); + } + + clone->SetMaxOpenInputChannelCount(m_maxInputChannels); + clone->SetMaxOpenOutputChannelCount(m_maxOutputChannels); + + clone->SetCanShareWorkQueue(m_canShareWorkQueue); + + clone->SetDisplayName(m_displayName); + clone->SetDebugBreak(m_debugBreak); + clone->SetFakeVertexFailure(m_fakeVertexFailure); + clone->SetFakeVertexInputFailure(m_fakeVertexInputFailure); + clone->SetAffinityConstraint(m_affinityConstraint); + clone->GetAffinityLocationList()->assign(m_affinityLocations.begin(), + m_affinityLocations.end()); + clone->SetCpuUsage(m_cpuUsage); + clone->SetMemoryUsage(m_memoryUsage); + clone->SetDiskUsage(m_diskUsage); + + clone->SetOutputTotalSizeHint(m_outputTotalSizeHint); + clone->SetOutputSizeHintVector(m_outputSizeHintVectorLength, + m_outputSizeHintVector); + + return clone; +} + +VertexInvocationRecord::~VertexInvocationRecord() +{ +} + +class DVPDummyHandler : public RChannelItemWriterHandler +{ +public: + void ProcessWriteCompleted(RChannelItemType status, + RChannelItem* marshalFailureItem); +}; + +void DVPDummyHandler::ProcessWriteCompleted(RChannelItemType /*status*/, + RChannelItem* /*failureItem*/) +{ + delete this; +} + + +DryadVertexProgramBase::ThreadBlock:: + ThreadBlock(DryadVertexProgramBase* parent, + WorkQueue* workQueue, + UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel, + DryadVertexProgramCompletionHandler* handler) +{ + m_parent = parent; + m_workQueue = workQueue; + m_numberOfInputChannels = numberOfInputChannels; + m_inputChannel = inputChannel; + m_numberOfOutputChannels = numberOfOutputChannels; + m_outputChannel = outputChannel; + m_handler = handler; +} + +// +// Run dryad vertex program and report completion +// +void DryadVertexProgramBase::ThreadBlock::Run() +{ + // + // Execute main function + // parent is of type ManagedWrapperVertex + // + m_parent->Main(m_workQueue, + m_numberOfInputChannels, m_inputChannel, + m_numberOfOutputChannels, m_outputChannel); + + // + // Report completion + // + m_handler->ProgramCompleted(); +} + + +DryadVertexProgramBase::DefaultHandler::DefaultHandler(HANDLE completionEvent) +{ + m_completionEvent = completionEvent; +} + +void DryadVertexProgramBase::DefaultHandler::ProgramCompleted() +{ + BOOL bRet = ::SetEvent(m_completionEvent); + LogAssert(bRet != 0); +} + + +DryadVertexProgramBase::DryadVertexProgramBase() +{ + m_defaultMainCalled = false; + m_defaultMainAsyncCalled = false; + m_parserFactoryArray = NULL; + m_numberOfParserFactories = 0; + m_marshalerFactoryArray = NULL; + m_numberOfMarshalerFactories = 0; + m_expectedSizeArrayLength = 0; + m_expectedSizeArray = NULL; +} + +DryadVertexProgramBase::~DryadVertexProgramBase() +{ + delete [] m_parserFactoryArray; + delete [] m_marshalerFactoryArray; + delete [] m_expectedSizeArray; +} + +// +// Print out usage information to file +// If no arguments, vertex wasn't created +// If vertex was created, there's no info because this is just the base vertex class +// +void DryadVertexProgramBase::Usage(FILE* f) +{ + if (GetArgumentCount() < 1) + { + fprintf(f, "usage called before vertex creation completed\n\n"); + return; + } + + fprintf(f, "vertex %s: no usage information\n\n", + GetArgument(0)); +} + +UInt32 DryadVertexProgramBase::GetVertexId() +{ + return m_vertexId; +} + +void DryadVertexProgramBase::SetVertexId(UInt32 vertexId) +{ + m_vertexId = vertexId; +} + +UInt32 DryadVertexProgramBase::GetVertexVersion() +{ + return m_vertexVersion; +} + +void DryadVertexProgramBase::SetVertexVersion(UInt32 vertexVersion) +{ + m_vertexVersion = vertexVersion; +} + +UInt64 DryadVertexProgramBase::GetExpectedInputLength(UInt32 inputChannel) +{ + if (m_expectedSizeArray == NULL) + { + return (UInt64) -1; + } + + LogAssert(inputChannel < m_expectedSizeArrayLength); + return m_expectedSizeArray[inputChannel]; +} + +// +// Foreach input channel, set the expected length of that channel +// +void DryadVertexProgramBase::SetExpectedInputLength(UInt32 numberOfChannels, + UInt64* expectedLength) +{ + // + // Clear current array and create new one + // + delete [] m_expectedSizeArray; + m_expectedSizeArray = new UInt64[numberOfChannels]; + m_expectedSizeArrayLength = numberOfChannels; + + // + // Foreach channel, remember length + // + UInt32 i; + for (i=0; i= m_numberOfParserFactories) + { + return NULL; + } + + return m_parserFactoryArray[whichFactory]; +} + +void DryadVertexProgramBase:: + SetNumberOfMarshalerFactories(UInt32 numberOfFactories) +{ + delete [] m_marshalerFactoryArray; + m_numberOfMarshalerFactories = numberOfFactories; + m_marshalerFactoryArray = + new DryadMarshalerFactoryRef[m_numberOfMarshalerFactories]; +} + +void DryadVertexProgramBase:: + SetMarshalerFactory(UInt32 whichFactory, + DryadMarshalerFactoryBase* factory) +{ + LogAssert(whichFactory < m_numberOfMarshalerFactories); + m_marshalerFactoryArray[whichFactory] = factory; +} + +void DryadVertexProgramBase:: + SetCommonMarshalerFactory(DryadMarshalerFactoryBase* factory) +{ + SetNumberOfMarshalerFactories(1); + SetMarshalerFactory(0, factory); + m_numberOfMarshalerFactories = 0; +} + +DryadMarshalerFactoryBase* DryadVertexProgramBase::GetCommonMarshalerFactory() +{ + if (m_marshalerFactoryArray != NULL && m_numberOfMarshalerFactories == 0) + { + return m_marshalerFactoryArray[0]; + } + else + { + return NULL; + } +} + +DryadMarshalerFactoryBase* DryadVertexProgramBase:: + GetMarshalerFactory(UInt32 whichFactory) +{ + if (m_marshalerFactoryArray == NULL) + { + return NULL; + } + + if (m_numberOfMarshalerFactories == 0) + { + return m_marshalerFactoryArray[0]; + } + + if (whichFactory >= m_numberOfMarshalerFactories) + { + return NULL; + } + + return m_marshalerFactoryArray[whichFactory]; +} + +void DryadVertexProgramBase::MakeInputParser(UInt32 whichInput, + RChannelItemParserRef* pParser) +{ + if (m_parserFactoryArray != NULL) + { + if (m_numberOfParserFactories == 0) + { +// DrLogI( +// "Attaching parser from common factory", +// "Input channel %u", whichInput); + m_parserFactoryArray[0]->MakeParser(pParser, this); + } + else if (whichInput >= m_numberOfParserFactories) + { + ReportError(DryadError_VertexInitialization, + "Parser for channel %u requested but " + "only %u factories supplied", + whichInput, m_numberOfParserFactories); + } + else if (m_parserFactoryArray[whichInput] == NULL) + { + ReportError(DryadError_VertexInitialization, + "Parser for channel %u requested but " + "no factory supplied for that channel", + whichInput); + } + else + { +// DrLogI( +// "Attaching parser from individual factory", +// "Input channel %u", whichInput); + m_parserFactoryArray[whichInput]->MakeParser(pParser, this); + } + + (*pParser)->SetParserIndex(whichInput); + (*pParser)->SetParserContext(this); + } + else + { + ReportError(DrError_NotImplemented, "No factory implemented"); + } +} + +// +// Make a marshaler for an output channel +// +void DryadVertexProgramBase:: + MakeOutputMarshaler(UInt32 whichOutput, + RChannelItemMarshalerRef* pMarshaler) +{ + if (m_marshalerFactoryArray != NULL) + { + if (m_numberOfMarshalerFactories == 0) + { +// DrLogI( +// "Attaching marshaler from common factory", +// "Output channel %u", whichOutput); + // + // Make a marshaller using the existing factory + // + m_marshalerFactoryArray[0]->MakeMarshaler(pMarshaler, this); + } + else if (whichOutput >= m_numberOfMarshalerFactories) + { + // + // If output index doesn't have it's own factory, report error + // todo: why does this matter + // + ReportError(DryadError_VertexInitialization, + "Marshaler for channel %u requested but " + "only %u factories supplied", + whichOutput, m_numberOfMarshalerFactories); + } + else if (m_marshalerFactoryArray[whichOutput] == NULL) + { + // + // If factory for this output channel is null, report error + // + ReportError(DryadError_VertexInitialization, + "Marshaler for channel %u requested but " + "no factory supplied for that channel", + whichOutput); + } + else + { +// DrLogI( +// "Attaching marshaler from individual factory", +// "Input channel %u", whichOutput); + // + // If factory exists for this output channel, use it to create a marshaler + // + m_marshalerFactoryArray[whichOutput]-> + MakeMarshaler(pMarshaler, this); + } + + // + // Set marshaler index and context + // todo: figure out what happens in middle two cases above + // + (*pMarshaler)->SetMarshalerIndex(whichOutput); + (*pMarshaler)->SetMarshalerContext(this); + } + else + { +// DrLogI( "Attaching default marshaler", +// "Output channel %u", whichOutput); + // + // If there is no factories, just use generic marshaler + // + pMarshaler->Attach(new RChannelStdItemMarshaler()); + } +} + +// +// No initialization in base class +// +void DryadVertexProgramBase::Initialize(UInt32 numberOfInputChannels, + UInt32 numberOfOutputChannels) +{ +} + +void DryadVertexProgramBase::Main(WorkQueue* workQueue, + UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel) +{ + LogAssert(m_defaultMainAsyncCalled == false); + m_defaultMainCalled = true; + + HANDLE completionEvent = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(completionEvent != NULL); + + DefaultHandler handler(completionEvent); + + MainAsync(workQueue, + numberOfInputChannels, inputChannel, + numberOfOutputChannels, outputChannel, &handler); + + DWORD dRet = ::WaitForSingleObject(completionEvent, INFINITE); + LogAssert(dRet == WAIT_OBJECT_0); + + BOOL bRet = ::CloseHandle(completionEvent); + LogAssert(bRet != 0); + + AsyncPostCompletion(); +} + +// +// Executes thread block containing reference to user code +// +unsigned DryadVertexProgramBase::MainThreadFunc(void* arg) +{ + // + // Get thread block and run it. Blocking. + // + ThreadBlock* threadBlock = (ThreadBlock *) arg; + threadBlock->Run(); + + // + // Clean up and return + // + delete threadBlock; + return 0; +} + +// +// Create a new thread to run user vertex code. This returns so that status can be reported periodically to GM +// +void DryadVertexProgramBase::MainAsync(WorkQueue* workQueue, + UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel, + DryadVertexProgramCompletionHandler* + handler) +{ + LogAssert(m_defaultMainCalled == false); + m_defaultMainAsyncCalled = true; + + // + // Creat a new thread block with worker threads and data channels + // + ThreadBlock* threadBlock = + new ThreadBlock(this, workQueue, + numberOfInputChannels, inputChannel, + numberOfOutputChannels, outputChannel, handler); + + // + // Start main method using thread block + // + unsigned threadAddr; + HANDLE threadHandle = + (HANDLE) ::_beginthreadex(NULL, + 0, + MainThreadFunc, + threadBlock, + 0, + &threadAddr); + LogAssert(threadHandle != 0); + + // + // Close the thread handle and return + // + BOOL bRet = ::CloseHandle(threadHandle); + LogAssert(bRet != 0); +} + +// +// Nothing to do post-completion +// +void DryadVertexProgramBase::AsyncPostCompletion() +{ +} + +// +// Notify each I/O channel that they are done and/or not needed +// +void DryadVertexProgramBase:: + NotifyChannelsOfCompletion(UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel) +{ + RChannelItemRef terminationItem; + + // + // Create different channel completion codes depending on reason for ending + // + DrError status = GetErrorCode(); + switch (status) + { + case DrError_OK: + case DryadError_ProcessingInterrupted: + // + // EndOfStream used in successful completion or user-interuption + // + terminationItem. + Attach(RChannelMarkerItem::Create(RChannelItem_EndOfStream, + false)); + break; + + case DryadError_ChannelRestart: + // + // Restart used when requested + // + terminationItem. + Attach(RChannelMarkerItem::Create(RChannelItem_Restart, + false)); + terminationItem->ReplaceMetaData(GetErrorMetaData()); + break; + + default: + // + // Use ProcessingInterrupted for all other reasons + // + terminationItem. + Attach(RChannelMarkerItem:: + CreateErrorItem(RChannelItem_Abort, + DryadError_ProcessingInterrupted)); + break; + } + + // + // Interrupt each input channel with completion reason + // + UInt32 i; + for (i=0; iClone(&clone); + inputChannel[i]->Interrupt(clone); + } + + // + // Write completion reason to each output channel + // + for (i=0; iClone(&clone); + outputChannel[i]->WriteItem(clone, false, dummyHandler); + } +} + +// +// Drain all I/O channels +// +void DryadVertexProgramBase::DrainChannels(UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel) +{ + // + // Drain all input channels + // + UInt32 i; + for (i=0; iDrain(); + } + + bool errorCondition = false; + DryadMetaDataRef errorData; + + // + // Drain all output channels + // + for (i=0; iDrain(DrTimeInterval_Zero, &writeCompletion); + if (writeCompletion != NULL) + { + if (writeCompletion->GetType() != RChannelItem_EndOfStream && errorCondition == false) + { + errorCondition = true; + errorData = writeCompletion->GetMetaData(); + } + } + } + + if (GetErrorCode() == DrError_OK && errorCondition) + { + // + // If completed writing but did not reach end of stream, this is an error + // + ReportError(DryadError_ChannelWriteError, errorData); + } +} + + +DryadVertexProgram::~DryadVertexProgram() +{ +} + + +DryadSimpleChannelVertexBase::ReaderData::ReaderData() +{ + m_bufferReader = NULL; + m_reader = NULL; + m_isFifo = false; +} + +DryadSimpleChannelVertexBase::WriterData::WriterData() +{ + m_bufferWriter = NULL; + m_writer = NULL; + m_isFifo = false; +} + +// +// Constructor - creates simple channel vertex +// +DryadSimpleChannelVertexBase::DryadSimpleChannelVertexBase() +{ + // + // Use defaults + // + m_maxParseBatchSize = RChannelItem::s_defaultItemBatchSize; + m_maxMarshalBatchSize = RChannelItem::s_defaultItemBatchSize; + m_statusInterval = s_defaultStatusInterval; + m_programCompleted = ::CreateEvent(NULL, TRUE, FALSE, NULL); + LogAssert(m_programCompleted != NULL); + DrLogI( "Ensuring subgraph factory is registered. factory name %s", g_subgraphFactory->GetName()); +} + +// +// Destructor. Frees program completed handle +// +DryadSimpleChannelVertexBase::~DryadSimpleChannelVertexBase() +{ + BOOL bRet = ::CloseHandle(m_programCompleted); + LogAssert(bRet != 0); +} + +// +// Sets status interval +// +void DryadSimpleChannelVertexBase::SetStatusInterval(DrTimeInterval interval) +{ + m_statusInterval = interval; +} + +// +// Gets status interval +// +DrTimeInterval DryadSimpleChannelVertexBase::GetStatusInterval() +{ + return m_statusInterval; +} + +// +// Initialize m_vertexProgram with arguments and channels +// +void DryadSimpleChannelVertexBase:: + PrepareDryadVertex(DVertexProcessStatus* initialState, + DryadVertexFactoryBase* factory, + UInt32 argumentCount, + DrStr64* argumentList, + UInt32 serializedBlockLength, + const void* serializedBlock) +{ + DryadMetaDataRef errorData; + + // + // Get input channels' length + // + UInt64* expectedLength = new UInt64[initialState->GetInputChannelCount()]; + UInt32 i; + for (i=0; iGetInputChannelCount(); ++i) + { + expectedLength[i] = + initialState->GetInputChannels()[i].GetChannelTotalLength(); + initialState->GetInputChannels()[i].SetChannelTotalLength(0); + } + + // + // Get vertex program and initialize it with parameters and channels + // + m_initializationError = VertexFactoryRegistry:: + MakeVertex(initialState->GetVertexId(), + initialState->GetVertexInstanceVersion(), + initialState->GetInputChannelCount(), + initialState->GetOutputChannelCount(), + expectedLength, + factory, + initialState->GetVertexMetaData(), + initialState->GetMaxOpenInputChannelCount(), + initialState->GetMaxOpenOutputChannelCount(), + argumentCount, + argumentList, + serializedBlockLength, + serializedBlock, + &errorData, + &m_vertexProgram); + + // + // Clean up channel length array + // + delete [] expectedLength; + + if (m_initializationError != DrError_OK) + { + // + // Remember error data and report failure if unsuccessful + // + initialState->SetVertexMetaData(errorData); + ReportStatus(initialState, false, false); + } +} + +// +// Set event that program has completed so that vertex host knows +// +void DryadSimpleChannelVertexBase::ProgramCompleted() +{ + BOOL bRet = ::SetEvent(m_programCompleted); + LogAssert(bRet != 0); +} + +// +// Update the progress based on the data channels +// +void DryadSimpleChannelVertexBase:: + UpdateChannelProgress(DVertexProcessStatus* status, + UInt32 inputChannelCount, + RChannelReaderHolderRef* rData, + UInt32 outputChannelCount, + RChannelWriterHolderRef* wData) +{ + UInt32 i; + + UInt64 processedLength = 0; + UInt64 totalLength = 0, fifoLength = 0; + + // + // Foreach input channel, get the processed length + // + for (i=0; iGetInputChannels()[i]); + rData[i]->FillInStatus(input); + UInt64 len = input->GetChannelProcessedLength(); + + // + // If fewer than 64 input channels, log each progress + // todo: this seems arbitrary. Probably should be debug feature. + // + if (inputChannelCount < 64) + { + DrLogD( "Input channel progress. %u:%s %I64u/%I64u", i, input->GetChannelURI(), + len, + input->GetChannelTotalLength()); + } + + // + // If fifo channel, add processed length to fifo count, + // otherwise add it to general processed count + // + if (!strncmp(input->GetChannelURI(), "fifo://", 7)) + { + fifoLength += len; + } + else + { + processedLength += len; + } + + // + // Increment total channel length for this channel + // + totalLength += input->GetChannelTotalLength(); + } + + // + // Report the processed length in the log + // + DrLogD( "Aggregated input progress. vertex %u.%u fifo=%I64u ext=%I64u total=%I64u", + status->GetVertexId(), status->GetVertexInstanceVersion(), + fifoLength, processedLength, totalLength); + + // + // Foreach output channel, get the processed length + // + processedLength = fifoLength = 0; + for (i=0; iGetOutputChannels()[i]); + wData[i]->FillInStatus(output); + UInt64 len = output->GetChannelProcessedLength(); + + // + // If fewer than 64 output channels, log each progress + // todo: this seems arbitrary. Probably should be debug feature. + // + if (outputChannelCount < 64) + { + DrLogD( "Output channel progress. %u:%s %I64u", i, output->GetChannelURI(), len); + } + + // + // If fifo channel, add processed length to fifo count, + // otherwise add it to general processed count + // + if (!strncmp(output->GetChannelURI(), "fifo://", 7)) + { + fifoLength += len; + } + else + { + processedLength += len; + } + } + + // + // Report the processed length in the log + // + DrLogD( "Aggregated output progress. vertex %u.%u fifo=%I64u ext=%I64u", status->GetVertexId(), + status->GetVertexInstanceVersion(), fifoLength, processedLength); +} + +// +// Start Async thread to call vertex code and periodically reports status to GM +// (called from RunDryadVertex) +// +void DryadSimpleChannelVertexBase::RunProgram(DVertexProcessStatus* status, + WorkQueue* workQueue, + UInt32 inputChannelCount, + RChannelReaderHolderRef* rData, + UInt32 outputChannelCount, + RChannelWriterHolderRef* wData) +{ + UInt32 i; + + // + // Get channel reader for each channel from reader holder + // + RChannelReader** rArray = new RChannelReader* [inputChannelCount]; + LogAssert(rArray != NULL); + for (i=0; iGetReader()->Start(NULL); + rArray[i] = rData[i]->GetReader(); + rArray[i]->SetExpectedLength(m_vertexProgram-> + GetExpectedInputLength(i)); + } + + // + // Get channel writer for each channel from writer holder + // + RChannelWriter** wArray = new RChannelWriter* [outputChannelCount]; + LogAssert(wArray != NULL); + for (i=0; iGetWriter()->Start(); + wArray[i] = wData[i]->GetWriter(); + } + + // + // Clear the program completed event and report status is good + // + BOOL bRet = ::ResetEvent(m_programCompleted); + LogAssert(bRet != 0); + m_vertexProgram->ReportError(DrError_OK, (DryadMetaData *) NULL); + + // + // Run the "main" method in the vertex program asynchronously + // + m_vertexProgram->MainAsync(workQueue, + inputChannelCount, rArray, + outputChannelCount, wArray, + this); + + // + // While waiting for program completed event, Update Vertex service about progress. + // + DWORD dRet; + do + { + // + // Wait for configurable timeout + // + DWORD waitTimeout = DrGetTimerMsFromInterval(m_statusInterval); + dRet = ::WaitForSingleObject(m_programCompleted, waitTimeout); + LogAssert(dRet == WAIT_OBJECT_0 || dRet == WAIT_TIMEOUT); + + // + // Report data progress to log and ok/not_ok to GM + // + UpdateChannelProgress(status, + inputChannelCount, rData, + outputChannelCount, wData); + ReportStatus(status, true, false); + } while (dRet == WAIT_TIMEOUT); + + // + // After completing vertex execution, perform any cleanup steps + // + m_vertexProgram->AsyncPostCompletion(); + + // + // Record any error information in status + // + status->SetVertexMetaData(m_vertexProgram->GetErrorMetaData()); + + // + // Notify all I/O channels that they're done + // + m_vertexProgram->NotifyChannelsOfCompletion(inputChannelCount, rArray, + outputChannelCount, wArray); + + // + // If vertex status shows that it was interrupted, mark it as ok + // todo: this looks like a hack to avoid reporting an error that can be set as + // vertex program stops. is it? + // + if (m_vertexProgram->GetErrorCode() == DryadError_ProcessingInterrupted) + { + m_vertexProgram->ReportError(DrError_OK); + } + + // + // Drain all I/O channels + // + m_vertexProgram->DrainChannels(inputChannelCount, rArray, + outputChannelCount, wArray); + + // + // Update input channels with correct status of completion + // + for (i=0; iGetTerminationStatus(&channelErrorData); + status->GetInputChannels()[i].SetChannelState(channelStatus); + status->GetInputChannels()[i].SetChannelMetaData(channelErrorData); + } + + // + // Update output channels with correct status of completion + // + for (i=0; iGetTerminationStatus(&channelErrorData); + status->GetOutputChannels()[i].SetChannelState(channelStatus); + status->GetOutputChannels()[i].SetChannelMetaData(channelErrorData); + } + + // + // Write out I/O channel status to logs + // + UpdateChannelProgress(status, + inputChannelCount, rData, + outputChannelCount, wData); + + // + // Clean up I/O channels + // + delete [] rArray; + delete [] wArray; +} + +// +// Prepare the channel readers, writers, and work queue for processing I/O +// Then calls RunProgram to continue quest to invoke user code +// +DrError DryadSimpleChannelVertexBase:: + RunDryadVertex(DVertexProcessStatus* initialState, + UInt32 argumentCount, + DrStr64* argumentList) +{ + // + // Return initialization error if there is already a problem + // + if (m_initializationError != DrError_OK) + { + return m_initializationError; + } + + DrLogI( "Opening vertex channels. Name %s", + (argumentCount > 0) ? argumentList[0].GetString() : "unknown"); + + if(initialState == NULL) + { + DrLogE("RunDryadVertex invoked with NULL initialState"); + return DrError_Fail; + } + + UInt32 i; + UInt32 iCC = initialState->GetInputChannelCount(); + UInt32 oCC = initialState->GetOutputChannelCount(); + + // + // Create a work queue that has two threads per core, + // but only one thread per core able to run concurrently + // start the work queue + // + SYSTEM_INFO systemInfo; + GetSystemInfo(&systemInfo); + WorkQueue* workQueue = new WorkQueue(systemInfo.dwNumberOfProcessors*2, + systemInfo.dwNumberOfProcessors); + workQueue->Start(); + + // + // Create readers for input channels and writers for output channels + // + RChannelReaderHolderRef* rData = new RChannelReaderHolderRef[iCC]; + RChannelWriterHolderRef* wData = new RChannelWriterHolderRef[oCC]; + + bool failed = false; + + // + // Throttle input connections to maximum + // + RChannelOpenThrottler* readThrottler = NULL; + UInt32 maxReaders = m_vertexProgram->GetMaxOpenInputChannelCount(); + if (maxReaders > 0) + { + readThrottler = RChannelFactory::MakeOpenThrottler(maxReaders, + workQueue); + } + + // + // Set maximum parsing batch size to the total maximum over the number of + // input channels, but allow at least 4. + // Done because: if there are lots of input channels, scale back on the + // batching so we don't use so much memory + // + UInt32 maxParseBatchSize = 1; + if (iCC > 0) + { + maxParseBatchSize = m_maxParseBatchSize / iCC; + } + if (maxParseBatchSize < 4) + { + maxParseBatchSize = 4; + } + + UInt32 maxParseUnitsInFlight = maxParseBatchSize * 4; + + // + // Create locality monitoring stuff + // + DWORD localInputChannels = 0; + bool channelLocalCreated = true; + bool* channelLocal = (bool*)malloc(sizeof(bool)*iCC); + if(channelLocal == NULL) + { + channelLocalCreated = false; + DrLogW("Channel locality list could not be created. Logging detail reduced."); + } + + // + // Foreach input channel, open a reader and record any errors + // + for (i=0; iGetInputChannels()[i]); + const char* uri = input->GetChannelURI(); + + // Get compression mode parameter + DrLogD("Original URI: %s", uri); + TransformType mode = StripCompressionModeFromUri(const_cast(uri)); + DrLogD("Transform type: %d. New URI: %s", mode, uri); + + // + // Assume channel is remote unless proven otherwise + // + if(channelLocalCreated) + { + channelLocal[i] = false; + } + + // + // Make an input parser + // + RChannelItemParserRef parser; + m_vertexProgram->MakeInputParser(i, &parser); + if (m_vertexProgram->NoError()) + { + DWORD localInputChannelsSnapshot = localInputChannels; + + // + // Open the channel reader if parser successfully created + // + DrError err = RChannelFactory::OpenReader(uri, input->GetChannelMetaData(), + parser, + iCC, readThrottler, + maxParseBatchSize, + maxParseUnitsInFlight, + workQueue, &errorReporter, + &(rData[i]), + &localInputChannels); + if(err != DrError_OK) + { + DrLogE("RChannelFactory::OpenReader failed."); + return err; + } + + if(rData[i] == NULL) + { + DrLogE("RChannelFactory::OpenReader returned a NULL RChannelReaderHolder object"); + return DrError_Fail; + } + + // + // Set the transform (compression) type + // + rData[i]->GetReader()->SetTransformType(mode); + + + // + // If number of local input channels increased, this one was local + // + if(channelLocalCreated && (localInputChannelsSnapshot < localInputChannels)) + { + channelLocal[i] = true; + } + } + else + { + // + // Report errors making parse + // + errorReporter.ReportError(m_vertexProgram->GetErrorCode(), + m_vertexProgram->GetErrorMetaData()); + } + + // + // Remember failure if anything goes wrong and update channel + // state with any errors. + // + if (errorReporter.GetErrorCode() != DrError_OK) + { + failed = true; + } + + input->SetChannelState(errorReporter.GetErrorCode()); + input->SetChannelMetaData(errorReporter.GetErrorMetaData()); + } + + + // + // Record number of local input files + // + fprintf(stdout, "HpcQueryVertex: Reading %lu input file(s) from local disk and %lu input file(s) over network\n", localInputChannels, iCC - localInputChannels); + + // + // Record list of channel inputs if list was created successfully + // + if(channelLocalCreated) + { + DrStr localChannelString(""); + bool firstChannelDone = false; + for(i = 0; i < iCC; ++i) + { + if (channelLocal[i]) + { + if(firstChannelDone) + { + localChannelString.AppendF(", %u", i); + } + else + { + localChannelString.SetF("%u", i); + firstChannelDone = true; + } + } + } + + fprintf(stdout, "HpcQueryVertex: Channels reading from local input files - {%s}\n", localChannelString.GetString()); + free(channelLocal); + } + + // + // Flush stdout before linq starts using it + // + fflush(stdout); + + // + // Create an output channel throttler + // + RChannelOpenThrottler* writeThrottler = NULL; + UInt32 maxWriters = m_vertexProgram->GetMaxOpenOutputChannelCount(); + if (maxWriters > 0) + { + writeThrottler = RChannelFactory::MakeOpenThrottler(maxWriters, + workQueue); + } + + // + // Set maximum parsing batch size to the total maximum over the number of + // input channels, but allow at least 4. + // Done because: if there are lots of output channels, scale back on the + // batching so we don't use so much memory + // + UInt32 maxMarshalBatchSize = 1; + if (oCC > 0) + { + maxMarshalBatchSize = m_maxMarshalBatchSize / oCC; + } + if (maxMarshalBatchSize < 4) + { + maxMarshalBatchSize = 4; + } + + // + // Foreach input channel, open a reader and record any errors + // + for (i=0; iGetOutputChannels()[i]); + const char* uri = output->GetChannelURI(); + + // Get the compression mode parameter + DrLogD("Original URI: %s", uri); + TransformType mode = StripCompressionModeFromUri(const_cast(uri)); + DrLogD("Transform type: %d. New URI: %s", mode, uri); + + // + // Create marshaller for writing output to this channel + // + RChannelItemMarshalerRef marshaler; + m_vertexProgram->MakeOutputMarshaler(i, &marshaler); + if (m_vertexProgram->NoError()) + { + // + // If marshaler successfully created, open a writer to the output channel + // + RChannelFactory::OpenWriter(uri, output->GetChannelMetaData(), + marshaler, + oCC, writeThrottler, + maxMarshalBatchSize, + workQueue, &errorReporter, + &(wData[i])); + + // + // Set the transform (compression) type if writer was created (no error) + // + if(errorReporter.GetErrorCode() == DrError_OK) + { + wData[i]->GetWriter()->SetTransformType(mode); + } + } + else + { + // + // If marshaler not created successfully, report error + // + errorReporter. + ReportError(m_vertexProgram->GetErrorCode(), + m_vertexProgram->GetErrorMetaData()); + } + + + // + // Remember failure if anything goes wrong and update channel + // state with any errors. + // + if (errorReporter.GetErrorCode() != DrError_OK) + { + failed = true; + } + + output->SetChannelState(errorReporter.GetErrorCode()); + output->SetChannelMetaData(errorReporter.GetErrorMetaData()); + } + + // + // If any failures, record vertex initialization failure + // + DrError err; + if (failed) + { + err = DryadError_VertexInitialization; + } + else + { + // + // Log start of vertex execution + // + DrLogI( "Starting vertex. Name %s", + (argumentCount > 0) ? argumentList[0].GetString() : "unknown"); + + // + // Move forward in process of invoking user vertex - blocking + // + RunProgram(initialState, workQueue, iCC, rData, oCC, wData); + + // + // Get any error code, record, it and mark it as completed if successful + // + err = m_vertexProgram->GetErrorCode(); + DrLogI( "Vertex complete. Name %s status %s", + (argumentCount > 0) ? argumentList[0].GetString() : "unknown", + DRERRORSTRING(err)); + if (err == DrError_OK) + { + err = DryadError_VertexCompleted; + } + } + + DrLogI( "Closing vertex channels. Name %s", + (argumentCount > 0) ? argumentList[0].GetString() : "unknown"); + + // + // Close any channel readers and writers that have been opened + // + delete [] rData; + delete [] wData; + + // + // Clean up read and write throttlers + // + if (readThrottler != NULL) + { + RChannelFactory::DiscardOpenThrottler(readThrottler); + readThrottler = NULL; + } + + if (writeThrottler != NULL) + { + RChannelFactory::DiscardOpenThrottler(writeThrottler); + writeThrottler = NULL; + } + + // + // Stop and clean up thread pool for vertex + // + DrLogI( "Stopping vertex work queue. Name %s", + (argumentCount > 0) ? argumentList[0].GetString() : "unknown"); + workQueue->Stop(); + delete workQueue; + + // + // Report the updated job status and exit + // + ReportStatus(initialState, false, false); + DrLogI( "Exiting vertex. Name %s", + (argumentCount > 0) ? argumentList[0].GetString() : "unknown"); + return err; +} + +// +// Report error if "ReOpenChannels" command ever received, as it is not supported +// +void DryadSimpleChannelVertexBase:: + ReOpenChannels(DVertexProcessStatus* /*newChannelStatus*/) +{ + // + // since a simple channel vertex never requests channels to be + // reopened, the job manager should never call here + // + LogAssert(false); +} + +// +// Do nothing to clean up the vertex. +// +DryadSimpleChannelVertex::~DryadSimpleChannelVertex() +{ +} diff --git a/DryadVertex/VertexHost/system/dprocess/src/dvertexcmdlinecontrol.cpp b/DryadVertex/VertexHost/system/dprocess/src/dvertexcmdlinecontrol.cpp new file mode 100644 index 0000000..6ddc049 --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/src/dvertexcmdlinecontrol.cpp @@ -0,0 +1,1073 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include +#include +#include +#include + +// +// todo: just have one of these - it's redefined a few places +// +static void EliminateArguments(int* pArgc, char* argv[], + int startingLocation, int numberToRemove) +{ + int argc = *pArgc; + + LogAssert(argc >= startingLocation+numberToRemove); + + int i; + for (i=startingLocation+numberToRemove; i 1); + UInt32 nChannels; + + // + // Parse number of channels and ensure it is valid + // + cse = DrStringToUInt32(argv[1], &nChannels); + LogAssert(cse == DrError_OK); + LogAssert((UInt32) argc > nChannels+1); + + // + // Update the number of channels of the specified type + // + if (isInput) + { + initialState->SetInputChannelCount(nChannels); + } + else + { + initialState->SetOutputChannelCount(nChannels); + } + + // + // Set the channel URI to explicitly what is provided in next + // nChannels arguments + // + UInt32 i; + for (i = 0; i < nChannels; ++i) + { + DryadChannelDescription* channel; + if (isInput) + { + channel = &(initialState->GetInputChannels()[i]); + } + else + { + channel = &(initialState->GetOutputChannels()[i]); + } + + channel->SetChannelState(DrError_OK); + channel->SetChannelURI(argv[i+2]); + } + + return nChannels+1; +} + +int DVertexCmdLineController:: + GetChannelOverride(int argc, char* argv[], + DVertexProcessStatus* initialState, + bool isInput) +{ + DrError cse; + LogAssert(argc > 2); + UInt32 whichChannel; + cse = DrStringToUInt32(argv[1], &whichChannel); + LogAssert(cse == DrError_OK); + + DryadChannelDescription* channel; + if (isInput) + { + LogAssert(whichChannel < initialState->GetInputChannelCount()); + channel = &(initialState->GetInputChannels()[whichChannel]); + } + else + { + LogAssert(whichChannel < initialState->GetOutputChannelCount()); + channel = &(initialState->GetOutputChannels()[whichChannel]); + } + + channel->SetChannelURI(argv[2]); + + return 2; +} + +int DVertexCmdLineController:: + GetTextOverride(int argc, char* argv[], + DVertexCommandBlock* startCommand) +{ + DVertexProcessStatus* initialState = startCommand->GetProcessStatus(); + + DrError err = DrError_OK; + LogAssert(argc > 1); + + FILE* fText = fopen(argv[1], "r"); + if (fText == NULL) + { + err = DrGetLastError(); + DrLogE("Failed to open dump text file '%s' --- %s", + argv[1], DRERRORSTRING(err)); + return 1; + } + + static const UInt32 bufSize = 32 * 1024; + char* buf = new char[bufSize]; + + char* s = fgets(buf, bufSize, fText); + if (s == NULL) + { + DrLogE("Couldn't read dump text input count"); + err = DrError_Fail; + } + else + { + UInt32 nInputs; + int nRead = sscanf(buf, "%p", &nInputs); + if (nRead == 1) + { + if (nInputs != initialState->GetInputChannelCount()) + { + DrLogE("Dump input count %u differs from override count %u", + initialState->GetInputChannelCount(), + nInputs); + err = DrError_Fail; + } + else + { + UInt32 i; + for (i=0; i 0 && + (buf[length-1] == '\n' || + buf[length-1] == '\r')) + { + --length; + buf[length] = '\0'; + } + DryadChannelDescription* channel; + channel = &(initialState->GetInputChannels()[i]); + channel->SetChannelURI(buf); + } + } + } + } + else + { + DrLogE("Malformed dump text output count %s", buf); + err = DrError_Fail; + } + } + + if (err != DrError_OK) + { + delete [] buf; + fclose(fText); + return 1; + } + + s = fgets(buf, bufSize, fText); + if (s == NULL) + { + DrLogE("Couldn't read dump text output count"); + err = DrError_Fail; + } + else + { + UInt32 nOutputs; + int nRead = sscanf(buf, "%p", &nOutputs); + if (nRead == 1) + { + if (nOutputs != initialState->GetOutputChannelCount()) + { + DrLogE("Dump output count %u differs from override count %u", + initialState->GetOutputChannelCount(), + nOutputs); + err = DrError_Fail; + } + else + { + UInt32 i; + for (i=0; i 0 && + (buf[length-1] == '\n' || + buf[length-1] == '\r')) + { + --length; + buf[length] = '\0'; + } + DryadChannelDescription* channel; + channel = &(initialState->GetOutputChannels()[i]); + channel->SetChannelURI(buf); + } + } + } + } + else + { + DrLogE("Malformed dump text output count %s", buf); + err = DrError_Fail; + } + } + + s = fgets(buf, bufSize, fText); + if (s == NULL) + { + DrLogE("Couldn't read dump text argument count"); + err = DrError_Fail; + } + else + { + UInt32 nArguments; + int nRead = sscanf(buf, "%u", &nArguments); + if (nRead == 1) + { + if (nArguments != startCommand->GetArgumentCount()) + { + DrLogE("Dump argument count %u differs from override count %u", + startCommand->GetArgumentCount(), + nArguments); + err = DrError_Fail; + } + else + { + UInt32 i; + for (i=0; i 0 && + (buf[length-1] == '\n' || + buf[length-1] == '\r')) + { + --length; + buf[length] = '\0'; + } + startCommand->SetArgument(i, buf); + } + } + } + } + else + { + DrLogE("Malformed dump text argument count %s", buf); + err = DrError_Fail; + } + } + + delete [] buf; + fclose(fText); + return 1; +} + +int DVertexCmdLineController:: + RestoreDumpedStartCommand(int argc, char* argv[], + DVertexCommandBlock* startCommand) +{ + LogAssert(argc > 1); + + DrError err = DrError_OK; + + HANDLE h = CreateFileA(argv[1], + GENERIC_READ, + FILE_SHARE_READ, + NULL, + OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, + NULL); + if (h == INVALID_HANDLE_VALUE) + { + err = DrGetLastError(); + DrLogE("Failed to open dump file '%s' --- %s", + argv[1], DRERRORSTRING(err)); + } + else + { + LARGE_INTEGER filelen; + ::memset(&filelen, 0, sizeof(filelen)); + + BOOL fSuccess = GetFileSizeEx(h, &filelen); + if (!fSuccess) + { + err = DrGetLastError(); + DrLogE("Failed to get dump file size for '%s' --- %s", + argv[1], DRERRORSTRING(err)); + } + else + { + LogAssert(filelen.HighPart == 0); + DWORD nToRead = (DWORD)filelen.LowPart; + BYTE* pBuff = new BYTE[nToRead+1]; + LogAssert(pBuff != NULL); + BYTE *pRead = pBuff; + while (err == DrError_OK && nToRead > 0) + { + DWORD nRead = 0; + fSuccess = ReadFile(h, pRead, nToRead, &nRead, NULL); + if (fSuccess) + { + if (nRead == 0) + { + break; + } + pRead += nRead; + nToRead -= nRead; + } + else + { + err = DrGetLastError(); + } + } + + if (err == DrError_OK) + { + Size_t bufSize = (Size_t) filelen.LowPart; + DrRef buf; + buf.Attach(new DrFixedMemoryBuffer(pBuff, + bufSize, bufSize)); + DrMemoryBufferReader reader(buf); + err = reader.ReadAggregate(DryadTag_VertexCommand, + startCommand, NULL); + if (err != DrError_OK) + { + DrLogE("Failed to deserialize dump file '%s' --- %s", + argv[1], DRERRORSTRING(err)); + } + } + else + { + DrLogE("Failed to read dump file '%s' --- %s", + argv[1], DRERRORSTRING(err)); + } + + delete [] pBuff; + } + } + + LogAssert(err == DrError_OK); + + return 1; +} + +// +// Parses remaining command line arguments +// +void DVertexCmdLineController:: + ParseExplicitCmdLine(int argc, char* argv[], + DVertexCommandBlock* startCommand) +{ + DVertexProcessStatus* initialState = startCommand->GetProcessStatus(); + bool restoredDump = false; + + // + // For each argument starting with '-' + // todo: ensure dump is first if it's used at all + // + while (argc > 1 && argv[1][0] == '-') + { + int consumed = 0; + + if (::_stricmp(argv[1], "-dump") == 0) + { + --argc; + ++argv; + consumed = RestoreDumpedStartCommand(argc, argv, startCommand); + restoredDump = true; + } + else if (::_stricmp(argv[1], "-overrideinput") == 0) + { + // + // Only allowed when restored dump + // + LogAssert(restoredDump == true); + --argc; + ++argv; + consumed = GetChannelOverride(argc, argv, + initialState, true); + } + else if (::_stricmp(argv[1], "-overrideoutput") == 0) + { + // + // Only allowed when restored dump + // + LogAssert(restoredDump == true); + --argc; + ++argv; + consumed = GetChannelOverride(argc, argv, + initialState, false); + } + else if (::_stricmp(argv[1], "-overridetext") == 0) + { + // + // Only allowed when restored dump + // + LogAssert(restoredDump == true); + --argc; + ++argv; + consumed = GetTextOverride(argc, argv, startCommand); + } + else if (::strcmp(argv[1], "-i") == 0) + { + // + // Only allowed when not restored dump + // + LogAssert(restoredDump == false); + --argc; + ++argv; + + // + // Get channel descriptions for input + // + consumed = GetChannelDescriptions(argc, argv, + initialState, true); + } + else if (::strcmp(argv[1], "-o") == 0) + { + // + // Only allowed when not restored dump + // + LogAssert(restoredDump == false); + --argc; + ++argv; + + // + // Get channel descriptions for output + // + consumed = GetChannelDescriptions(argc, argv, + initialState, false); + } + else + { + DrLogA("Unknown argument '%s'", argv[1]); + } + + // + // Update the argument count and pointer based on used parameters + // + LogAssert(argc > consumed); + argc -= consumed; + argv += consumed; + } + + if (restoredDump) + { + // + // If restoring from dump file, all arguments should have been handled + // + LogAssert(argc == 1); + } + else + { + // + // Otherwise, remaining arguments are arguments to start command + // copy remaining arguments into arg vector + // + startCommand->SetArgumentCount(argc-1); + DrStr64* argVector = startCommand->GetArgumentVector(); + int i; + for (i=1; iSetF("file://%s", channel); + } +} + +// +// Parse command line arguments +// +DrError DVertexCmdLineController:: + ParseImplicitCmdLine(int argc, char* argv[], + DVertexCommandBlock* startCommand) +{ + DrStr128 prefix; + bool expandInputs = false; + bool expandOutputs = false; + DVertexProcessStatus* initialState = startCommand->GetProcessStatus(); + + // + // todo: why 1024? + // + static UInt32 s_maxChannels = 1024;//2 + + DrStr128* inputChannels = new DrStr128[s_maxChannels]; + DrStr128* outputChannels = new DrStr128[s_maxChannels]; + + UInt32 numberOfInputChannels = 0; + UInt32 numberOfOutputChannels = 0; + + // + // foreach argument, check for well-known designators (-?, -i, -o, etc) + // and interpret according to designator rules + // + int i = 1; + while (i []\\\n" + " [-i input1 [-i input2] ...] [-o output1 [-o output2] ...]\n\n" + "Vertices that are registered in this executable:\n", + leafName); + + // + // Also, show and errors and usage messages for verticies + // + VertexFactoryRegistry::ShowAllVertexUsageMessages(stderr); + return DrError_Fail; + } + else if (::_stricmp(argv[i], "-i") == 0) + { + // + // If input specified, make sure parameter is valid, and then store input URI + // + if (expandInputs) + { + // + // If already specified --inputs, fail on duplicate method of specifying inputs + // + DrLogE("Cannot use -i together with --inputs"); + return DrError_Fail; + } + if (i+1 >= argc) + { + // + // If -i is the last parameter, there is an error, because no URI could be specified + // + DrLogE("-i flag with no input channel"); + return DrError_Fail; + } + if (numberOfInputChannels == s_maxChannels) + { + // + // If all input channels are already set, fail due to too many input channels + // + DrLogE("Too many input channels. Can't have more than %u inputs", s_maxChannels); + + return DrError_Fail; + } + + // + // Get the URI for the input channel + // + GetURI(&(inputChannels[numberOfInputChannels]), argv[i+1]); + ++numberOfInputChannels; + + // + // Remove the used arguments from the list + // + EliminateArguments(&argc, argv, i, 2); + } + else if (::_stricmp(argv[i], "-o") == 0) + { + // + // If input specified, make sure parameter is valid, and then store output URI + // + if (expandOutputs) + { + // + // If already specified --outputs, fail on duplicate method of specifying outputs + // + DrLogE("Cannot use -o together with --outputs"); + return DrError_Fail; + } + if (i+1 >= argc) + { + // + // If -o is the last parameter, there is an error, because no URI could be specified + // + DrLogE("-o flag with no output channel"); + return DrError_Fail; + } + if (numberOfOutputChannels == s_maxChannels) + { + // + // If all output channels are already set, fail due to too many output channels + // + DrLogE("Too many output channels. Can't have more than %u outputs", s_maxChannels); + return DrError_Fail; + } + + // + // Get the URI for the output channel + // + GetURI(&(outputChannels[numberOfOutputChannels]), argv[i+1]); + ++numberOfOutputChannels; + + // + // Remove the used arguments from the list + // + EliminateArguments(&argc, argv, i, 2); + } + else if (::_stricmp(argv[i], "--prefix") == 0) + { + // + // If I/O prefix specified, make sure parameters are valid and store the prefix URI + // + if (i+1 >= argc) + { + // + // If --prefix is the last parameter, there is an error, because no prefix string could be specified + // todo: if prefix specified multiple times, that should be logged at least, if not error + // + DrLogE("--prefix flag without prefix string"); + return DrError_Fail; + } + + // + // Get the URI for the prefix + // + GetURI(&prefix, argv[i+1]); + + // + // Remove the used arguments from the list + // + EliminateArguments(&argc, argv, i, 2); + } + else if (::_stricmp(argv[i], "--inputs") == 0) + { + // + // If --inputs specified, make sure parameters are valid and store number of channels + // + if (i+1 >= argc) + { + // + // If --inputs is the last parameter, there is an error, because no count can be specified + // + DrLogE("--inputs flag without a number"); + return DrError_Fail; + } + + if (numberOfInputChannels > 0) + { + // + // If input channels are already set, fail due to duplicate method for specifying inputs + // todo: Can also fail in this code path if --inputs set twice. Need to update error message. + // + DrLogE("Cannot use -i together with --inputs"); + return DrError_Fail; + } + + // + // Try to parse the next parameter as a count + // If this fails, numberOfInputChannels = 0 + // + numberOfInputChannels = atoi(argv[i + 1]); + if (numberOfInputChannels >= s_maxChannels) + { + // + // If more than max channels, fail + // + DrLogE("Too many input channels. Can't have more than %u inputs", s_maxChannels); + + return DrError_Fail; + } + + // + // Remember to expand the input list + // + expandInputs = true; + + // + // Remove the used arguments from the list + // + EliminateArguments(&argc, argv, i, 2); + } + else if (::_stricmp(argv[i], "--outputs") == 0) + { + // + // If --outputs specified, make sure parameters are valid and store number of channels + // + if (i+1 >= argc) + { + // + // If --outputs is the last parameter, there is an error, because no count can be specified + // + DrLogE("--outputs flag without a number"); + return DrError_Fail; + } + + if (numberOfOutputChannels > 0) + { + // + // If output channels are already set, fail due to duplicate method for specifying outputs + // todo: Can also fail in this code path if --outputs set twice. Need to update error message. + // + DrLogE("Cannot use -o together with --outputs"); + return DrError_Fail; + } + + // + // Try to parse the next parameter as a count + // If this fails, numberOfInputChannels = 0 + // + numberOfOutputChannels = atoi(argv[i + 1]); + if (numberOfOutputChannels >= s_maxChannels) + { + DrLogE("Too many output channels. Can't have more than %u outputs", s_maxChannels); + return DrError_Fail; + } + + // + // Remember to expand the output list + // + expandOutputs = true; + + // + // Remove the used arguments from the list + // + EliminateArguments(&argc, argv, i, 2); + } + else + { + // + // If parameter isn't well-known or removed while processing well-known + // arguments, just increment index and continue + // + ++i; + } + } + + UInt32 c; + DrStr128 channelURI; + + // + // If prefix not specified, and --inputs or --outputs used, fail + // + if (prefix.GetLength() == 0) + { + if (expandInputs) + { + DrLogE("--inputs flag without --prefix flag"); + return DrError_Fail; + } + if (expandOutputs) + { + DrLogE("--outputs flag without --prefix flag"); + return DrError_Fail; + } + } + + // + // Foreach input channel, set the channel URI + // + initialState->SetInputChannelCount(numberOfInputChannels); + for (c=0; cGetInputChannels()[c]); + channel->SetChannelState(DrError_OK); + if (!expandInputs) + { + // + // If using explicit names, just copy the temp name built in argument processing + // + channel->SetChannelURI(inputChannels[c]); + } + else + { + // + // If using implicit names, format the URI as [prefix]i[input index] + // + channelURI.SetF("%si%u", prefix.GetString(), c); + channel->SetChannelURI(channelURI); + } + } + + // + // Foreach output channel, set the channel URI + // + initialState->SetOutputChannelCount(numberOfOutputChannels); + for (c=0; cGetOutputChannels()[c]); + channel->SetChannelState(DrError_OK); + if (!expandOutputs) + { + // + // If using explicit names, just copy the temp name built in argument processing + // + channel->SetChannelURI(outputChannels[c]); + } + else + { + // + // If using implicit names, format the URI as [prefix]o[output index] + // + channelURI.SetF("%so%u", prefix.GetString(), c); + channel->SetChannelURI(channelURI); + } + } + + // + // Use all arguments not used here as arguments to start command + // + startCommand->SetArgumentCount(argc-1); + DrStr64* argVector = startCommand->GetArgumentVector(); + for (i=1; i startCommand; + startCommand.Attach(new DVertexCommandBlock()); + + { + // todo: make sure the vertex gets deleted before the command block + + // + // Create vertex reference + // + DrRef vertex; + vertex.Attach(new DryadSimpleChannelVertex()); + + + if (useExplicitCmdLine) + { + // + // Use explicit command line if --cmd or --cmdwait option used + // allows for dump restore or explicit -i/-o usage + // + ParseExplicitCmdLine(argc, argv, startCommand); + } + else + { + // + // Use implicit command line if no cmd option used + // allows for explicit -i/-o usage (different from above) or implicit --inputs/--outputs usage + // + DrError err = + ParseImplicitCmdLine(argc, argv, startCommand); + if (err != DrError_OK) + { + exitCode = 1; + } + } + + // + // If everything successful, initialize and run the vertex + // + if (exitCode == 0) + { + vertex->Initialize(this); + vertex->PrepareDryadVertex(startCommand->GetProcessStatus(), + factory, + startCommand->GetArgumentCount(), + startCommand->GetArgumentVector(), + startCommand->GetRawSerializedBlockLength(), + startCommand->GetRawSerializedBlock()); + + // + // Run the vertex - blocking + // + DrError vertexState = + vertex->RunDryadVertex(startCommand->GetProcessStatus(), + startCommand->GetArgumentCount(), + startCommand->GetArgumentVector()); + + // + // Log completion status + // + if (vertexState == DryadError_VertexCompleted) + { + DrLogI("Vertex exited without error"); + } + else + { + DrLogE("Vertex exited with error: %s", + DRERRORSTRING(vertexState)); + exitCode = 1; + } + } + } + + return exitCode; +} + +static void ShowChannelError(const char* channelType, + DrError err, + const char* channelURI, DryadMetaData* errorData) +{ + char* cosmosErrorString = DrGetErrorText(err); + if (errorData == NULL) + { + DrLogE("%s channel %s aborted %s, no error data", + channelType, channelURI, cosmosErrorString); + } + else + { + char* errorText = errorData->GetText(); + DrLogE("%s channel %s aborted %s: %s", + channelType, channelURI, cosmosErrorString, errorText); + delete [] errorText; + } + free(cosmosErrorString); +} + +void DVertexCmdLineController:: + AssimilateNewStatus(DVertexProcessStatus* status, + bool /*sendUpdate*/, bool /*notifyWaiters*/) +{ + UInt32 i; + + for (i=0; iGetInputChannelCount(); ++i) + { + DryadChannelDescription* c = &(status->GetInputChannels()[i]); + if (c->GetChannelState() != DrError_OK && + c->GetChannelState() != DryadError_ProcessingInterrupted && + c->GetChannelState() != DrError_EndOfStream) + { + LogAssert(c->GetChannelState() != DryadError_ChannelRestart); + ShowChannelError("Input", + c->GetChannelState(), + c->GetChannelURI(), + c->GetChannelMetaData()); + } + } + + for (i=0; iGetOutputChannelCount(); ++i) + { + DryadChannelDescription* c = &(status->GetOutputChannels()[i]); + if (c->GetChannelState() != DrError_OK && + c->GetChannelState() != DryadError_ProcessingInterrupted && + c->GetChannelState() != DrError_EndOfStream) + { + LogAssert(c->GetChannelState() != DryadError_ChannelRestart); + ShowChannelError("Output", + c->GetChannelState(), + c->GetChannelURI(), + c->GetChannelMetaData()); + } + } + + DryadMetaData* metaData = status->GetVertexMetaData(); + if (metaData != NULL) + { + const char* metaDataText = metaData->GetText(); + DrLogI("Vertex output status: %s", metaDataText); + delete [] metaDataText; + } +} diff --git a/DryadVertex/VertexHost/system/dprocess/src/dvertexcmdlinecontrol.h b/DryadVertex/VertexHost/system/dprocess/src/dvertexcmdlinecontrol.h new file mode 100644 index 0000000..b30cad9 --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/src/dvertexcmdlinecontrol.h @@ -0,0 +1,53 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include +#include +#include + +class DVertexCmdLineController : public DryadVertexController +{ +public: + UInt32 Run(int argc, char* argv[], + DryadVertexFactoryBase* factory, + bool useExplicitCmdLine); + void AssimilateNewStatus(DVertexProcessStatus* status, + bool sendUpdate, bool notifyWaiters); + +private: + void GetURI(DrStr* dst, const char* channel); + int GetChannelDescriptions(int argc, char* argv[], + DVertexProcessStatus* initialState, + bool isInput); + int GetChannelOverride(int argc, char* argv[], + DVertexProcessStatus* initialState, + bool isInput); + int GetTextOverride(int argc, char* argv[], + DVertexCommandBlock* startCommand); + int RestoreDumpedStartCommand(int argc, char* argv[], + DVertexCommandBlock* startCommand); + DrError ParseImplicitCmdLine(int argc, char* argv[], + DVertexCommandBlock* startCommand); + void ParseExplicitCmdLine(int argc, char* argv[], + DVertexCommandBlock* startCommand); +}; diff --git a/DryadVertex/VertexHost/system/dprocess/src/dvertexenvironment.cpp b/DryadVertex/VertexHost/system/dprocess/src/dvertexenvironment.cpp new file mode 100644 index 0000000..0878d46 --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/src/dvertexenvironment.cpp @@ -0,0 +1,84 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include + +// +// Constructor - set defaults +// no process or machine info +// quota, failure before abort, and failure threshold all set to maxint +// +DVertexEnvironment::DVertexEnvironment() +{ + m_process = NULL; + m_machine = NULL; + m_pnQuota = (UInt32)-1; + m_minNumberOfFailuresBeforeAbort = (UInt32)-1; + m_maxFailureThreshold = (UInt32)-1; +} + +// +// Destructor - do nothing +// todo: does m_process and m_machine need to be freed? +// +DVertexEnvironment::~DVertexEnvironment() +{ +} + +// +// Get the process info +// +DryadProcessIdentifier* DVertexEnvironment::GetPNProcess() +{ + return m_process; +} + +// +// Get the machine info +// +DryadMachineIdentifier* DVertexEnvironment::GetPNMachine() +{ + return m_machine; +} + +// +// Get the quota +// +UInt32 DVertexEnvironment::GetPNQuota() const +{ + return m_pnQuota; +} + +// +// Get the failure threshold +// +UInt32 DVertexEnvironment::GetMaxFailureThreshold() const +{ + return m_maxFailureThreshold; +} + +// +// Get the number of failures before abort +// +UInt32 DVertexEnvironment::GetMinNumberOfFailuresBeforeAbort() const +{ + return m_minNumberOfFailuresBeforeAbort; +} diff --git a/DryadVertex/VertexHost/system/dprocess/src/dvertexmain.cpp b/DryadVertex/VertexHost/system/dprocess/src/dvertexmain.cpp new file mode 100644 index 0000000..4ac7213 --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/src/dvertexmain.cpp @@ -0,0 +1,146 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include +#include +#include +#include +#include +#include + +#pragma managed + +// +// Eliminates arguments used already. leaves remaining arguments +// todo: should only have one copy of this and export it around. Already defined in vertexHost.cpp +// +static void EliminateArguments(int* pArgc, char* argv[], + int startingLocation, int numberToRemove) +{ + // + // Get total number of arguments. Check if there number to remove makes sense. + // + int argc = *pArgc; + LogAssert(argc >= startingLocation+numberToRemove); + + int i; + for (i=startingLocation+numberToRemove; i 1 && + ::strcmp(argv[1], "--cmd") == 0) + { + useExplicitCmdLine = true; + EliminateArguments(&argc, argv, 1, 1); + } + else if (argc > 1 && + ::strcmp(argv[1], "--cmdwait") == 0) + { + useExplicitCmdLine = true; + waitForInput = true; + EliminateArguments(&argc, argv, 1, 1); + } + else if (argc > 1 && + ::strcmp(argv[1], "--startfrompn") == 0) + { + useCmdLineController = false; + EliminateArguments(&argc, argv, 1, 1); + } + + UInt32 exitCode; + + // + // If --cmdwait was used, wait for user to press a key before continuing + // + if (waitForInput) + { + fprintf(stdout, "Press any key\n"); + fflush(stdout); + char buf[100]; + char* s = fgets(buf, 99, stdin); + LogAssert(s != NULL); + } + + // + // Run controller + // + { + if (useCmdLineController) + { + // + // If --cmd or --cmdwait supplied, create a cmd line controller and run + // + DVertexCmdLineController controller; + exitCode = controller.Run(argc, argv, + factory, useExplicitCmdLine); + } + else + { + // + // If --startfrompn used, create a Vertex Service controller and run + // + DVertexXComputePnControllerOuter controller; + exitCode = controller.Run(argc, argv); + } + } + + // + // After controller terminates, wait for user input again if --cmdwait + // + if (waitForInput) + { + fprintf(stdout, "Press any key\n"); fflush(stdout); + char buf[100]; + char* s = fgets(buf, 99, stdin); + LogAssert(s != NULL); + } + + return exitCode; +} diff --git a/DryadVertex/VertexHost/system/dprocess/src/dvertexpncontrol.cpp b/DryadVertex/VertexHost/system/dprocess/src/dvertexpncontrol.cpp new file mode 100644 index 0000000..741402e --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/src/dvertexpncontrol.cpp @@ -0,0 +1,1192 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "dvertexpncontrol.h" +#include +#include +#include +#include +#include + +#pragma managed + +// +// Constructor for controller associated with single vertex process +// +// +DVertexPnController::DVertexPnController(DVertexPnControllerOuter* parent, + UInt32 vertexId, UInt32 vertexVersion) +{ + m_currentCommandVersion = 0; + m_activeVertex = false; + m_waitingForTermination = false; + + m_parent = parent; + m_vertexId = vertexId; + m_vertexVersion = vertexVersion; + + // + // Create a new vertex + // + m_vertex.Attach(new DryadSimpleChannelVertex()); + DrLogI( "PN controller creating simple channel vertex. Vertex %u.%u", m_vertexId, m_vertexVersion); + + // + // Initialize the vertex and current vertex status + // + m_vertex->Initialize(this); + m_currentStatus.Attach(new DVertexStatus()); + m_currentStatus->SetVertexState(DrError_OK); + + DrLogI( "PN controller managing vertex. Vertex %u.%u", m_vertexId, m_vertexVersion); +} + +DVertexPnController::~DVertexPnController() +{ +} + +// +// Send updated status to GM +// +void DVertexPnController::SendStatus(UInt32 exitOnCompletion, + bool notifyWaiters) +{ + DrRef request; + + // + // Enter critical section + // + { + AutoCriticalSection acs(&m_baseDR); + + // + // If waiting for termination, don't send more status updates. + // Must report exiting as next status. + // todo: make sure m_waitingForTermination is set false when done + // + if (m_waitingForTermination) + { + DrLogI( "Skipping status update"); + return; + } + else if (exitOnCompletion != DrExitCode_StillActive) + { + // + // If not waiting for termination and still running, wait for termination + // + m_waitingForTermination = true; + } + + // + // Notify GM that we're done + // + request.Attach(MakeSetStatusRequest(exitOnCompletion, false, + notifyWaiters)); + m_currentStatus->StoreInRequestMessage(request); + } + + // + // Send status update + // + DrLogI( "Sending status update. Vertex %u.%u notifyWaiters %s", + m_currentStatus->GetProcessStatus()->GetVertexId(), + m_currentStatus->GetProcessStatus()->GetVertexInstanceVersion(), + (notifyWaiters) ? "true" : "false"); + SendSetStatusRequest(request); +} + +void DVertexPnController::SendAssertStatus(const char* assertString) +{ + DrRef request; + + if (m_baseDR.TryEnter()) + { + /* we got the lock, so we can keep the current status and + just add info that there's an assert failure. */ + m_waitingForTermination = true; + + DVertexProcessStatus* pStatus = + m_currentStatus->GetProcessStatus(); + + DryadMetaDataRef metaData = pStatus->GetVertexMetaData(); + if (metaData.Ptr() == NULL) + { + DryadMetaData::Create(&metaData); + pStatus->SetVertexMetaData(metaData); + } + + metaData->AppendString(Prop_Dryad_AssertFailure, + assertString, true); + + m_currentStatus->SetVertexState(DryadError_AssertFailure); + + request.Attach(MakeSetStatusRequest(DrExitCode_StillActive, + true, true)); + if (request != NULL) + { + m_currentStatus->StoreInRequestMessage(request); + } + } + else + { + /* we don't want to risk deadlock, so just send the assert + failure as a "raw" status message */ + DrRef rawStatus; + rawStatus.Attach(new DVertexStatus()); + + DVertexProcessStatus* pStatus = rawStatus->GetProcessStatus(); + + DryadMetaDataRef metaData; + DryadMetaData::Create(&metaData); + pStatus->SetVertexMetaData(metaData); + + metaData->AppendString(Prop_Dryad_AssertFailure, + assertString, true); + + rawStatus->SetVertexState(DryadError_AssertFailure); + + request.Attach(MakeSetStatusRequest(DrExitCode_StillActive, + true, true)); + if (request != NULL) + { + rawStatus->StoreInRequestMessage(request); + } + } + + if (request != NULL) + { + DrLogI( "Sending notification. AssertString=%s", assertString); + + SendSetStatusRequest(request); + + DrLogI( "Sent notification. AssertString=%s", assertString); + } +} + +void DVertexPnController:: + AssimilateChannelStatus(DryadChannelDescription* dst, + DryadChannelDescription* src) +{ + dst->SetChannelState(src->GetChannelState()); + dst->SetChannelMetaData(src->GetChannelMetaData()); + dst->SetChannelProcessedLength(src->GetChannelProcessedLength()); + dst->SetChannelTotalLength(src->GetChannelTotalLength()); +} + +void DVertexPnController::AssimilateNewStatus(DVertexProcessStatus* status, + bool sendUpdate, + bool notifyWaiters) +{ + { + AutoCriticalSection acs(&m_baseDR); + + UInt32 i; + + LogAssert(m_activeVertex == true); + + DVertexProcessStatus* currentPStatus = + m_currentStatus->GetProcessStatus(); + LogAssert(status->GetVertexId() == + currentPStatus->GetVertexId()); + LogAssert(status->GetVertexInstanceVersion() == + currentPStatus->GetVertexInstanceVersion()); + + currentPStatus->SetVertexMetaData(status->GetVertexMetaData()); + + LogAssert(status->GetInputChannelCount() == + currentPStatus->GetInputChannelCount()); + for (i=0; iGetInputChannelCount(); ++i) + { + AssimilateChannelStatus(&(currentPStatus->GetInputChannels()[i]), + &(status->GetInputChannels()[i])); + } + + LogAssert(status->GetOutputChannelCount() == + currentPStatus->GetOutputChannelCount()); + for (i=0; iGetOutputChannelCount(); ++i) + { + AssimilateChannelStatus(&(currentPStatus->GetOutputChannels()[i]), + &(status->GetOutputChannels()[i])); + } + } + + if (sendUpdate) + { + SendStatus(DrExitCode_StillActive, notifyWaiters); + } +} + +// +// Notify GM of completed vertex +// +void DVertexPnController::Terminate(DrError vertexState, + UInt32 exitCode) +{ + LogAssert(vertexState != DryadError_VertexRunning); + + if (vertexState == DrError_OK) + { + LogAssert(exitCode == DrExitCode_StillActive); + } + + // + // take Critical section to update vertex state + // + { + AutoCriticalSection acs(&m_baseDR); + + m_currentStatus->SetVertexState(vertexState); + } + + // + // Log vertex termination + // + DrLogI( "Terminating vertex. Vertex %u.%u exitCode %s vertexState %s", + m_currentStatus->GetProcessStatus()->GetVertexId(), + m_currentStatus->GetProcessStatus()->GetVertexInstanceVersion(), + DREXITCODESTRING(exitCode), DRERRORSTRING(vertexState)); + + // + // Send status to GM + // + SendStatus(exitCode, true); +} + +// +// Runs in new thread created and started by DVertexPnController::Start +// Executes vertex and terminates to report final status +// +unsigned DVertexPnController::ThreadFunc(void* arg) +{ + // + // Get Controller Reference, and then clean up unnecessary thread block + // + DVertexPnControllerThreadBlock* threadBlock = + (DVertexPnControllerThreadBlock *) arg; + DVertexPnController* self = threadBlock->m_parent; + DrRef startCommand = threadBlock->m_commandBlock; + + LogAssert(self != NULL,"Received NULL DVertexPnController pointer."); + + if(startCommand == NULL) + { + DrLogE("Received NULL DVertexCommandBlock pointer"); + self->Terminate(DrError_Fail, DrError_Fail); + return DrExitCode_Fail; + } + + + delete threadBlock; + threadBlock = NULL; + + // + // Execute dryad vertex - blocking + // + DrError vertexState = + self->m_vertex->RunDryadVertex(startCommand->GetProcessStatus(), + startCommand->GetArgumentCount(), + startCommand->GetArgumentVector()); + + // + // If vertex completed, then success, otherwise failure + // + UInt32 exitCode; + if (vertexState == DryadError_VertexCompleted) + { + exitCode = DrExitCode_OK; + } + else + { + exitCode = DrExitCode_Fail; + } + + // + // Enter critical section to turn off active vertex + // + { + AutoCriticalSection acs(&(self->m_baseDR)); + LogAssert(self->m_activeVertex == true); + self->m_activeVertex = false; + } + + // + // Notify GM of completed vertex + // + DrLogD( "About to terminate"); + self->Terminate(vertexState, exitCode); + + return DrExitCode_OK; +} + +// +// Create files which contain information used to restart the upcoming vertex command +// Used for post-mortem debugging. +// +void DVertexPnController::DumpRestartCommand(DVertexCommandBlock* commandBlock) +{ + DrError err; + + // + // Create temporary buffer + // + DrRef buf; + buf.Attach(new DrSimpleHeapBuffer()); + + // + // Write command block into buffer + // + { + DrMemoryBufferWriter writer(buf); + err = commandBlock->Serialize(&writer); + } + + // + // If write fails, log failure and return + // + if (err != DrError_OK) + { + DrLogE("Can't serialize command block for restart --- %s", + DRERRORSTRING(err)); + return; + } + + // + // Get data reference and byte count + // + const void* serializedData; + Size_t availableToRead; + serializedData = buf->GetReadAddress(0, &availableToRead); + LogAssert(availableToRead >= buf->GetAvailableSize()); + + // + // Get the process information + // + DVertexProcessStatus* ps = commandBlock->GetProcessStatus(); + + // + // Build file for data required for rerun, open it + // + DrStr64 restartBlockName; + restartBlockName.SetF("vertex-%u-%u-rerun-data.dat", + ps->GetVertexId(), ps->GetVertexInstanceVersion()); + FILE* fData = fopen(restartBlockName, "wb"); + if (fData == NULL) + { + // + // If failed to open file, log and return + // + err = DrGetLastError(); + DrLogE( + "Can't open re-run command block file '%s' --- %s", + restartBlockName.GetString(), DRERRORSTRING(err)); + return; + } + + // + // Build file for original information required for rerun, open it + // + DrStr64 originalInfoName; + originalInfoName.SetF("vertex-%u-%u-rerun-originalInfo.txt", + ps->GetVertexId(), ps->GetVertexInstanceVersion()); + FILE* fOriginalText = fopen(originalInfoName, "w"); + if (fOriginalText == NULL) + { + // + // If failed to open file, log and return + // + err = DrGetLastError(); + DrLogE( + "Can't open re-run original info file '%s' --- %s", + originalInfoName.GetString(), DRERRORSTRING(err)); + + // + // Close data file + // + fclose(fData); + return; + } + + // + // Build file for rerun command line, open it + // + DrStr64 originalRestartCommand; + originalRestartCommand.SetF("vertex-%u-%u-rerun.cmd", + ps->GetVertexId(), + ps->GetVertexInstanceVersion()); + FILE* fOriginalRestart = fopen(originalRestartCommand, "w"); + if (fOriginalRestart == NULL) + { + // + // If failed to open file, log and return + // + err = DrGetLastError(); + DrLogE( + "Can't open re-run original command file '%s' --- %s", + originalRestartCommand.GetString(), DRERRORSTRING(err)); + + // + // Close data and original text files + // + fclose(fData); + fclose(fOriginalText); + return; + } + + // + // Open file for local info + // + + /* BUG 16322: Do not create this for SP3, since it is currently broken. + Consider fixing for v4. + DrStr64 localInfoName; + localInfoName.SetF("vertex-%u-%u-rerun-localInfo.txt", + ps->GetVertexId(), ps->GetVertexInstanceVersion()); + FILE* fLocalText = fopen(localInfoName, "w"); + if (fLocalText == NULL) + { + // + // If failed to open file, log and return + // + err = DrGetLastError(); + DrLogE( + "Can't open re-run local info file '%s' --- %s", + localInfoName.GetString(), DRERRORSTRING(err)); + + // + // Close data, cmd, and original text files + // + fclose(fData); + fclose(fOriginalText); + fclose(fOriginalRestart); + return; + } + */ + + + // + // Open file for rerun with local inputs + // + + /* BUG 16322: Do not create this for SP3, since it is currently broken. + Consider fixing for v4. + DrStr64 localRestartCommand; + localRestartCommand.SetF("vertex-%u-%u-rerun-local-inputs.cmd", + ps->GetVertexId(), + ps->GetVertexInstanceVersion()); + FILE* fLocalRestart = fopen(localRestartCommand, "w"); + if (fLocalRestart == NULL) + { + // + // If failed to open file, log and return + // + err = DrGetLastError(); + DrLogE( + "Can't open re-run local command file '%s' --- %s", + localRestartCommand.GetString(), DRERRORSTRING(err)); + + // + // Close data, cmd, original, and local text files + // + fclose(fData); + fclose(fOriginalText); + fclose(fOriginalRestart); + fclose(fLocalText); + return; + } + */ + + // + // Open file for fetching inputs + // + /* BUG 16322: Do not create this for SP3, since it is currently broken. + Consider fixing for v4. + DrStr64 copyCommand; + copyCommand.SetF("vertex-%u-%u-rerun-fetch-inputs.cmd", + ps->GetVertexId(), ps->GetVertexInstanceVersion()); + FILE* fCopyCommand = fopen(copyCommand, "w"); + if (fCopyCommand == NULL) + { + // + // If failed to open file, log and return + // + err = DrGetLastError(); + DrLogE( + "Can't open re-run copy command file '%s' --- %s", + localRestartCommand.GetString(), DRERRORSTRING(err)); + + // + // Close data, original and localcmd, and original and local text files + // + fclose(fData); + fclose(fOriginalText); + fclose(fOriginalRestart); + fclose(fLocalText); + fclose(fLocalRestart); + return; + } + */ + + // + // Write out data to data file, then close it. + // + size_t written = fwrite(serializedData, 1, buf->GetAvailableSize(), fData); + fclose(fData); + if (written != buf->GetAvailableSize()) + { + // + // If failed to write all the data, log failure + // + err = DrGetLastError(); + DrLogE( + "Failed to write re-run command block file '%s': only %Iu of %Iu bytes written --- %s", + restartBlockName.GetString(), + written, (size_t) (buf->GetAvailableSize()), + DRERRORSTRING(err)); + } + + // + // Write original restart command + // + fprintf(fOriginalRestart, + "%s --vertex --cmd -dump %s -overridetext %s\n", + m_parent->GetRunningExePathName(), + restartBlockName.GetString(), + originalInfoName.GetString()); + + // + // Write local restart command + // + /* BUG 16322: Do not create this for SP3, since it is currently broken. + Consider fixing for v4. + fprintf(fLocalRestart, + "%s --vertex --cmd -dump %s -overridetext %s\n", + m_parent->GetRunningExePathName(), + restartBlockName.GetString(), + localInfoName.GetString()); + */ + + // + // Record number of input files + // + fprintf(fOriginalText, "%u # input files\n", ps->GetInputChannelCount()); + + /* BUG 16322: Do not create this for SP3, since it is currently broken. + Consider fixing for v4. + fprintf(fLocalText, "%u # input files\n", ps->GetInputChannelCount()); + */ + + // + // Get the input channels and foreach channel, add copy command to copy script + // + DryadInputChannelDescription* inputs = ps->GetInputChannels(); + for (UInt32 i=0; iGetInputChannelCount(); ++i) + { + const char* uri = inputs[i].GetChannelURI(); + + /* BUG 16322: Do not create this for SP3, since it is currently broken. + Consider fixing for v4. + if (::_strnicmp(uri, "file://", 7) == 0) + { + // + // If reading from file, copy command doesn't want "file://" prefix + // todo: remove reference to cosmos + // + fprintf(fCopyCommand, "cosmos.exe copy %s v%u.%u-i%u\n", + uri+7, + ps->GetVertexId(), ps->GetVertexInstanceVersion(), i); + } + else if (::_strnicmp(uri, "cosmos://", 9) == 0) + { + // + // If reading from cosmos path, copy directly + // todo: remove cosmos code + // + fprintf(fCopyCommand, "cosmos.exe copy %s v%u.%u-i%u\n", + uri, + ps->GetVertexId(), ps->GetVertexInstanceVersion(), i); + } + else + { + // + // Otherwise, unable to copy + // + fprintf(fCopyCommand, "echo can't copy URI %s to v%u.%u-i%u\n", + uri, + ps->GetVertexId(), ps->GetVertexInstanceVersion(), i); + } + */ + + // + // At reference to this URI to original and relative reference to local + // + fprintf(fOriginalText, "%s\n", uri); + + /* BUG 16322: Do not create this for SP3, since it is currently broken. + Consider fixing for v4. + fprintf(fLocalText, "file://v%u.%u-i%u\n", + ps->GetVertexId(), ps->GetVertexInstanceVersion(), i); + */ + } + + // + // Record number of output files + // + fprintf(fOriginalText, "%u # output files\n", ps->GetOutputChannelCount()); + + /* BUG 16322: Do not create this for SP3, since it is currently broken. + Consider fixing for v4. + fprintf(fLocalText, "%u # output files\n", ps->GetOutputChannelCount()); + */ + + // + // Get the output channels and record each one + // + DryadOutputChannelDescription* outputs = ps->GetOutputChannels(); + for (UInt32 i=0; iGetOutputChannelCount(); ++i) + { + const char* uri = outputs[i].GetChannelURI(); + + // + // Check if uri is writting to DSC partition. + // If it is, redirect to local temp file to avoid writing to sealed stream + // + DrStr uriMod(""); + if(ConcreteRChannel::IsDscPartition(uri)) + { + uriMod.AppendF("file://hpcdscpt_redirect_%d.dtf", i); + uri = uriMod.GetString(); + } + + fprintf(fOriginalText, "%s\n", uri); + + /* BUG 16322: Do not create this for SP3, since it is currently broken. + Consider fixing for v4. + fprintf(fLocalText, "%s\n", uri); + */ + } + + // + // Record number of arguments + // + fprintf(fOriginalText, "%u # arguments\n", + commandBlock->GetArgumentCount()); + + /* BUG 16322: Do not create this for SP3, since it is currently broken. + Consider fixing for v4. + fprintf(fLocalText, "%u # arguments\n", commandBlock->GetArgumentCount()); + */ + + // + // Foreach argument, record its value + // + for (UInt32 i=0; iGetArgumentCount(); ++i) + { + DrStr64 arg = commandBlock->GetArgumentVector()[i]; + fprintf(fOriginalText, "%s\n", arg.GetString()); + + /* BUG 16322: Do not create this for SP3, since it is currently broken. + Consider fixing for v4. + fprintf(fLocalText, "%s\n", arg.GetString()); + */ + } + + // + // Close all files + // todo: fData closed above, remove duplicate + // + fclose(fData); + fclose(fOriginalText); + fclose(fOriginalRestart); + + /* BUG 16322: Do not create this for SP3, since it is currently broken. + Consider fixing for v4. + fclose(fLocalText); + fclose(fLocalRestart); + fclose(fCopyCommand); + */ +} + +// +// Prepare to start user vertex code. Dump restart command, verify I/O channel health +// and spin up new thread to handle vertex execution +// Called from ActOnCommand when command is start. +// +void DVertexPnController::Start(DVertexCommandBlock* commandBlock) +{ + // + // Associate controller and command block + // + DVertexPnControllerThreadBlock* threadBlock = + new DVertexPnControllerThreadBlock(); + + threadBlock->m_parent = this; + threadBlock->m_commandBlock = commandBlock; + + // + // Write out all files needed to restart this vertex later + // + DumpRestartCommand(commandBlock); + + // + // Enter critical section for updating vertex status and getting channels' state + // + { + AutoCriticalSection acs(&m_baseDR); + + // + // Update vertex state to "running" after verifying vertex ok + // + LogAssert(m_currentStatus->GetVertexState() == DrError_OK); + m_currentStatus->SetVertexState(DryadError_VertexRunning); + + // + // Get current vertex process status ref and update with command parameters + // + DVertexProcessStatus* currentPStatus = m_currentStatus->GetProcessStatus(); + currentPStatus->CopyFrom(commandBlock->GetProcessStatus(), false); + + // + // Get input channels and foreach make sure state is ok + // + DryadInputChannelDescription* inputs = currentPStatus->GetInputChannels(); + UInt32 i; + for (i=0; iGetInputChannelCount(); ++i) + { + LogAssert(inputs[i].GetChannelState() == DrError_OK); + } + + + // + // Get output channels and foreach make sure state is ok + // + DryadOutputChannelDescription* outputs = currentPStatus->GetOutputChannels(); + for (i=0; iGetOutputChannelCount(); ++i) + { + LogAssert(outputs[i].GetChannelState() == DrError_OK); + } + + // + // Set vertex to active + // + LogAssert(m_activeVertex == false); + m_activeVertex = true; + } + + // + // Update vertex service with current, active, status + // + SendStatus(DrExitCode_StillActive, true); + + // + // Prepare the vertex for execution + // + m_vertex->PrepareDryadVertex(commandBlock->GetProcessStatus(), + NULL, + commandBlock->GetArgumentCount(), + commandBlock->GetArgumentVector(), + commandBlock->GetRawSerializedBlockLength(), + commandBlock->GetRawSerializedBlock()); + + // + // Start thread to execute vertex code + // + unsigned threadAddr; + HANDLE threadHandle = + (HANDLE) ::_beginthreadex(NULL, + 0, + DVertexPnController::ThreadFunc, + threadBlock, + 0, + &threadAddr); + + // + // Verify thread creation and exit + // + LogAssert(threadHandle != 0); + BOOL bRet = ::CloseHandle(threadHandle); + LogAssert(bRet != 0); +} + +void DVertexPnController::ReOpenChannels(DVertexCommandBlock* reOpenCommand) +{ + DVertexProcessStatus* newStatus = reOpenCommand->GetProcessStatus(); + + { + AutoCriticalSection acs(&m_baseDR); + + DVertexProcessStatus* currentPStatus = + m_currentStatus->GetProcessStatus(); + + LogAssert(newStatus->GetVertexId() == + currentPStatus->GetVertexId()); + LogAssert(newStatus->GetVertexInstanceVersion() == + currentPStatus->GetVertexInstanceVersion()); + + LogAssert(newStatus->GetInputChannelCount() == + currentPStatus->GetInputChannelCount()); + LogAssert(newStatus->GetOutputChannelCount() == + currentPStatus->GetOutputChannelCount()); + } + + m_vertex->ReOpenChannels(newStatus); +} + +// +// Interpret command received from GM +// +DrError DVertexPnController::ActOnCommand(DVertexCommandBlock* commandBlock) +{ + // + // Break into debugger if command asks for it + // + if (commandBlock->GetDebugBreak()) + { + ::DebugBreak(); + } + + DVertexCommand command = commandBlock->GetVertexCommand(); + DrError err = DrError_OK; + + // + // Critical section to issue commands + // + { + AutoCriticalSection acs(&m_baseDR); + + switch (command) + { + case DVertexCommand_Start: + // + // Command is to start a new vertex + // + if (m_currentStatus->GetVertexState() != DrError_OK) + { + // + // If vertex is in an error state, can't start it + // + err = DryadError_InvalidCommand; + } + else + { + // + // If vertex ok, then start with command from GM + // this is non-blocking and will return after creating new thread + // + DrLogI("Start command received."); + Start(commandBlock); + } + break; + + case DVertexCommand_ReOpenChannels: + // + // If reopen channels command, then reopen channels + // todo: find out if this is ever used + // + DrLogI("Reopen Channels command received."); + ReOpenChannels(commandBlock); + break; + + case DVertexCommand_Terminate: + // + // Terminate command + // todo: find out from Victor if this is ever used + // + DrError currentState; + currentState = m_currentStatus->GetVertexState(); + DrLogI("Terminate command received."); + if (m_waitingForTermination == false) + { + // + // If not waiting for termination already, terminate. + // + DrLogD( "About to terminate"); + if (currentState == DryadError_VertexCompleted) + { + // + // If already done, report cause of finish + // + Terminate(m_currentStatus->GetVertexState(), + DrExitCode_OK); + } + else + { + // + // If not yet done, terminate because of this + // + Terminate(DryadError_VertexReceivedTermination, + DrExitCode_Killed); + } + } + + // + // if waiting for termination, we can just fall through to the command loop again + // here: the process will exit as soon as the sendstatus completes + // + break; + + default: + // + // Invalid command + // + DrLogE("Unknown command received."); + err = DryadError_InvalidCommand; + break; + } + } + + return err; +} + +// +// Starts thread for vertex command loop +// called from DVertexPnControllerOuter.run +// +void DVertexPnController::LaunchCommandLoop() +{ + unsigned threadAddr; + + // + // Create new thread executing the DVertexPnController::CommandLoopStatic function + // + HANDLE threadHandle = + (HANDLE) ::_beginthreadex(NULL, + 0, + DVertexPnController::CommandLoopStatic, + this, + 0, + &threadAddr); + + // todo: handle errors in thread creation better + LogAssert(threadHandle != 0); + + // + // Close handle to this thread + // + BOOL bRet = ::CloseHandle(threadHandle); + LogAssert(bRet != 0); +} + +// +// Started in new thread per vertex to process commands +// +unsigned DVertexPnController::CommandLoopStatic(void* arg) +{ + // + // Get provided controller + // + DVertexPnController* self = (DVertexPnController *) arg; + + // + // Run command loop on controller + // + return self->CommandLoop(); +} + +// +// Constructor. Initializes without environment or controller list +// +DVertexPnControllerOuter::DVertexPnControllerOuter() +{ + m_controllerArray = NULL; + m_environment = NULL; +} + +void DVertexPnControllerOuter::AssertCallback(void* cookie, const char* assertString) +{ + DVertexPnControllerOuter* self = (DVertexPnControllerOuter *) cookie; + + if (assertString[0] == 'a' || assertString[0] == 'A') + { + self->SendAssertStatus(assertString); + } +} + +void DVertexPnControllerOuter::SendAssertStatus(const char* assertString) +{ + LONG postIncrement = ::InterlockedIncrement(&m_assertCounter); + + if (postIncrement == 1) + { + /* this is the first assert we've seen, so try to do something + useful with it */ + + UInt32 i; + for (i=0; iSendAssertStatus(assertString); + } + + DrLogI( "Sleeping"); + + /* wait to give the PN a chance to receive the messages if + they are in a send queue */ + ::Sleep(2000); + + DrLogI( "Sleeping done"); + + /* once we return the assert failure will be processed and the + program will terminate */ + } +} + +// +// Decrement number of active verticies and close process if done or failed +// +void DVertexPnControllerOuter::VertexExiting(int exitCode) +{ + // + // Take critical section to update the number of active verticies + // and exit the process if all verticies are complete or a failure occurred + // + { + AutoCriticalSection acs(&m_baseDR); + + LogAssert(m_activeVertexCount > 0); + --m_activeVertexCount; + if (m_activeVertexCount == 0 || exitCode != 0) + { + DrExitProcess(exitCode); + } + } +} + +// +// Return current exe path +// +const char* DVertexPnControllerOuter::GetRunningExePathName() +{ + return m_exePathName; +} + +// +// Return current environment +// +DVertexEnvironment* DVertexPnControllerOuter::GetEnvironment() +{ + return m_environment; +} + +// +// called from dvertexmain.cpp +// +UInt32 DVertexPnControllerOuter::Run(int argc, char* argv[]) +{ + // + // Create virtex environment and initialize it from current environment + // + m_environment = MakeEnvironment(); + DrError cse = m_environment->InitializeFromEnvironment(); + if (cse != DrError_OK) + { + DrLogE("Couldn't initialise environment"); + return 1; + } + + // + // Make sure there are at least two arguments + // + if (argc < 2) + { + DrLogE("No vertex arguments specified to the PN controller"); + return 1; + } + + // + // Get path and num verticies + // + m_exePathName = argv[0]; + m_numberOfVertices = atoi(argv[1]); + + // + // Fail if number of verticies cannot be converted + // todo: also fail if INT_MAX or INT_MIN returned denoting invalid range + // + if (m_numberOfVertices == 0) + { + DrLogE("No vertices specified to the PN controller"); + return 1; + } + + // + // If number of arguments isn't 2*numVerticies + 2, then it doesn't make sense + // + if ((UInt32) argc != (2 + 2*m_numberOfVertices)) + { + DrLogE( "%u vertices specified to the PN controller need " + "%u not %d arguments to describe them", + m_numberOfVertices, 2 + 2*m_numberOfVertices, argc); + return 1; + } + + // + // Set up array for controllers for each vertex + // + LogAssert(m_controllerArray == NULL); + m_controllerArray = new DVertexPnController* [m_numberOfVertices]; + LogAssert(m_controllerArray != NULL); + + // + // Critical section to update the number of active verticies + // + { + AutoCriticalSection acs(&m_baseDR); + m_assertCounter = 0; + m_activeVertexCount = m_numberOfVertices; + } + + // + // Foreach vertex, get command line arguments and make a controller + // + UInt32 i; + for (i=0; i + // + UInt32 vertexId = atoi(argv[2 + i*2]); + UInt32 vertexVersion = atoi(argv[2 + i*2 + 1]); + + // + // Make a new controller for each vertex + // + m_controllerArray[i] = MakePnController(vertexId, vertexVersion); + } + + // todo: not sure if this matters for us + /* disable assertion notification for now until logging deadlock + is fixed */ +// Logger::AddApplicationLogCallback(AssertCallback, this); + + // + // foreach vertex, launch the command loop + // todo: code cleanup: any reason not to launch command loop as soon as it's created? + // + for (i=0; iLaunchCommandLoop(); + } + + // + // Sleep forever - commandloop will take down process when instructed to do so + // + ::Sleep(INFINITE); + + return 0; +} diff --git a/DryadVertex/VertexHost/system/dprocess/src/dvertexpncontrol.h b/DryadVertex/VertexHost/system/dprocess/src/dvertexpncontrol.h new file mode 100644 index 0000000..d5e355d --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/src/dvertexpncontrol.h @@ -0,0 +1,111 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include + +class DVertexPnController; +class DVertexPnControllerOuter; + +// +// Structure for storing controller/commandblock association +// +struct DVertexPnControllerThreadBlock +{ + DVertexPnController* m_parent; + DrRef m_commandBlock; +}; + +class DVertexPnController : public DryadVertexController +{ +public: + DVertexPnController(DVertexPnControllerOuter* parent, + UInt32 vertexId, UInt32 vertexVersion); + virtual ~DVertexPnController(); + + void AssimilateNewStatus(DVertexProcessStatus* status, + bool sendUpdate, bool notifyWaiters); + void LaunchCommandLoop(); + void SendAssertStatus(const char* assertString); + +protected: + void SetChannelStatus(DryadChannelDescription* dst, + DryadChannelDescription* src); + void AssimilateChannelStatus(DryadChannelDescription* dst, + DryadChannelDescription* src); + + void SendStatus(UInt32 exitOnCompletion, bool notifyWaiters); + void Start(DVertexCommandBlock* commandBlock); + void ReOpenChannels(DVertexCommandBlock* reOpenCommand); + void Terminate(DrError vertexState, UInt32 exitCode); + DrError ActOnCommand(DVertexCommandBlock* commandBlock); + static unsigned ThreadFunc(void* arg); + static unsigned CommandLoopStatic(void* arg); + void DumpRestartCommand(DVertexCommandBlock* commandBlock); + + virtual DryadPnProcessPropertyRequest* + MakeSetStatusRequest(UInt32 exitOnCompletion, + bool isAssert, + bool notifyWaiters) = 0; + virtual void SendSetStatusRequest(DryadPnProcessPropertyRequest* r) = 0; + virtual unsigned CommandLoop() = 0; + + DVertexPnControllerOuter* m_parent; + UInt32 m_vertexId; + UInt32 m_vertexVersion; + DryadVertexRef m_vertex; + UInt64 m_currentCommandVersion; + bool m_activeVertex; + bool m_waitingForTermination; + DrRef m_currentStatus; + CRITSEC m_baseDR; +}; + +class DVertexEnvironment; + +class DVertexPnControllerOuter +{ +public: + DVertexPnControllerOuter(); + + UInt32 Run(int argc, char* argv[]); + void VertexExiting(int exitCode); + const char* GetRunningExePathName(); + DVertexEnvironment* GetEnvironment(); + +private: + void SendAssertStatus(const char* assertString); + static void AssertCallback(void* cookie, + const char* assertString); + + virtual DVertexEnvironment* MakeEnvironment() = 0; + virtual DVertexPnController* MakePnController(UInt32 vertexId, + UInt32 vertexVersion) = 0; + + DVertexEnvironment* m_environment; + DVertexPnController** m_controllerArray; + volatile LONG m_assertCounter; + UInt32 m_numberOfVertices; + UInt32 m_activeVertexCount; + DrStr64 m_exePathName; + CRITSEC m_baseDR; +}; diff --git a/DryadVertex/VertexHost/system/dprocess/src/dvertexxcomputeenvironment.cpp b/DryadVertex/VertexHost/system/dprocess/src/dvertexxcomputeenvironment.cpp new file mode 100644 index 0000000..13add4f --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/src/dvertexxcomputeenvironment.cpp @@ -0,0 +1,62 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include +#include +#include + +// +// Use current environment to initialize xcompute environment +// +DrError DVertexXComputeEnvironment::InitializeFromEnvironment() +{ + // + // Get process handle from dryad standalone ini + // + XCPROCESSHANDLE handle = GetProcessHandle(); + if (handle == INVALID_XCPROCESSHANDLE) + { + DrLogE("XCompute Process Handle is invalid"); + return DrError_Fail; + } + + // + // Build an identifier around process handle and attach to it + // + DryadXComputeProcessIdentifier* process = new DryadXComputeProcessIdentifier(handle); + + m_process.Attach(process); + + // + // Get node running xcompute process. Fail if invalid + // + XCPROCESSNODEID node; + XCERROR err = XcGetProcessNodeId(process->GetHandle(), &node); + LogAssert(err == DrError_OK); + LogAssert(node != INVALID_XCPROCESSNODEID); + + // + // attach to machine + // + m_machine.Attach(new DryadXComputeMachineIdentifier(node)); + + return DrError_OK; +} diff --git a/DryadVertex/VertexHost/system/dprocess/src/dvertexxcomputepncontrol.cpp b/DryadVertex/VertexHost/system/dprocess/src/dvertexxcomputepncontrol.cpp new file mode 100644 index 0000000..059e041 --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/src/dvertexxcomputepncontrol.cpp @@ -0,0 +1,596 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "dvertexxcomputepncontrol.h" +#include +#include +#include +#include +#include +#include +#include +#include + +class DVertexXComputeSetStatus : public DryadXComputePnProcessPropertyRequest +{ +public: + DVertexXComputeSetStatus(DVertexXComputePnController* parent, + DVertexPnControllerOuter* parentOuter, + UInt32 exitOnCompletion, + bool isAssert, + bool notifyWaiters); + + PXC_SETANDGETPROCESSINFO_REQINPUT MarshalProperty(); + PXC_SETANDGETPROCESSINFO_REQRESULTS* GetResults(); + + bool IsAssert(); + + void IncrementSendCount(); + UInt32 GetSendCount(); + + void Process(DrError err); + +private: + XC_SETANDGETPROCESSINFO_REQINPUT m_info; + + PXC_PROCESSPROPERTY_INFO m_propertyArray[2]; + XC_PROCESSPROPERTY_INFO m_payloadProperty; + XC_PROCESSPROPERTY_INFO m_controlProperty; + PXC_SETANDGETPROCESSINFO_REQRESULTS m_results; + + DVertexXComputePnController* m_parent; + DVertexPnControllerOuter* m_parentOuter; + UInt32 m_exitOnCompletion; + bool m_notifyWaiters; + bool m_isAssert; + + UInt32 m_sendCount; +}; + +// +// Constructor. Create Set Status request. +// +DVertexXComputeSetStatus:: + DVertexXComputeSetStatus(DVertexXComputePnController* parent, + DVertexPnControllerOuter* parentOuter, + UInt32 exitOnCompletion, + bool isAssert, + bool notifyWaiters) +{ + // + // Build m_info property + // + memset(&m_info, 0, sizeof(m_info)); + m_info.Size = sizeof(m_info); + m_info.NumberOfProcessPropertiesToSet = (notifyWaiters) ? 2 : 1; + m_info.ppPropertiesToSet = m_propertyArray; + + // + // Assign payload and control properties to array + // + m_propertyArray[0] = &m_payloadProperty; + m_propertyArray[1] = &m_controlProperty; + + memset(&m_payloadProperty, 0, sizeof(m_payloadProperty)); + m_payloadProperty.Size = sizeof(m_payloadProperty); + m_payloadProperty.PropertyVersion = MAX_UINT64; + + memset(&m_controlProperty, 0, sizeof(m_controlProperty)); + m_controlProperty.Size = sizeof(m_controlProperty); + m_controlProperty.PropertyVersion = MAX_UINT64; + + // + // Save parameters + // + m_parent = parent; + m_parentOuter = parentOuter; + m_exitOnCompletion = exitOnCompletion; + m_notifyWaiters = notifyWaiters; + m_isAssert = isAssert; + + m_sendCount = 0; +} + +PXC_SETANDGETPROCESSINFO_REQINPUT + DVertexXComputeSetStatus::MarshalProperty() +{ + DrLogI( "Marshaling property. Payload %s notify waiters %s", + m_label.GetString(), m_notifyWaiters ? "true" : "false"); + + m_payloadProperty.pPropertyLabel = m_label.GetString(); + m_payloadProperty.pPropertyString = m_string.GetString(); + + if (m_notifyWaiters) + { + m_controlProperty.pPropertyLabel = m_controlLabel.GetString(); + m_controlProperty.pPropertyString = "(Empty)"; + } + + Size_t propertySize = m_block->GetAvailableSize(); + + Size_t blockSize; + void* propertyData = m_block->GetDataAddress(0, &blockSize, NULL); + + LogAssert(blockSize >= propertySize); + + m_payloadProperty.PropertyBlockSize = propertySize; + m_payloadProperty.pPropertyBlock = propertyData; + + return &m_info; +} + +PXC_SETANDGETPROCESSINFO_REQRESULTS* + DVertexXComputeSetStatus::GetResults() +{ + return &m_results; +} + +// +// Get assertion level +// +bool DVertexXComputeSetStatus::IsAssert() +{ + return m_isAssert; +} + +// +// Increment the number of retries attempted +// +void DVertexXComputeSetStatus::IncrementSendCount() +{ + ++m_sendCount; +} + +// +// Return the number of retries attempted +// +UInt32 DVertexXComputeSetStatus::GetSendCount() +{ + return m_sendCount; +} + +// +// Handle response from xcompute +// +void DVertexXComputeSetStatus::Process(DrError err) +{ + // + // If status successfully sent, log success and check on vertex status + // + if (err == DrError_OK) + { + DrLogI( "PN send succeeded. label %s", m_label.GetString()); + + if (m_exitOnCompletion != DrExitCode_StillActive) + { + // + // If vertex is not still active, report that it is exiting + // this may kill this process if all verticies are complete + // + m_parentOuter->VertexExiting(m_exitOnCompletion); + } + + return; + } + + // + // If there was a communication failure, retry up to 4 times + // + if (err == DrError_RemoteDisconnected || + err == DrError_LocalDisconnected || + err == DrError_ConnectionFailed) + { + if (m_sendCount < 4) + { + DrLogW( "Retrying PN send. error %s", DRERRORSTRING(err)); + + m_parent->SendSetStatusRequest(this); + + return; + } + } + + // + // If m_isAssert, just report warning, otherwise log and fail + // todo: this seems backwards. I don't understand how m_isAssert is set. + // + if (m_isAssert) + { + DrLogW( + "Send to PN failed: not asserting again. done %u sends, error %s", + m_sendCount, DRERRORSTRING(err)); + } + else + { + DrLogA( + "Send to PN failed. done %u sends, error %s", + m_sendCount, DRERRORSTRING(err)); + } +} + +// +// Type of xcompute request handler that deals with "set status" operations +// +class XComputeSetStatusOverlapped : public DryadNativePort::HandlerBase +{ +public: + // + // Constructor + // + XComputeSetStatusOverlapped(DVertexXComputeSetStatus* request); + + // + // Property returning operation state of request + // + DrError* GetOperationState(); + +private: + // + // Process a set status request + // + void ProcessIO(DrError retval, UInt32 numBytes); + + // + // Set Status request and state + // + DrError m_operationState; + DrRef m_request; +}; + +// +// Assign associated request +// +XComputeSetStatusOverlapped:: + XComputeSetStatusOverlapped(DVertexXComputeSetStatus* request) +{ + m_request = request; +} + +// +// Returns a pointer to the operation state of the request +// +DrError* XComputeSetStatusOverlapped::GetOperationState() +{ + return &m_operationState; +} + +// +// Process a set status request +// +void XComputeSetStatusOverlapped::ProcessIO(DrError retval, UInt32 numBytes) +{ + // + // If there was an error, log and fail + // + if (retval != DrError_OK) + { + DrLogA( "Completion port returned error. error %s %u bytes", DRERRORSTRING(retval), numBytes); + } + + // + // If there is no data, log and fail + // + if (numBytes != 0) + { + DrLogA( "Completion port returned non-zero. %u bytes", numBytes); + } + + // + // Handle other outcomes based on the operation state + // + m_request->Process(m_operationState); + + delete this; +} + +// +// Constructor. Calls parent constructor only. +// +DVertexXComputePnController:: + DVertexXComputePnController(DVertexPnControllerOuter* parent, + UInt32 vertexId, UInt32 vertexVersion) : + DVertexPnController(parent, vertexId, vertexVersion) +{ +} + +// +// Create a new set status request +// +DryadPnProcessPropertyRequest* + DVertexXComputePnController::MakeSetStatusRequest(UInt32 exitOnCompletion, + bool isAssert, + bool notifyWaiters) +{ + return new DVertexXComputeSetStatus(this, m_parent, + exitOnCompletion, isAssert, + notifyWaiters); +} + +// +// Send updated status to vertex service +// +void DVertexXComputePnController:: + SendSetStatusRequest(DryadPnProcessPropertyRequest* r) +{ + // + // Cast request to required type and make sure it's valid + // + DVertexXComputeSetStatus* request = + dynamic_cast(r); + LogAssert(request != NULL); + + // + // Wrap request in XComputeSetStatusOverlapped + // + XComputeSetStatusOverlapped* overlapped = + new XComputeSetStatusOverlapped(request); + + // + // Create asynchronous execution information + // + XC_ASYNC_INFO asyncInfo; + memset(&asyncInfo, 0, sizeof(asyncInfo)); + asyncInfo.cbSize = sizeof(asyncInfo); + asyncInfo.pOperationState = overlapped->GetOperationState(); + asyncInfo.IOCP = g_dryadNativePort->GetCompletionPort(); + asyncInfo.pOverlapped = overlapped->GetOverlapped(); + + // + // Update request counters + // + request->IncrementSendCount(); + g_dryadNativePort->IncrementOutstandingRequests(); + + // + // Update process info + // + XCERROR err = + XcSetAndGetProcessInfo(GetProcessHandle(), + request->MarshalProperty(), + request->GetResults(), + &asyncInfo); + + LogAssert(err != DrError_OK); + + if (err != HRESULT_FROM_WIN32(ERROR_IO_PENDING)) + { + // + // If failed (other than due to pending IO) log failure and update request counter + // + g_dryadNativePort->DecrementOutstandingRequests(); + + // + // If request assertion true, report errors as warnings, otherwise report as error and fail + // todo: this still seems backwards - need to figure out rational + // request handles retries itself + // + if (request->IsAssert()) + { + DrLogW( + "Status request send failed synchronously during assert: not asserting again. done %u send tries, error %s", + request->GetSendCount(), DRERRORSTRING(err)); + } + else + { + DrLogA( + "Status request send failed synchronously. done %u send tries, error %s", + request->GetSendCount(), DRERRORSTRING(err)); + } + + delete overlapped; + } +} + +// +// Run in thread for each vertex +// +unsigned DVertexXComputePnController::CommandLoop() +{ + DrError err; + UInt32 retries = 0; + + // + // Get the vertex label + // + DrStr64 label; + DVertexCommandBlock::GetPnPropertyLabel(&label, + m_vertexId, + m_vertexVersion); + + // + // Wait for communication until error + // + do + { + // + // Create request to get vertex version property + // + XC_SETANDGETPROCESSINFO_REQINPUT request; + memset(&request, 0, sizeof(request)); + request.Size = sizeof(request); + request.pBlockOnPropertyLabel = label.GetString(); + request.BlockOnPropertyversionLastSeen = m_currentCommandVersion; + request.MaxBlockTime = XCTIMEINTERVAL_MINUTE; + // XXXX + request.pPropertyFetchTemplate = (char *) label.GetString(); + + // + // Send the request and check for errors + // + PXC_SETANDGETPROCESSINFO_REQRESULTS pResults = NULL; + err = XcSetAndGetProcessInfo(GetProcessHandle(), + &request, + &pResults, + NULL); + if (err == DrError_OK) + { + // + // If request successfully sent, store process status and exit code + // + DrLogI( "Got command property"); + retries = 0; + DrError processStatus = pResults->pProcessInfo->ProcessStatus; + DrExitCode exitCode = pResults->pProcessInfo->ExitCode; + + if (processStatus == DrError_OK || exitCode != DrExitCode_StillActive) + { + // + // If the PN thinks we have exited, so better make it so + // + err = DrError_Fail; + } + } + + // + // If request was successful and other process doesn't think we're done + // + if (err == DrError_OK) + { + if (pResults->pProcessInfo->NumberofProcessProperties != 0) + { + // + // Make sure there's only one property and it's the version + // + LogAssert(pResults->pProcessInfo-> + NumberofProcessProperties == 1); + PXC_PROCESSPROPERTY_INFO property = + pResults->pProcessInfo->ppProperties[0]; + LogAssert(::strcmp(property->pPropertyLabel, label) == 0); + + // + // Update vertex version + // + UInt64 newVersion = property->PropertyVersion; + if (newVersion < m_currentCommandVersion) + { + // + // If vertex version is less than the current version, fail (logic error) + // + DrLogE( + "Property version went back in time. Property %s old version %I64u new version %I64u", + label.GetString(), + m_currentCommandVersion, newVersion); + err = DrError_ProcessPropertyVersionMismatch; + } + else if (newVersion == m_currentCommandVersion) + { + // + // If version the same, report version the same + // + DrLogI( + "Command timeout with same version. Property %s version %I64u", + label.GetString(), m_currentCommandVersion); + } + else if (newVersion > m_currentCommandVersion) + { + // + // If new vertex version, let GM know what process is handling it + // + DrLogI( + "Property got new version. Property %s old version %I64u new version %I64u", + label.GetString(), + m_currentCommandVersion, newVersion); + + m_currentCommandVersion = newVersion; + + DrRef newCommand; + newCommand.Attach(new DVertexCommandBlock()); + + DrRef response; + response.Attach(new DryadXComputePnProcessPropertyResponse(pResults->pProcessInfo)); + + // + // Get new vertex command + // + err = newCommand->ReadFromResponseMessage(response, m_vertexId, m_vertexVersion); + + // + // If no errors in getting command, act on it. Log any failures below + // + if (err == DrError_OK) + { + err = ActOnCommand(newCommand); + } + } + } + } + else + { + // + // Log error and continue + // + DrLogE( "XcSetAndGetProcessInfo got error: %s", DRERRORSTRING(err)); + } + + // + // If the error is related to disconnection, retry up to 4 times + // + if (err == DrError_RemoteDisconnected || + err == DrError_LocalDisconnected || + err == DrError_ConnectionFailed || + err == DrError_ResponseDisconnect) + { + ++retries; + // todo: move 4 to global + if (retries < 4) + { + DrLogW( "Retrying get"); + err = DrError_OK; + } + } + + // + // If result was allocated, free it before next iteration + // + if (pResults != NULL) + { + XCERROR freeError = XcFreeMemory(pResults); + LogAssert(freeError == DrError_OK); + } + } while (err == DrError_OK); + + // + // Close this controller and take no more requests + // + DrLogD( "About to terminate"); + Terminate(err, DrExitCode_Fail); + + // + // Sleep forever, waiting for verticies to complete and take down the process + // + Sleep(INFINITE); + + return 0; +} + +// +// Create and return an XCompute environment +// +DVertexEnvironment* DVertexXComputePnControllerOuter::MakeEnvironment() +{ + return new DVertexXComputeEnvironment(); +} + +DVertexPnController* DVertexXComputePnControllerOuter:: + MakePnController(UInt32 vertexId, + UInt32 vertexVersion) +{ + return new DVertexXComputePnController(this, vertexId, vertexVersion); +} diff --git a/DryadVertex/VertexHost/system/dprocess/src/dvertexxcomputepncontrol.h b/DryadVertex/VertexHost/system/dprocess/src/dvertexxcomputepncontrol.h new file mode 100644 index 0000000..39e71b0 --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/src/dvertexxcomputepncontrol.h @@ -0,0 +1,59 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include + +class DVertexXComputePnControllerOuter; + +class DVertexXComputePnController : public DVertexPnController +{ +public: + DVertexXComputePnController(DVertexPnControllerOuter* parent, + UInt32 vertexId, UInt32 vertexVersion); + + void SendSetStatusRequest(DryadPnProcessPropertyRequest* r); + +private: + DryadPnProcessPropertyRequest* + MakeSetStatusRequest(UInt32 exitOnCompletion, + bool isAssert, + bool notifyWaiters); + unsigned CommandLoop(); + + // + // Get wrapper + // + DVertexXComputePnControllerOuter* GetParent(); +}; + +// +// Wrapper for PN controller and environment +// Used by dvertextmain.cpp:DryadVertexMain +// +class DVertexXComputePnControllerOuter : public DVertexPnControllerOuter +{ +private: + DVertexEnvironment* MakeEnvironment(); + DVertexPnController* MakePnController(UInt32 vertexId, + UInt32 vertexVersion); +}; diff --git a/DryadVertex/VertexHost/system/dprocess/src/dvertexyarnpncontrol.h b/DryadVertex/VertexHost/system/dprocess/src/dvertexyarnpncontrol.h new file mode 100644 index 0000000..dc2bf57 --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/src/dvertexyarnpncontrol.h @@ -0,0 +1,58 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "dvertexpncontrol.h" + +class DVertexXComputePnControllerOuter; + +class DVertexYarnPnController : public DVertexPnController +{ +public: + DVertexYarnPnController(DVertexPnControllerOuter* parent, + UInt32 vertexId, UInt32 vertexVersion); + + void SendSetStatusRequest(DryadPnProcessPropertyRequest* r); + +private: + DryadPnProcessPropertyRequest* + MakeSetStatusRequest(UInt32 exitOnCompletion, + bool isAssert, + bool notifyWaiters); + unsigned CommandLoop(); + + // + // Get wrapper + // + DVertexXComputePnControllerOuter* GetParent(); +}; + +// +// Wrapper for PN controller and environment +// Used by dvertextmain.cpp:DryadVertexMain +// +class DVertexYarnPnControllerOuter : public DVertexPnControllerOuter +{ +private: + DVertexEnvironment* MakeEnvironment(); + DVertexPnController* MakePnController(UInt32 vertexId, + UInt32 vertexVersion); +}; diff --git a/DryadVertex/VertexHost/system/dprocess/src/subgraphvertex.cpp b/DryadVertex/VertexHost/system/dprocess/src/subgraphvertex.cpp new file mode 100644 index 0000000..94f05ab --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/src/subgraphvertex.cpp @@ -0,0 +1,1408 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include +#include +#include +#include +#include +#include +#include + +static const UInt32 s_invalidVertex = (UInt32) ((Int32) -1); + +static const UInt32 s_internalWorkQueueThreads = 4; +static const UInt32 s_internalWorkQueueConcurrentThreads = 2; +static const UInt32 s_internalFifoLength = 20; + +class SGVDummyHandler : public RChannelItemWriterHandler +{ +public: + void ProcessWriteCompleted(RChannelItemType status, + RChannelItem* marshalFailureItemArray); +}; + +void SGVDummyHandler::ProcessWriteCompleted(RChannelItemType /*status*/, + RChannelItem* marshalFailureItem) +{ + if (marshalFailureItem != NULL) + { + marshalFailureItem->DecRef(); + } +} + + + +DryadSubGraphVertex::EdgeInfo::EdgeInfo() +{ + m_sourceVertex = 0; + m_sourcePort = 0; + m_destinationVertex = 0; + m_destinationPort = 0; +} + +DryadSubGraphVertex::EdgeInfo::~EdgeInfo() +{ +} + +void DryadSubGraphVertex::EdgeInfo::SetInfo(UInt32 sourceVertex, + UInt32 sourcePort, + UInt32 destinationVertex, + UInt32 destinationPort) +{ + m_sourceVertex = sourceVertex; + m_sourcePort = sourcePort; + m_destinationVertex = destinationVertex; + m_destinationPort = destinationPort; +} + +UInt32 DryadSubGraphVertex::EdgeInfo::GetSourceVertex() +{ + return m_sourceVertex; +} + +UInt32 DryadSubGraphVertex::EdgeInfo::GetSourcePort() +{ + return m_sourcePort; +} + +UInt32 DryadSubGraphVertex::EdgeInfo::GetDestinationVertex() +{ + return m_destinationVertex; +} + +UInt32 DryadSubGraphVertex::EdgeInfo::GetDestinationPort() +{ + return m_destinationPort; +} + +void DryadSubGraphVertex::EdgeInfo:: + MakeFifo(UInt32 fifoLength, WorkQueue* workQueue) +{ + LogAssert(m_reader == NULL); + LogAssert(m_writer == NULL); + + UInt32 uniquifier = RChannelFactory::GetUniqueFifoId(); + + DrStr64 fifoName; + fifoName.SetF("fifo://%u/internal-%u-%u.%u--%u.%u", + fifoLength, uniquifier, + m_sourceVertex, m_sourcePort, + m_destinationVertex, m_destinationPort); + + DVErrorReporter errorReporter; + RChannelFactory::OpenReader(fifoName, NULL, NULL, 1, NULL, 0, 0, workQueue, + &errorReporter, &m_reader, NULL); + LogAssert(errorReporter.NoError()); + RChannelFactory::OpenWriter(fifoName, NULL, NULL, 1, NULL, 0, NULL, + &errorReporter, &m_writer); + LogAssert(errorReporter.NoError()); + + m_reader->GetReader()->Start(NULL); + m_writer->GetWriter()->Start(); +} + +RChannelReader* DryadSubGraphVertex::EdgeInfo::GetReader() +{ + return m_reader->GetReader(); +} + +RChannelWriter* DryadSubGraphVertex::EdgeInfo::GetWriter() +{ + return m_writer->GetWriter(); +} + +void DryadSubGraphVertex::EdgeInfo::DiscardFifo() +{ + m_reader = NULL; + m_writer = NULL; +} + + +DryadSubGraphVertex::VertexInfo::VertexInfo() +{ + m_parent = NULL; + m_inputPortCount = 0; + m_inputEdge = NULL; + m_inputExternal = NULL; + m_inputChannel = NULL; + m_outputPortCount = 0; + m_outputEdge = NULL; + m_outputExternal = NULL; + m_outputChannel = NULL; + m_argumentCount = 0; + m_argument = NULL; + m_serializedBlockLength = 0; + m_serializedBlock = NULL; + m_canShareWorkQueue = false; + m_workQueue = NULL; + m_virtual = false; +} + +DryadSubGraphVertex::VertexInfo::~VertexInfo() +{ + delete [] m_inputEdge; + delete [] m_inputExternal; + delete [] m_inputChannel; + delete [] m_outputEdge; + delete [] m_outputExternal; + delete [] m_outputChannel; + delete [] m_argument; + delete [] m_serializedBlock; + if (m_workQueue != NULL) + { + m_workQueue->Stop(); + delete m_workQueue; + } +} + +DrError DryadSubGraphVertex::VertexInfo:: + MakeProgram(UInt32 numberOfWorkQueueThreads, + UInt32 concurrentWorkQueueThreads, + DryadMetaDataRef* pErrorData) +{ + LogAssert(m_vertexProgram == NULL); + + DrError err = VertexFactoryRegistry::MakeVertex(0, + GetVersion(), + m_inputPortCount, + m_outputPortCount, + NULL, + NULL, + m_metaData, + 0, 0, + m_argumentCount, + m_argument, + m_serializedBlockLength, + m_serializedBlock, + pErrorData, + &m_vertexProgram); + + if (err != DrError_OK) + { + return err; + } + + DrLogI( "Made vertex program. Vertex ID %u", m_id); + + LogAssert(m_workQueue == NULL); + if (m_canShareWorkQueue == false) + { + m_workQueue = new WorkQueue(numberOfWorkQueueThreads, + concurrentWorkQueueThreads); + m_workQueue->Start(); + + DrLogI( "Added private work queue. Vertex ID %u", m_id); + } + else + { + DrLogI( "Using shared work queue. Vertex ID %u", m_id); + } + + return err; +} + +DryadVertexProgramBase* + DryadSubGraphVertex::VertexInfo::GetVertexProgram() +{ + return m_vertexProgram; +} + +WorkQueue* DryadSubGraphVertex::VertexInfo::GetWorkQueue() +{ + return m_workQueue; +} + +void DryadSubGraphVertex::VertexInfo:: + SetInputPortCount(UInt32 inputPortCount) +{ + LogAssert(m_inputEdge == NULL); + LogAssert(m_inputExternal == NULL); + LogAssert(m_inputChannel == NULL); + m_inputPortCount = inputPortCount; + m_inputEdge = new EdgeInfo* [m_inputPortCount]; + LogAssert(m_inputEdge != NULL); + m_inputExternal = new bool [m_inputPortCount]; + LogAssert(m_inputExternal != NULL); + m_inputChannel = new RChannelReader* [m_inputPortCount]; + LogAssert(m_inputChannel != NULL); + UInt32 i; + for (i=0; iVertexCompleted(DryadError_VertexCompleted, NULL); + } + else + { + WorkQueue* workQueue = (m_workQueue == NULL) ? + (sharedWorkQueue) : (m_workQueue); + + const char** argv = new const char* [m_argumentCount]; + LogAssert(argv != NULL); + UInt32 i; + for (i=0; iMainAsync(workQueue, + m_inputPortCount, + m_inputChannel, + m_outputPortCount, + m_outputChannel, + this); + + delete [] argv; + } +} + +void DryadSubGraphVertex::VertexInfo::NotifyChannelsOfCompletion() +{ + if (m_virtual) + { + return; + } + + LogAssert(m_vertexProgram != NULL); + + DrLogI( + "Vertex telling channels about completion. Vertex ID %u status %s", m_id, + DRERRORSTRING(m_vertexProgram->GetErrorCode())); + + m_vertexProgram->NotifyChannelsOfCompletion(m_inputPortCount, + m_inputChannel, + m_outputPortCount, + m_outputChannel); +} + +void DryadSubGraphVertex::VertexInfo::DrainChannels() +{ + if (m_virtual || m_vertexProgram == NULL) + { + return; + } + + DrLogI( + "Vertex draining channels. Vertex ID %u", m_id); + + UInt32 i; + for (i=0; iDrain(); + DrLogI( + "Vertex drained internal input. Vertex ID %u port %u", m_id, i); + } + } + + bool errorCondition = false; + for (i=0; iDrain(DrTimeInterval_Zero, &writeTermination); + if (writeTermination != NULL) + { + if (writeTermination->GetType() != RChannelItem_EndOfStream) + { + DrLogI( + "Vertex internal output write error. Vertex ID %u port %u type %u", m_id, i, + writeTermination->GetType()); + errorCondition = true; + } + } + DrLogI( + "Vertex drained internal output. Vertex ID %u port %u", m_id, i); + } + } + + if (errorCondition && m_vertexProgram->NoError()) + { + m_vertexProgram->ReportError(DryadError_ChannelWriteError); + } +} + +void DryadSubGraphVertex::VertexInfo::ProgramCompleted() +{ + LogAssert(m_vertexProgram != NULL); + + DrLogI( + "Vertex completed. Vertex ID %u, status %s", m_id, + DRERRORSTRING(m_vertexProgram->GetErrorCode())); + + NotifyChannelsOfCompletion(); + + DrError err = m_vertexProgram->GetErrorCode(); + LogAssert(err != DryadError_VertexCompleted); + if (err == DrError_OK) + { + err = DryadError_VertexCompleted; + } + + m_parent->VertexCompleted(err, + m_vertexProgram->GetErrorMetaData()); +} + +DryadSubGraphVertex::DryadSubGraphVertex() +{ + m_virtualInput = s_invalidVertex; + m_virtualOutput = s_invalidVertex; + + m_numberOfVertices = 0; + m_vertex = NULL; + m_numberOfEdges = 0; + m_edge = NULL; + m_numberOfInputEdges = 0; + m_inputEdge = NULL; + m_numberOfOutputEdges = 0; + m_outputEdge = NULL; + + m_internalWorkQueueThreads = + s_internalWorkQueueThreads; + m_internalWorkQueueConcurrentThreads = + s_internalWorkQueueConcurrentThreads; + m_internalFifoLength = s_internalFifoLength; +} + +DryadSubGraphVertex::~DryadSubGraphVertex() +{ + delete [] m_edge; + delete [] m_vertex; + delete [] m_inputEdge; + delete [] m_outputEdge; +} + +// +// Prints out vertex identifier and warns about internal use +// todo: check if argument length >= 1 like in vertex.cpp +// +void DryadSubGraphVertex::Usage(FILE* f) +{ + fprintf(f, "%s: for internal use only\n\n", GetArgument(0)); +} + +void DryadSubGraphVertex::SetWorkQueueThreads(UInt32 workQueueThreads) +{ + m_internalWorkQueueThreads = workQueueThreads; +} + +void DryadSubGraphVertex:: + SetWorkQueueConcurrentThreads(UInt32 workQueueConcurrentThreads) +{ + m_internalWorkQueueConcurrentThreads = workQueueConcurrentThreads; +} + +void DryadSubGraphVertex::SetInternalFifoLength(UInt32 fifoLength) +{ + m_internalFifoLength = fifoLength; +} + +void DryadSubGraphVertex::ReadVertexInfo(DryadMetaData* vertexData, + UInt32 vertexIndex) +{ + LogAssert(NoError()); + + VertexInfo* vertex = &(m_vertex[vertexIndex]); + + UInt32 id; + if (vertexData->LookUpUInt32(Prop_Dryad_VertexId, &id) != DrError_OK) + { + ReportError(DryadError_VertexInitialization, "No VertexId Tag"); + return; + } + vertex->SetId(id); + + UInt32 version; + if (vertexData->LookUpUInt32(Prop_Dryad_VertexVersion, + &version) != DrError_OK) + { + ReportError(DryadError_VertexInitialization, "No VertexVersion Tag"); + return; + } + vertex->SetVersion(version); + + bool canShareWorkQueue; + if (vertexData->LookUpBoolean(Prop_Dryad_CanShareWorkQueue, + &canShareWorkQueue) != DrError_OK) + { + ReportError(DryadError_VertexInitialization, + "No CanShareWorkQueue Tag"); + return; + } + vertex->SetCanShareWorkQueue(canShareWorkQueue); + + UInt32 inputCount; + if (vertexData->LookUpUInt32(Prop_Dryad_InputPortCount, + &inputCount) != DrError_OK) + { + ReportError(DryadError_VertexInitialization, "No InputPortCount Tag"); + return; + } + vertex->SetInputPortCount(inputCount); + + UInt32 outputCount; + if (vertexData->LookUpUInt32(Prop_Dryad_OutputPortCount, + &outputCount) != DrError_OK) + { + ReportError(DryadError_VertexInitialization, "No OutputPortCount Tag"); + return; + } + vertex->SetOutputPortCount(outputCount); + + DryadMetaDataRef metaData; + if (vertexData->LookUpMetaData(DryadTag_VertexMetaData, + &metaData) == DrError_OK) + { + vertex->SetMetaData(metaData); + } + + UInt32 argc; + if (vertexData->LookUpUInt32(Prop_Dryad_VertexArgumentCount, + &argc) != DrError_OK) + { + ReportError(DryadError_VertexInitialization, "No ArgumentCount Tag"); + return; + } + vertex->SetArgumentCount(argc); + + DryadMetaDataRef argvArray; + DrError err = vertexData->LookUpMetaData(DryadTag_ArgumentArray, + &argvArray); + if (err != DrError_OK) + { + ReportError(DryadError_VertexInitialization, "No ArgumentArray Tag"); + return; + } + + DryadMTagUnknown* uTag = + vertexData->LookUpUnknownTag(Prop_Dryad_VertexSerializedBlock); + if (uTag == NULL) + { + ReportError(DryadError_VertexInitialization, + "No VertexSerializedBlock Tag"); + return; + } + else + { + vertex->SetRawSerializedBlock(uTag->GetDataLength(), uTag->GetData()); + } + + DryadMetaData::TagListIter endArray; + DryadMetaData::TagListIter iter = + argvArray->LookUpInSequence(NULL, &endArray); + UInt32 argIndex = 0; + while (iter != endArray) + { + if (argIndex == vertex->GetArgumentCount()) + { + ReportError(DryadError_VertexInitialization, + "Too many arguments in array"); + return; + } + DryadMTag* tag = *iter; + if (tag->GetTagValue() != Prop_Dryad_VertexArgument) + { + ReportError(DryadError_VertexInitialization, + "Unexpected tag in array"); + return; + } + if (tag->GetType() != DrPropertyTagType_String) + { + ReportError(DryadError_VertexInitialization, + "Unexpected tag type in array"); + return; + } + DryadMTagString* argString = (DryadMTagString *) tag; + vertex->SetArgumentValue(argIndex, argString->GetString()); + ++iter; + ++argIndex; + } + + if (argIndex != vertex->GetArgumentCount()) + { + ReportError(DryadError_VertexInitialization, + "Too few arguments in array"); + return; + } + + if (vertex->GetArgumentCount() > 0 && + ::strcmp(vertex->GetArgumentValue(0), "__INPUT__") == 0) + { + if (m_virtualInput != s_invalidVertex) + { + ReportError(DryadError_VertexInitialization, + "Multiple virtual input vertices"); + return; + } + + if (vertex->GetInputPortCount() != 0) + { + ReportError(DryadError_VertexInitialization, + "Virtual input has input channels"); + return; + } + + if (vertex->GetArgumentCount() != 1) + { + ReportError(DryadError_VertexInitialization, + "Virtual input has extra arguments"); + return; + } + + m_virtualInput = vertexIndex; + m_numberOfInputEdges = vertex->GetOutputPortCount(); + m_inputEdge = new EdgeInfo* [m_numberOfInputEdges]; + LogAssert(m_inputEdge != NULL); + UInt32 i; + for (i=0; iSetVirtual(); + } + else if (vertex->GetArgumentCount() > 0 && + ::strcmp(vertex->GetArgumentValue(0), "__OUTPUT__") == 0) + { + if (m_virtualOutput != s_invalidVertex) + { + ReportError(DryadError_VertexInitialization, + "Multiple virtual output vertices"); + return; + } + + if (vertex->GetOutputPortCount() != 0) + { + ReportError(DryadError_VertexInitialization, + "Virtual output has output channels"); + return; + } + + if (vertex->GetArgumentCount() != 1) + { + ReportError(DryadError_VertexInitialization, + "Virtual output has extra arguments"); + return; + } + + m_virtualOutput = vertexIndex; + m_numberOfOutputEdges = vertex->GetInputPortCount(); + m_outputEdge = new EdgeInfo* [m_numberOfOutputEdges]; + UInt32 i; + for (i=0; iSetVirtual(); + } + else + { + DryadMetaDataRef errorData; + err = vertex->MakeProgram(m_internalWorkQueueThreads, + m_internalWorkQueueConcurrentThreads, + &errorData); + ReportError(err, errorData); + } +} + +void DryadSubGraphVertex::ReadVertices(DryadMetaData* graphData) +{ + LogAssert(m_virtualInput == s_invalidVertex); + LogAssert(m_virtualOutput == s_invalidVertex); + + LogAssert(NoError()); + + DrError err = graphData->LookUpUInt32(Prop_Dryad_NumberOfVertices, + &m_numberOfVertices); + if (err != DrError_OK) + { + ReportError(DryadError_VertexInitialization, + "No NumberOfVertices Tag"); + return; + } + + m_vertex = new VertexInfo [m_numberOfVertices]; + LogAssert(m_vertex != NULL); + + DryadMetaDataRef vertexArray; + err = graphData->LookUpMetaData(DryadTag_VertexArray, &vertexArray); + if (err != DrError_OK) + { + ReportError(DryadError_VertexInitialization, + "No VertexArray Tag"); + return; + } + + DryadMetaData::TagListIter endArray; + DryadMetaData::TagListIter iter = + vertexArray->LookUpInSequence(NULL, &endArray); + UInt32 vertexIndex = 0; + while (iter != endArray) + { + if (vertexIndex == m_numberOfVertices) + { + ReportError(DryadError_VertexInitialization, + "Too many vertices in array"); + return; + } + DryadMTag* tag = *iter; + if (tag->GetTagValue() != DryadTag_VertexInfo) + { + ReportError(DryadError_VertexInitialization, + "Unexpected tag in array"); + return; + } + if (tag->GetType() != DryadPropertyTagType_MetaData) + { + ReportError(DryadError_VertexInitialization, + "Unexpected tag type in array"); + return; + } + DryadMTagMetaData* vertexData = (DryadMTagMetaData *) tag; + ReadVertexInfo(vertexData->GetMetaData(), vertexIndex); + + if (NoError() == false) + { + return; + } + + ++iter; + ++vertexIndex; + } + + if (vertexIndex != m_numberOfVertices) + { + ReportError(DryadError_VertexInitialization, + "Too few vertices in array"); + return; + } + + if (m_virtualInput == s_invalidVertex) + { + ReportError(DryadError_VertexInitialization, + "No virtual input vertex"); + return; + } + + if (m_virtualOutput == s_invalidVertex) + { + ReportError(DryadError_VertexInitialization, + "No virtual output vertex"); + return; + } + + LogAssert(NoError()); +} + +void DryadSubGraphVertex::ReadEdgeInfo(DryadMetaData* edgeData, + UInt32 edgeIndex) +{ + EdgeInfo* edge = &(m_edge[edgeIndex]); + + LogAssert(NoError()); + + UInt32 sourceVertex; + DrError err = edgeData->LookUpUInt32(Prop_Dryad_SourceVertex, + &sourceVertex); + if (err != DrError_OK) + { + ReportError(DryadError_VertexInitialization, "No SourceVertex Tag"); + return; + } + + if (sourceVertex >= m_numberOfVertices) + { + ReportError(DryadError_VertexInitialization, + "Bad Source Vertex Index"); + return; + } + + UInt32 sourcePort; + err = edgeData->LookUpUInt32(Prop_Dryad_SourcePort, + &sourcePort); + if (err != DrError_OK) + { + ReportError(DryadError_VertexInitialization, "No SourcePort Tag"); + return; + } + + if (sourcePort >= m_vertex[sourceVertex].GetOutputPortCount()) + { + ReportError(DryadError_VertexInitialization, "Bad Source Port Index"); + return; + } + + if (sourceVertex == m_virtualInput) + { + LogAssert(sourcePort < m_numberOfInputEdges); + if (m_inputEdge[sourcePort] != NULL) + { + ReportError(DryadError_VertexInitialization, + "Duplicate InputEdge Port"); + return; + } + m_inputEdge[sourcePort] = edge; + } + + UInt32 destinationVertex; + err = edgeData->LookUpUInt32(Prop_Dryad_DestinationVertex, + &destinationVertex); + if (err != DrError_OK) + { + ReportError(DryadError_VertexInitialization, + "No DestinationVertex Tag"); + return; + } + + if (destinationVertex >= m_numberOfVertices) + { + ReportError(DryadError_VertexInitialization, + "Bad Destination Vertex Index"); + return; + } + + UInt32 destinationPort; + err = edgeData->LookUpUInt32(Prop_Dryad_DestinationPort, + &destinationPort); + if (err != DrError_OK) + { + ReportError(DryadError_VertexInitialization, + "No DestinationPort Tag"); + return; + } + + if (destinationPort >= m_vertex[destinationVertex].GetInputPortCount()) + { + ReportError(DryadError_VertexInitialization, + "Bad Destination Port Index"); + return; + } + + if (destinationVertex == m_virtualOutput) + { + LogAssert(destinationPort < m_numberOfOutputEdges); + if (m_outputEdge[destinationPort] != NULL) + { + ReportError(DryadError_VertexInitialization, + "Duplicate OutputEdge Port"); + return; + } + m_outputEdge[destinationPort] = edge; + } + + edge->SetInfo(sourceVertex, sourcePort, + destinationVertex, destinationPort); + + LogAssert(NoError()); +} + +void DryadSubGraphVertex::ReadEdges(DryadMetaData* graphData) +{ + LogAssert(NoError()); + + DrError err = graphData->LookUpUInt32(Prop_Dryad_NumberOfEdges, + &m_numberOfEdges); + if (err != DrError_OK) + { + ReportError(DryadError_VertexInitialization, + "No NumberOfEdges Tag"); + return; + } + + m_edge = new EdgeInfo [m_numberOfEdges]; + LogAssert(m_edge != NULL); + + DryadMetaDataRef edgeArray; + err = graphData->LookUpMetaData(DryadTag_EdgeArray, &edgeArray); + if (err != DrError_OK) + { + ReportError(DryadError_VertexInitialization, "No EdgeArray Tag"); + return; + } + + DryadMetaData::TagListIter endArray; + DryadMetaData::TagListIter iter = + edgeArray->LookUpInSequence(NULL, &endArray); + UInt32 edgeIndex = 0; + while (iter != endArray) + { + if (edgeIndex == m_numberOfEdges) + { + ReportError(DryadError_VertexInitialization, + "Too many edges in array"); + return; + } + DryadMTag* tag = *iter; + if (tag->GetTagValue() != DryadTag_EdgeInfo) + { + ReportError(DryadError_VertexInitialization, + "Unexpected tag in array"); + return; + } + if (tag->GetType() != DryadPropertyTagType_MetaData) + { + ReportError(DryadError_VertexInitialization, + "Unexpected tag type in array"); + return; + } + DryadMTagMetaData* edgeData = (DryadMTagMetaData *) tag; + ReadEdgeInfo(edgeData->GetMetaData(), edgeIndex); + + if (NoError() == false) + { + return; + } + ++iter; + ++edgeIndex; + } + + if (edgeIndex != m_numberOfEdges) + { + ReportError(DryadError_VertexInitialization, + "Too few edges in array"); + return; + } + + LogAssert(NoError()); +} + +void DryadSubGraphVertex::Initialize(UInt32 numberOfInputChannels, + UInt32 numberOfOutputChannels) +{ + DrError err; + + DryadMetaData* metaData = GetMetaData(); + + if (metaData == NULL) + { + ReportError(DryadError_VertexInitialization, + "No MetaData In Start Command"); + return; + } + + DryadMetaDataRef graphData; + err = metaData->LookUpMetaData(DryadTag_GraphDescription, &graphData); + if (err != DrError_OK) + { + ReportError(DryadError_VertexInitialization, + "No GraphDescription Tag"); + return; + } + + ReadVertices(graphData); + + if (NoError() == false) + { + return; + } + + ReadEdges(graphData); +} + +void DryadSubGraphVertex::MakeInputParser(UInt32 whichInput, + RChannelItemParserRef* pParser) +{ + if (m_virtualInput == s_invalidVertex) + { + ReportError(DryadError_VertexInitialization, + "MakeInputParser called before initialization"); + return; + } + + if (whichInput >= m_numberOfInputEdges) + { + ReportError(DryadError_VertexInitialization, + "MakeInputParser called with illegal port"); + return; + } + + EdgeInfo* edge = m_inputEdge[whichInput]; + LogAssert(edge->GetSourceVertex() == m_virtualInput); + LogAssert(edge->GetSourcePort() == whichInput); + + UInt32 inputVertex = edge->GetDestinationVertex(); + LogAssert(inputVertex < m_numberOfVertices); + + VertexInfo* input = &(m_vertex[inputVertex]); + UInt32 inputPort = edge->GetDestinationPort(); + LogAssert(inputPort < input->GetInputPortCount()); + + DryadVertexProgramBase* subProgram = input->GetVertexProgram(); + + subProgram->MakeInputParser(inputPort, pParser); + + ReportError(subProgram->GetErrorCode(), + subProgram->GetErrorMetaData()); +} + +void DryadSubGraphVertex:: + MakeOutputMarshaler(UInt32 whichOutput, + RChannelItemMarshalerRef* pMarshaler) +{ + if (m_virtualOutput == s_invalidVertex) + { + ReportError(DryadError_VertexInitialization, + "MakeOutputMarshaler called before initialization"); + return; + } + + if (whichOutput >= m_numberOfOutputEdges) + { + ReportError(DryadError_VertexInitialization, + "MakeOutputMarshaler called will illegal port"); + return; + } + + EdgeInfo* edge = m_outputEdge[whichOutput]; + LogAssert(edge->GetDestinationVertex() == m_virtualOutput); + LogAssert(edge->GetDestinationPort() == whichOutput); + + UInt32 outputVertex = edge->GetSourceVertex(); + LogAssert(outputVertex < m_numberOfVertices); + + VertexInfo* output = &(m_vertex[outputVertex]); + UInt32 outputPort = edge->GetSourcePort(); + LogAssert(outputPort < output->GetOutputPortCount()); + + DryadVertexProgramBase* subProgram = output->GetVertexProgram(); + + subProgram->MakeOutputMarshaler(outputPort, pMarshaler); + + ReportError(subProgram->GetErrorCode(), + subProgram->GetErrorMetaData()); +} + +bool DryadSubGraphVertex::SetUpChannels(WorkQueue* sharedWorkQueue, + RChannelReader** inputChannel, + RChannelWriter** outputChannel) +{ + UInt32 i; + for (i=0; iGetSourceVertex(); + UInt32 dstVertex = edge->GetDestinationVertex(); + UInt32 srcPort = edge->GetSourcePort(); + UInt32 dstPort = edge->GetDestinationPort(); + + m_vertex[dstVertex].SetInputEdge(dstPort, edge); + m_vertex[srcVertex].SetOutputEdge(srcPort, edge); + + if (srcVertex == m_virtualInput) + { + LogAssert(dstVertex != m_virtualOutput); + m_vertex[dstVertex].SetInputChannel(dstPort, true, + inputChannel[srcPort]); + } + else if (dstVertex == m_virtualOutput) + { + m_vertex[srcVertex].SetOutputChannel(srcPort, true, + outputChannel[dstPort]); + } + else + { + WorkQueue* workQueue = m_vertex[dstVertex].GetWorkQueue(); + if (workQueue == NULL) + { + workQueue = sharedWorkQueue; + } + edge->MakeFifo(m_internalFifoLength, workQueue); + m_vertex[dstVertex].SetInputChannel(dstPort, false, + edge->GetReader()); + m_vertex[srcVertex].SetOutputChannel(srcPort, false, + edge->GetWriter()); + } + } + + for (i=0; i 0); + --m_outstandingVertices; + + if (m_outstandingVertices == 0) + { + finished = true; + } + } + + if (finished) + { + /* the real status will be returned in AsyncPostCompletion + after we have exited all handlers. */ + m_handler->ProgramCompleted(); + } +} + +void DryadSubGraphVertex:: + MainAsync(WorkQueue* workQueue, + UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel, + DryadVertexProgramCompletionHandler* handler) +{ + m_handler = handler; + + if (m_virtualInput == s_invalidVertex || + m_virtualOutput == s_invalidVertex) + { + ReportError(DryadError_VertexInitialization, + "MainAsync called before initialization"); + m_handler->ProgramCompleted(); + return; + } + + VertexInfo* virtualInput = &(m_vertex[m_virtualInput]); + if (virtualInput->GetOutputPortCount() != numberOfInputChannels) + { + ReportError(DryadError_VertexInitialization, + "Wrong number of input channels for subgraph"); + m_handler->ProgramCompleted(); + return; + } + + VertexInfo* virtualOutput = &(m_vertex[m_virtualOutput]); + if (virtualOutput->GetInputPortCount() != numberOfOutputChannels) + { + ReportError(DryadError_VertexInitialization, + "Wrong number of output channels for subgraph"); + m_handler->ProgramCompleted(); + return; + } + + if (SetUpChannels(workQueue, inputChannel, outputChannel) == false) + { + ReportError(DryadError_VertexInitialization, + "Missing edge or bad virtual vertices in subgraph"); + /* the real status will be returned in AsyncPostCompletion + after we have exited all handlers. */ + m_handler->ProgramCompleted(); + return; + } + + LaunchSubGraph(workQueue); +} + +void DryadSubGraphVertex::AsyncPostCompletion() +{ + ShutDownFifos(); + + if (GetErrorCode() == DrError_OK) + { + ReportError(DryadError_VertexCompleted); + } +} + +static StdTypedVertexFactory + s_subgraphFactory("Dryad.Core.Subgraph"); + +DryadVertexFactoryBase* g_subgraphFactory = &s_subgraphFactory; diff --git a/DryadVertex/VertexHost/system/dprocess/src/subgraphvertex.h b/DryadVertex/VertexHost/system/dprocess/src/subgraphvertex.h new file mode 100644 index 0000000..181a0c2 --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/src/subgraphvertex.h @@ -0,0 +1,202 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include +#include + +class RChannelFifo; +class WorkQueue; + +class DryadSubGraphVertex : public DryadVertexProgramBase +{ +public: + DryadSubGraphVertex(); + virtual ~DryadSubGraphVertex(); + + void Usage(FILE* f); + + void SetWorkQueueThreads(UInt32 workQueueThreads); + void SetWorkQueueConcurrentThreads(UInt32 workQueueConcurrentThreads); + void SetInternalFifoLength(UInt32 fifoLength); + + /* DryadVertexProgramBase interface */ + void Initialize(UInt32 numberOfInputChannels, + UInt32 numberOfOutputChannels); + void MainAsync(WorkQueue* workQueue, + UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel, + DryadVertexProgramCompletionHandler* handler); + void MakeInputParser(UInt32 whichInput, + RChannelItemParserRef* pParser); + void MakeOutputMarshaler(UInt32 whichOutput, + RChannelItemMarshalerRef* pMarshaler); + void AsyncPostCompletion(); + +private: + class EdgeInfo + { + public: + EdgeInfo(); + ~EdgeInfo(); + + void SetInfo(UInt32 sourceVertex, UInt32 sourcePort, + UInt32 destinationVertex, UInt32 destinationPort); + UInt32 GetSourceVertex(); + UInt32 GetSourcePort(); + UInt32 GetDestinationVertex(); + UInt32 GetDestinationPort(); + + void MakeFifo(UInt32 fifoLength, WorkQueue* workQueue); + void DiscardFifo(); + RChannelReader* GetReader(); + RChannelWriter* GetWriter(); + + private: + UInt32 m_sourceVertex; + UInt32 m_sourcePort; + UInt32 m_destinationVertex; + UInt32 m_destinationPort; + RChannelReaderHolderRef m_reader; + RChannelWriterHolderRef m_writer; + }; + + class VertexInfo : public DryadVertexProgramCompletionHandler + { + public: + VertexInfo(); + ~VertexInfo(); + + DrError MakeProgram(UInt32 numberOfWorkQueueThreads, + UInt32 concurrentWorkQueueThreads, + DryadMetaDataRef* pErrorData); + DryadVertexProgramBase* GetVertexProgram(); + WorkQueue* GetWorkQueue(); + + void SetId(UInt32 id); + UInt32 GetId(); + + void SetVersion(UInt32 id); + UInt32 GetVersion(); + + void SetInputPortCount(UInt32 inputPortCount); + UInt32 GetInputPortCount(); + bool SetInputChannel(UInt32 inputPort, bool isExternal, + RChannelReader* reader); + void SetInputEdge(UInt32 inputPort, EdgeInfo* edge); + + void SetOutputPortCount(UInt32 outputPortCount); + UInt32 GetOutputPortCount(); + bool SetOutputChannel(UInt32 outputPort, bool isExternal, + RChannelWriter* writer); + void SetOutputEdge(UInt32 outputPort, + EdgeInfo* edge); + + void SetArgumentCount(UInt32 outputPortCount); + UInt32 GetArgumentCount(); + void SetArgumentValue(UInt32 argument, const char* value); + const char* GetArgumentValue(UInt32 argument); + + void* GetRawSerializedBlock(); + UInt32 GetRawSerializedBlockLength(); + void SetRawSerializedBlock(UInt32 length, + const void* data); + + void SetMetaData(DryadMetaData* metaData); + DryadMetaData* GetMetaData(); + + void SetCanShareWorkQueue(bool canShareWorkQueue); + bool GetCanShareWorkQueue(); + + void SetVirtual(); + + bool Verify(); + + void Run(DryadSubGraphVertex* parent, + WorkQueue* sharedWorkQueue); + + /* DryadVertexProgramCompletionHandler interface */ + void ProgramCompleted(); + + void NotifyChannelsOfCompletion(); + void DrainChannels(); + + private: + DryadSubGraphVertex* m_parent; + bool m_virtual; + DryadVertexProgramRef m_vertexProgram; + WorkQueue* m_workQueue; + UInt32 m_id; + UInt32 m_version; + UInt32 m_inputPortCount; + EdgeInfo** m_inputEdge; + bool* m_inputExternal; + RChannelReader** m_inputChannel; + UInt32 m_outputPortCount; + EdgeInfo** m_outputEdge; + bool* m_outputExternal; + RChannelWriter** m_outputChannel; + UInt32 m_argumentCount; + DrStr64* m_argument; + DryadMetaDataRef m_metaData; + UInt32 m_serializedBlockLength; + char* m_serializedBlock; + bool m_canShareWorkQueue; + }; + + void ReadVertexInfo(DryadMetaData* vertexData, UInt32 vertexIndex); + void ReadVertices(DryadMetaData* graphData); + void ReadEdgeInfo(DryadMetaData* edgeData, UInt32 edgeIndex); + void ReadEdges(DryadMetaData* graphData); + bool SetUpChannels(WorkQueue* sharedWorkQueue, + RChannelReader** inputChannel, + RChannelWriter** outputChannel); + void LaunchSubGraph(WorkQueue* sharedWorkQueue); + void VertexCompleted(DrError status, + DryadMetaData* errorData); + void ShutDownFifos(); + + + UInt32 m_internalWorkQueueThreads; + UInt32 m_internalWorkQueueConcurrentThreads; + UInt32 m_internalFifoLength; + + UInt32 m_virtualInput; + UInt32 m_virtualOutput; + + UInt32 m_numberOfVertices; + VertexInfo* m_vertex; + UInt32 m_numberOfEdges; + EdgeInfo* m_edge; + UInt32 m_numberOfInputEdges; + EdgeInfo** m_inputEdge; + UInt32 m_numberOfOutputEdges; + EdgeInfo** m_outputEdge; + + DryadVertexProgramCompletionHandler* m_handler; + UInt32 m_outstandingVertices; + + CRITSEC m_baseDR; + DRREFCOUNTIMPL_BASE(DryadVertexProgramBase) +}; diff --git a/DryadVertex/VertexHost/system/dprocess/src/vertexfactory.cpp b/DryadVertex/VertexHost/system/dprocess/src/vertexfactory.cpp new file mode 100644 index 0000000..2cdb3c5 --- /dev/null +++ b/DryadVertex/VertexHost/system/dprocess/src/vertexfactory.cpp @@ -0,0 +1,403 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// Includes +// +#include +#include +#include + +// +// Create a vertex factory and register it +// +DryadVertexFactoryBase::DryadVertexFactoryBase(const char* name) +{ + m_name = name; + VertexFactoryRegistry::RegisterFactory(this); +} + +// +// Destructor does nothing +// +DryadVertexFactoryBase::~DryadVertexFactoryBase() +{ +} + +const char* DryadVertexFactoryBase::GetName() +{ + return m_name; +} + +// +// The Register method does nothing, but can be used to pull in static +// factories from other compilation units. +// +void DryadVertexFactoryBase::Register() +{ +} + +// +// Create reference to new program +// +DryadVertexProgramRef DryadVertexFactoryBase::MakeUntyped() +{ + DryadVertexProgramRef p; + p.Attach(NewUntyped()); + return p; +} + +// +// Have factory create new program base +// +DryadVertexProgramBase* DryadVertexFactoryBase::NewUntyped() +{ + return NewUntypedInternal(); +} + +/* this is guaranteed to be set to zero before anyone's constructor is + called */ +static VertexFactoryRegistry* s_dryadVertexRegistry; + +// +// Create new empty factory registry +// +VertexFactoryRegistry::VertexFactoryRegistry() +{ + m_registeredNULL = false; +} + +// +// Register a factory +// +void VertexFactoryRegistry::RegisterFactory(DryadVertexFactoryBase* factory) +{ + // + // If there is no registry, create one + // + if (s_dryadVertexRegistry == 0) + { + s_dryadVertexRegistry = new VertexFactoryRegistry(); + } + + VertexFactoryRegistry* self = s_dryadVertexRegistry; + + const char* name = factory->GetName(); + if (name == NULL) + { + // + // If no factory name, this is an error. + // RegisterFactory gets called during static initializations, + // so just save the errors and log properly in :LookupFactory + // or :ShowAllVertexUsageMessages + // + self->m_registeredNULL = true; + } + else + { + FactoryMap::iterator existing = self->m_factories.find(name); + if (existing != self->m_factories.end()) + { + // + // If factory name already in registry, this is an error + // RegisterFactory gets called during static + // initializations, so just save the errors and log + // properly in :LookupFactory or :ShowAllVertexUsageMessages + // + self->m_errorSet.insert(name); + } + else + { + // + // If factory name doesn't exist, add it to list of registered factories (key is name). + // + self->m_factories.insert(std::make_pair(name, factory)); + } + } +} + +// +// Get Vertex factory from name +// +DryadVertexFactoryBase* VertexFactoryRegistry::LookUpFactory(const char* name) +{ + // + // Create a vertex registry if none exist + // + if (s_dryadVertexRegistry == 0) + { + s_dryadVertexRegistry = new VertexFactoryRegistry(); + } + + VertexFactoryRegistry* self = s_dryadVertexRegistry; + + // + // If NULL name was provided to the registry on initialization, log the error + // + if (self->m_registeredNULL) + { + DrLogE("Factory Registered With illegal NULL name"); + } + + // + // If there are any elements in m_errorSet, multiple factories of the same + // were attempted to be registered. Log this error. + // + if (self->m_errorSet.empty() == false) + { + DuplicateSet::iterator i; + for (i = self->m_errorSet.begin(); i != self->m_errorSet.end(); ++i) + { + DrLogE("Duplicate Factory Registered. Factory name: %s", i->c_str()); + } + } + + // + // If any errors, report that they exist at the assert level + // + if (self->m_registeredNULL || self->m_errorSet.empty() == false) + { + DrLogA("Factory Registration Errors"); + } + + // + // Find a factory with the provided name + // + FactoryMap::iterator factory = self->m_factories.find(name); + if (factory != self->m_factories.end()) + { + // + // If found factory, return reference to it + // + return factory->second; + } + else + { + // + // If no factory found, return null + // + return NULL; + } +} + +// +// Check factory registry and print out any immediate errors and usage instructions +// +void VertexFactoryRegistry::ShowAllVertexUsageMessages(FILE* f) +{ + // + // If factory registry not yet initialized, report failure and exit. + // todo: ensure we want to fprintf here and not write to log + // + if (s_dryadVertexRegistry == 0) + { + fprintf(f, "Factory Registry usage called before initialization\n\n"); + return; + } + + VertexFactoryRegistry* self = s_dryadVertexRegistry; + + // + // If NULL name was provided to the registry on initialization, log the error + // + if (self->m_registeredNULL) + { + DrLogE("Factory Registered With illegal NULL name"); + } + + // + // If there are any elements in m_errorSet, multiple factories of the same + // were attempted to be registered. Log this error. + // + if (self->m_errorSet.empty() == false) + { + DuplicateSet::iterator i; + for (i = self->m_errorSet.begin(); i != self->m_errorSet.end(); ++i) + { + DrLogE("Duplicate Factory Registered. Factory name: %s", i->c_str()); + } + } + + // + // If any errors, report that they exist at the assert level + // + if (self->m_registeredNULL || self->m_errorSet.empty() == false) + { + DrLogA("Factory Registration Errors"); + } + + // + // Foreach factory, make sure the argument count is 1 and print out any + // usage information + // + FactoryMap::iterator factory; + for (factory = self->m_factories.begin(); + factory != self->m_factories.end(); ++factory) + { + DryadVertexProgramRef program = factory->second->MakeUntyped(); + LogAssert(program->GetArgumentCount() == 1); + program->Usage(f); + } +} + +// +// Get factory and have it create a vertex program +// +DrError VertexFactoryRegistry::MakeVertex(UInt32 vertexId, + UInt32 vertexVersion, + UInt32 numberOfInputChannels, + UInt32 numberOfOutputChannels, + UInt64* expectedLength, + DryadVertexFactoryBase* factory, + DryadMetaData* metaData, + UInt32 maxInputChannels, + UInt32 maxOutputChannels, + UInt32 argumentCount, + DrStr64* argumentList, + UInt32 serializedBlockLength, + const void* serializedBlock, + DryadMetaDataRef* pErrorData, + DryadVertexProgramRef* pProgram) +{ + DryadVertexProgramBase* program; + + // + // If factory is not supplied, try to get it from factory registry. + // If still unable, fail with error. + // + if (factory == NULL) + { + // + // If no arguments, return vertex initialization error + // + if (argumentCount == 0) + { + DryadMetaData::Create(pErrorData); + (*pErrorData)->AddErrorWithDescription(DryadError_VertexInitialization, + "Factory Registry called with no arguments"); + return DryadError_VertexInitialization; + } + + // + // Get vertex factory. If one cannot be found, report initialization error + // + factory = LookUpFactory(argumentList[0]); + if (factory == NULL) + { + DrStr128 errorString; + errorString.SetF("Factory Registry called with unknown factory UID %s", + argumentList[0].GetString()); + DryadMetaData::Create(pErrorData); + (*pErrorData)->AddErrorWithDescription(DryadError_VertexInitialization, + errorString); + return DryadError_VertexInitialization; + } + + // + // report new vertex creation + // + DrLogI( "Factory making new vertex. Vertex %s with %u arguments %u inputs %u outputs", + argumentList[0].GetString(), argumentCount, + numberOfInputChannels, numberOfOutputChannels); + + // + // Make a program and ensure that program argument is first argument in list + // todo: figure out how first argument in program is specified/what it is + // + *pProgram = factory->MakeUntyped(); + program = *pProgram; + LogAssert(program->GetArgumentCount() == 1); + LogAssert(::strcmp(program->GetArgument(0), argumentList[0]) == 0); + + // + // adjust the argument count and list to throw away the first + // argument naming which vertex was to be created + // + ++argumentList; + --argumentCount; + } + else + { + // + // If a factory exists, then we have been told which factory to use and are + // not going to read it from the argument list + // + *pProgram = factory->MakeUntyped(); + program = *pProgram; + LogAssert(program->GetArgumentCount() == 1); + } + + // + // Update program with vertext id and version + // + program->SetVertexId(vertexId); + program->SetVertexVersion(vertexVersion); + + // + // Let program know how long the input channels are + // + if (expectedLength != NULL) + { + program->SetExpectedInputLength(numberOfInputChannels, expectedLength); + } + + // + // Add remaining arguments to arg list + // + UInt32 i; + for (i=0; iAddArgument(argumentList[i]); + + DrLogI( "Factory adding vertex argument. Vertex %s arguments %u=%s", + program->GetArgument(0), i, argumentList[i].GetString()); + } + + // + // Set max number of channels and other metadata + // + program->SetMaxOpenInputChannelCount(maxInputChannels); + program->SetMaxOpenOutputChannelCount(maxOutputChannels); + program->SetMetaData(metaData); + + // + // Copy data into program + // + DrRef buffer; + buffer.Attach(new DrSimpleHeapBuffer()); + buffer->Append(serializedBlock, serializedBlockLength); + DrMemoryBufferReader reader(buffer); + program->DeSerialize(&reader); + + if (program->GetErrorCode() == DrError_OK) + { + // + // If still ok, initialize the program with the expected number of input and output channels + // todo: this doesn't do anything in base type. Figure out if base type or derived type. + // + program->Initialize(numberOfInputChannels, numberOfOutputChannels); + } + + // + // Report any errors + // + *pErrorData = program->GetErrorMetaData(); + return program->GetErrorCode(); +} diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfo/FifoChannel.cpp b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/FifoChannel.cpp new file mode 100644 index 0000000..ecf3c98 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/FifoChannel.cpp @@ -0,0 +1,255 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "stdafx.h" + +#include +#include +#include +#include + +#pragma unmanaged + + +FifoChannelItemWriterHandler::~FifoChannelItemWriterHandler() +{ +// nothing needed for now. +} + +void FifoChannelItemWriterHandler::ProcessWriteCompleted(RChannelItemType status, + RChannelItem* marshalFailureItem) +{ +// NYI +} + + + +FifoChannel::FifoChannel(RChannelReader* reader, + RChannelWriter *writer, + DryadVertexProgram* vertex, + TransformType transType) +{ + m_initialHandlerSent = false; + m_vertex = vertex; + switch (transType) + { + case TT_NullTransform: + m_transform = new NullChannelTransform(vertex); + break; +#ifdef LINKWITHZLIB + case TT_GzipCompression: + m_transform = new GzipCompressionChannelTransform(vertex, true, false); + break; + case TT_GzipFastCompression: + m_transform = new GzipCompressionChannelTransform(vertex, true, true); + break; + case TT_GzipDecompression: + m_transform = new GzipDecompressionChannelTransform(vertex, true); + break; + case TT_DeflateCompression: + m_transform = new GzipCompressionChannelTransform(vertex, false, false); + break; + case TT_DeflateFastCompression: + m_transform = new GzipCompressionChannelTransform(vertex, false, true); + break; + case TT_DeflateDecompression: + m_transform = new GzipDecompressionChannelTransform(vertex, false); + break; +#endif + default: + DrLogE("Invalid compressionScheme."); + LogAssert(false); + } + m_reader = reader; + m_writer = writer; + m_shutdownEvent = CreateEvent(NULL, false, false, NULL); + LogAssert(m_shutdownEvent != NULL); + m_numItemsInFlight = 0; + m_state = RS_Stopped; + m_fifoWriterHandler = NULL; +} + +FifoChannel::~FifoChannel() +{ + LogAssert(m_state == RS_Stopped); + + delete m_transform; + m_transform = NULL; + + BOOL bRet = ::CloseHandle(m_shutdownEvent); + LogAssert(bRet != 0); + + if (m_fifoWriterHandler) + { + delete m_fifoWriterHandler; + m_fifoWriterHandler = NULL; + } +} + +bool FifoChannel::Stop(RChannelWriter* fifoWriter) +{ + DrLogI( "Stopping fifochannel"); + bool mustWait = false; + bool mustTerminate = false; + + { + AutoCriticalSection acs(&m_critsec); + + if (m_state == RS_OutstandingHandler) + { + mustTerminate = true; + mustWait = true; + } + + if (m_state == RS_Stopping) + { + mustWait = true; + } + } + + if (mustTerminate) + { + DrLogI( "Sending fifochannel termination"); + RChannelItemRef termination; + termination.Attach(RChannelMarkerItem::Create(RChannelItem_EndOfStream, + false)); + if (fifoWriter == m_writer) + { + this->WriteTransformedItem(termination.Ptr()); + } + else + { + m_fifoWriterHandler = new FifoChannelItemWriterHandler(); + fifoWriter->WriteItem(termination, false, m_fifoWriterHandler); + } + } + return mustWait; +} + +void FifoChannel::Drain(bool mustWait) +{ + DrLogI( "Waiting for fifochannel"); + + if (mustWait) + { + + WaitForSingleObject(m_shutdownEvent, INFINITE); + } + + DrLogI( "Waiting for fifochannel done"); + { + AutoCriticalSection acs(&m_critsec); + + m_state = RS_Stopped; + } +} + +void FifoChannel::ProcessItem(RChannelItem* deliveredItem) +{ + { + AutoCriticalSection acs(&m_critsec); + m_numItemsInFlight++; + } + + DrLogI( "In ProcessItem"); + + RChannelItemType itemType = deliveredItem->GetType(); + if (RChannelItem::IsTerminationItem(itemType)) + { + DrLogI( "Got termination item"); + m_transform->Finish(itemType == RChannelItem_EndOfStream); + WriteTransformedItem(deliveredItem); + { + AutoCriticalSection acs(&m_critsec); + m_state = RS_Stopping; + } + } + else + { + if (itemType == RChannelItem_Data) + { + DataBlockItem* itemPtr = + dynamic_cast(deliveredItem); + LogAssert(itemPtr != NULL); + m_transform->ProcessItem(itemPtr); + } + } + + { + AutoCriticalSection acs(&m_critsec); + LogAssert(m_numItemsInFlight > 0); + m_numItemsInFlight--; + MaybeSendHandler(); + } +} + +void FifoChannel::ProcessWriteCompleted(RChannelItemType status, + RChannelItem* marshalFailureItem) +{ + + { + AutoCriticalSection acs(&m_critsec); + DrLogD( "In Process Write Complete. Status: %d, m_numItemsInFlight: %d", status, + m_numItemsInFlight); + + LogAssert(m_numItemsInFlight > 0); + m_numItemsInFlight--; + + MaybeSendHandler(); + } +} + +void FifoChannel::MaybeSendHandler() +{ + { + AutoCriticalSection acs(&m_critsec); + if (m_numItemsInFlight == 0) + { + if (m_state == RS_Stopping) + { + SetEvent(m_shutdownEvent); + } + else + { + m_reader->SupplyHandler(this, NULL); + } + } + } +} + +void FifoChannel::Start() +{ + DrLogI( "Starting fifochannel"); + LogAssert(!m_initialHandlerSent); + m_initialHandlerSent = true; + m_transform->Start(this); + m_state = RS_OutstandingHandler; + MaybeSendHandler(); +} + + +void FifoChannel::WriteTransformedItem(RChannelItem* transformedItem) +{ + { + AutoCriticalSection acs(&m_critsec); + m_numItemsInFlight++; + m_writer->WriteItem(transformedItem, false, this); + } +} diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfo/FifoInputChannel.cpp b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/FifoInputChannel.cpp new file mode 100644 index 0000000..2e80957 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/FifoInputChannel.cpp @@ -0,0 +1,143 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "stdafx.h" + +#include + +#pragma unmanaged + +FifoInputChannel::FifoInputChannel(UInt32 portNum, + DryadVertexProgram* vertex, WorkQueue *workQueue, + RChannelReader* channel, TransformType tType) : + InputChannel(portNum, vertex, channel) +{ + m_initialHandlerSent = false; + + m_fifoReader = NULL; + m_fifoWriter = NULL; + m_origReader = channel; + + MakeFifo(3, workQueue); + m_fifoChannel = new FifoChannel(m_origReader, m_fifoWriter->GetWriter(), vertex, tType); + m_reader = m_fifoReader->GetReader(); +} + +void FifoInputChannel::Stop() +{ + bool mustWait = false; + + if (m_initialHandlerSent) { + /* this makes sure no more items will be sent down the fifo */ + mustWait = m_fifoChannel->Stop(m_fifoWriter->GetWriter()); + } + else + { + // we never started the FifoChannel, so we must send a term item + RChannelItemRef termination; + termination.Attach(RChannelMarkerItem::Create(RChannelItem_EndOfStream, + false)); +/* +// by construction, we've never written any items to the FIFO, so a + //WriteItemSync call should not block + RChannelItemType result = m_fifoWriter->GetWriter()->WriteItemSync( + termination.Ptr(), false, NULL); +*/ + FifoChannelItemWriterHandler *fifoWriterHandler = new FifoChannelItemWriterHandler(); + m_fifoWriter->GetWriter()->WriteItem(termination, false, fifoWriterHandler); + //LogAssert(result == RChannelItem_EndOfStream); + } + + m_fifoReader->GetReader()->Drain(); + + RChannelItemRef writeTermination; + m_fifoWriter->GetWriter()->Drain(DrTimeInterval_Zero, &writeTermination); + if (writeTermination != NULL) + { + if (writeTermination->GetType() != RChannelItem_EndOfStream) + { + // NYI: do something + } + } + + // as above, we've never called Start on the FifoChannel, + //so no need to call Drain + if (m_initialHandlerSent) { + m_fifoChannel->Drain(mustWait); + } + delete m_fifoChannel; + m_fifoChannel = NULL; + + m_fifoReader = NULL; + m_fifoWriter = NULL; +} + +bool FifoInputChannel::GetTotalLength(UInt64 *length) +{ + return m_origReader->GetTotalLength(length); +} + +bool FifoInputChannel::GetExpectedLength(UInt64 *length) +{ + return m_origReader->GetExpectedLength(length); +} + +void FifoInputChannel::MakeFifo(UInt32 fifoLength, WorkQueue* workQueue) +{ +// based on DryadSubGraphVertex::EdgeInfo::MakeFifo in subgraphvertex.cpp + LogAssert(m_fifoReader == NULL); + LogAssert(m_fifoWriter == NULL); + + UInt32 uniquifier = RChannelFactory::GetUniqueFifoId(); + + DrStr64 fifoName; + fifoName.SetF("fifo://%u/compressedchannel-%u", + fifoLength, uniquifier); + + DVErrorReporter errorReporter; + RChannelFactory::OpenReader(fifoName, NULL, NULL, 1, NULL, 0, 0, workQueue, + &errorReporter, &m_fifoReader, NULL); + LogAssert(errorReporter.NoError()); + RChannelFactory::OpenWriter(fifoName, NULL, NULL, 1, NULL, 0, NULL, + &errorReporter, &m_fifoWriter); + LogAssert(errorReporter.NoError()); + + m_fifoReader->GetReader()->Start(NULL); + m_fifoWriter->GetWriter()->Start(); +} + +DataBlockItem* FifoInputChannel::ReadDataBlock(byte **ppDataBlock, + Int32 *ppDataBlockSize, + Int32 *pErrorCode) +{ + if (!m_initialHandlerSent) + { + m_initialHandlerSent = true; + m_fifoChannel->Start(); + } + return InputChannel::ReadDataBlock(ppDataBlock, ppDataBlockSize, + pErrorCode); +} + +const char* FifoInputChannel::GetURI() +{ + return m_origReader->GetURI(); +} + diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfo/FifoOutputChannel.cpp b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/FifoOutputChannel.cpp new file mode 100644 index 0000000..9f868db --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/FifoOutputChannel.cpp @@ -0,0 +1,101 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "stdafx.h" + +#include +#include +#include + + +#pragma unmanaged + + +FifoOutputChannel::FifoOutputChannel(UInt32 portNum, DryadVertexProgram* vertex, + WorkQueue *workQueue, RChannelWriter* outputChannel, TransformType tType): + OutputChannel(portNum, vertex, outputChannel) +{ + m_origWriter = outputChannel; + m_fifoReader = NULL; + m_fifoWriter = NULL; + MakeFifo(3, workQueue); + m_fifoChannel = new FifoChannel(m_fifoReader->GetReader(), m_origWriter, vertex, tType); + + m_writer = m_fifoWriter->GetWriter(); + m_fifoChannel->Start(); +} + +void FifoOutputChannel::Stop() +{ + /* this makes sure no more items will be sent down the fifo */ + bool mustWait = m_fifoChannel->Stop(m_fifoWriter->GetWriter()); + m_fifoChannel->Drain(mustWait); + delete m_fifoChannel; + m_fifoChannel = NULL; + + m_fifoReader->GetReader()->Drain(); + + RChannelItemRef writeTermination; + m_fifoWriter->GetWriter()->Drain(DrTimeInterval_Zero, &writeTermination); + if (writeTermination != NULL) + { + if (writeTermination->GetType() != RChannelItem_EndOfStream) + { + // do something + } + } + + m_fifoReader = NULL; + m_fifoWriter = NULL; +} + +void FifoOutputChannel::MakeFifo(UInt32 fifoLength, WorkQueue* workQueue) +{ +// based on DryadSubGraphVertex::EdgeInfo::MakeFifo in subgraphvertex.cpp + LogAssert(m_fifoReader == NULL); + LogAssert(m_fifoWriter == NULL); + + UInt32 uniquifier = RChannelFactory::GetUniqueFifoId(); + + DrStr64 fifoName; + fifoName.SetF("fifo://%u/compressedchannel-%u", + fifoLength, uniquifier); + + DVErrorReporter errorReporter; + RChannelFactory::OpenReader(fifoName, NULL, NULL, 1, NULL, 0, 0, workQueue, + &errorReporter, &m_fifoReader, NULL); + LogAssert(errorReporter.NoError()); + RChannelFactory::OpenWriter(fifoName, NULL, NULL, 1, NULL, 0, NULL, + &errorReporter, &m_fifoWriter); + LogAssert(errorReporter.NoError()); + + m_fifoReader->GetReader()->Start(NULL); + m_fifoWriter->GetWriter()->Start(); +} + +void FifoOutputChannel::SetInitialSizeHint(UInt64 hint) +{ + m_origWriter->SetInitialSizeHint(hint); +} + +const char* FifoOutputChannel::GetURI() +{ + return m_origWriter->GetURI(); +} diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfo/GzipCompressionChannelTransform.cpp b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/GzipCompressionChannelTransform.cpp new file mode 100644 index 0000000..bde2d94 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/GzipCompressionChannelTransform.cpp @@ -0,0 +1,296 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "stdafx.h" + +#include +#include + +#pragma unmanaged + +#ifdef LINKWITHZLIB + +static const UInt32 s_defaultWriteSize = 256 * 1024; + +GzipCompressionChannelTransform::GzipCompressionChannelTransform(DryadVertexProgram* vertex, + bool gzipHeader, + bool optimizeForSpeed) +{ + // zlib structures + memset(&m_stream, 0, sizeof(m_stream)); + m_stream.next_in = NULL; + m_stream.avail_in = 0; + m_stream.next_out = NULL; + m_stream.avail_out = 0; + m_stream.zalloc = Z_NULL; + m_stream.zfree = Z_NULL; + m_stream.opaque = NULL; + + m_crc = crc32(0L, Z_NULL, 0); + m_crcStart = NULL; + m_channel = NULL; + m_writeSize = 0; /* try to infer this from the first item to process */ + m_zlibArg = Z_NO_FLUSH; + m_firstReadProcessed = false; + m_gzipHeader = gzipHeader; + m_optimizeForSpeed = optimizeForSpeed; + m_vertex = vertex; +} + +GzipCompressionChannelTransform::~GzipCompressionChannelTransform() +{ + + // zlib structures + memset(&m_stream, 0, sizeof(m_stream)); + m_stream.next_in = NULL; + m_stream.avail_in = 0; + m_stream.next_out = NULL; + m_stream.avail_out = 0; + m_stream.zalloc = Z_NULL; + m_stream.zfree = Z_NULL; + m_stream.opaque = NULL; + + m_crc = crc32(0L, Z_NULL, 0); + m_crcStart = NULL; + m_channel = NULL; +} + +void GzipCompressionChannelTransform::AllocateOutputBuffer() +{ + LogAssert(m_writeSize > 0); + m_outputBuffer.Attach(new DataBlockItem(m_writeSize)); + m_stream.next_out = (byte *) m_outputBuffer->GetDataAddress(); + LogAssert(m_outputBuffer.Ptr() != NULL); + m_stream.avail_out = (z_uInt) m_writeSize; +} + +inline void GzipCompressionChannelTransform::AppendOutputByte(byte value) +{ + //LogAssert(m_stream.avail_out > 0); + if (m_stream.avail_out == 0) + { + WriteOutputBuffer(); + AllocateOutputBuffer(); + } + m_stream.avail_out--; + *(m_stream.next_out) = value; + m_stream.next_out++; +} + +inline void GzipCompressionChannelTransform::IncrementOutputPosition(UInt32 increment) +{ + LogAssert(m_stream.avail_out >= increment); + m_stream.avail_out -= increment; + m_stream.next_out += increment; +} + +void GzipCompressionChannelTransform::SetOutputBufferSize(UInt32 bufferSize) +{ + WriteOutputBuffer(); + m_writeSize = bufferSize; + AllocateOutputBuffer(); +} + +// Process block should be called with m_critsec held +DrError GzipCompressionChannelTransform::ProcessBlock() +{ + int retVal = -1; + bool done = false; + while (!done) + { + // record the starting point for the crc +// DrLogI( "About to deflate", +// "avail_in %u total_in %u avail_out %u total_out %u arg %u", +// m_stream.avail_in, m_stream.total_in, +// m_stream.avail_out, m_stream.total_out, m_zlibArg); + m_crcStart = m_stream.next_in; + retVal = deflate(&m_stream, m_zlibArg); +// DrLogI( "Done deflate", +// "ret %u", retVal); + if (m_gzipHeader) + { + m_crc = crc32(m_crc, m_crcStart, + (UInt32) (m_stream.next_in - m_crcStart)); + } + + UInt32 *intPtr = NULL; + switch (retVal) { + case Z_OK: + // nothing needed here + break; + case Z_STREAM_END: + done = true; + if (m_gzipHeader) + { + if (m_stream.avail_out < 8) + { + WriteOutputBuffer(); + AllocateOutputBuffer(); + } + intPtr = (UInt32 *)m_stream.next_out; + *intPtr = m_crc; + intPtr++; + *intPtr = m_stream.total_in; + IncrementOutputPosition(8); + } + retVal = deflateEnd(&m_stream); + LogAssert(retVal == Z_OK); + WriteOutputBuffer(); + AllocateOutputBuffer(); + break; + default: + char *errorMsg = "NULL"; + if (m_stream.msg != NULL) { + errorMsg = m_stream.msg; + } + DrLogA( "Error in deflate. retVal: %d Message: %s", retVal, errorMsg); + } + if (m_stream.avail_out == 0) + { + WriteOutputBuffer(); + AllocateOutputBuffer(); + } + if (m_stream.avail_in == 0 && m_zlibArg != Z_FINISH) + { + done = true; + } + } + return DrError_OK; +} + +DrError GzipCompressionChannelTransform::ProcessItem(DataBlockItem *item) +{ + UInt32 inputSize = (UInt32) item->GetAvailableSize(); + if (inputSize == 0) + { + /* nothing to do here */ + return DrError_OK; + } + + { + AutoCriticalSection acs(&m_critsec); + + if (m_firstReadProcessed == false) + { + if (m_writeSize == 0) + { + /* try to infer the write size from the size of the + uncompressed data coming in to the compresser, + i.e. what the app thought it wanted to write */ + UInt32 itemSize = (UInt32) item->GetAllocatedSize(); + UInt32 inputPages = itemSize / (4 * 1024); + if (inputPages * 4 * 1024 == itemSize) + { + m_writeSize = itemSize; + } + else + { + /* if the block is a weird size then use the + default so we don't force the writer to use + buffered IO */ + m_writeSize = s_defaultWriteSize; + } + + if (m_writeSize > 4 * 1024 * 1024) + { + m_writeSize = 4 * 1024 * 1024; + } + } + + m_zlibArg = Z_NO_FLUSH; + + AllocateOutputBuffer(); + + if (m_gzipHeader) + { + WriteGzipHeader(); + } + + int level = (m_optimizeForSpeed) ? + (Z_BEST_SPEED) : (Z_DEFAULT_COMPRESSION); + DrLogI( "Set compression level. Level %s", (m_optimizeForSpeed) ? "faster" : "default"); + + int retVal = deflateInit2(&m_stream, level, Z_DEFLATED, + -MAX_WBITS, 9, Z_DEFAULT_STRATEGY); + LogAssert(retVal == Z_OK); + + m_firstReadProcessed = true; + } + + LogAssert(m_zlibArg == Z_NO_FLUSH); + m_stream.avail_in = (z_uInt) item->GetAvailableSize(); + m_stream.next_in = (z_Bytef *)item->GetDataAddress(); +// DrLogI( "In process item"); + return ProcessBlock(); + } +} + +DrError GzipCompressionChannelTransform::Finish(bool atEndOfStream) +{ + { + DrError retval = DrError_OK; + AutoCriticalSection acs(&m_critsec); + + if (m_firstReadProcessed) + { + m_zlibArg = Z_FINISH; +// DrLogI( "In finish"); + retval = ProcessBlock(); + + m_firstReadProcessed = false; + } + return retval; + } +} + +DrError GzipCompressionChannelTransform::Start(FifoChannel *channel) +{ + m_channel = channel; + return DrError_OK; +} + +void GzipCompressionChannelTransform::WriteGzipHeader() +{ + AppendOutputByte(0x1f); // ID1 + AppendOutputByte(0x8b); // ID2 + AppendOutputByte(8); // CM + for (int i = 0; i < 6; i++) { + AppendOutputByte(0); // FLAG, MTIME, and XFL + } + AppendOutputByte(11); // OS +} + + +void GzipCompressionChannelTransform::WriteOutputBuffer() +{ + DataBlockItem *dbi = (DataBlockItem *)m_outputBuffer.Ptr(); + int bytesToWrite = m_writeSize - m_stream.avail_out; +// DrLogI( "Writing output buffer", +// "bytes: %d", bytesToWrite); + if (bytesToWrite > 0) { + dbi->SetAvailableSize(bytesToWrite); + m_channel->WriteTransformedItem(dbi); + } + // for safety, zero the output pointers + m_stream.next_out = NULL; + m_stream.avail_out = 0; + m_outputBuffer = NULL; +} +#endif \ No newline at end of file diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfo/GzipDecompressionChannelTransform.cpp b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/GzipDecompressionChannelTransform.cpp new file mode 100644 index 0000000..235ec82 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/GzipDecompressionChannelTransform.cpp @@ -0,0 +1,391 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "stdafx.h" + +#include +#include + +#pragma unmanaged + +#ifdef LINKWITHZLIB + +GzipDecompressionChannelTransform::GzipDecompressionChannelTransform(DryadVertexProgram* vertex, bool gzipHeader) +{ + m_firstReadProcessed = false; + m_vertex = vertex; + + // zlib structures + memset(&m_stream, 0, sizeof(m_stream)); + m_stream.next_in = NULL; + m_stream.avail_in = 0; + m_stream.next_out = NULL; + m_stream.avail_out = 0; + m_stream.zalloc = Z_NULL; + m_stream.zfree = Z_NULL; + m_stream.opaque = NULL; + + m_crc = crc32(0L, Z_NULL, 0); + m_crcStart = NULL; + m_channel = NULL; + m_writeSize = 256 * 1024; + m_trailerBytesFilled = 0; + m_trailerBytes = new byte[TrailerLength]; + + m_gzipHeader = gzipHeader; +} + +GzipDecompressionChannelTransform::~GzipDecompressionChannelTransform() +{ + // zlib structures + memset(&m_stream, 0, sizeof(m_stream)); + m_stream.next_in = NULL; + m_stream.avail_in = 0; + m_stream.next_out = NULL; + m_stream.avail_out = 0; + m_stream.zalloc = Z_NULL; + m_stream.zfree = Z_NULL; + m_stream.opaque = NULL; + + m_crc = 0; + m_crcStart = NULL; + m_channel = NULL; + m_outputBuffer = NULL; + delete m_trailerBytes; +} + +DrError GzipDecompressionChannelTransform::AllocateOutputBuffer() +{ + DataBlockItem * block = new DataBlockItem(m_writeSize); + if (m_writeSize > 0 && block == NULL) { + m_vertex->ReportError(DryadError_ChannelRestart, "Failed to allocate enough memory for decompression; corrupted input stream?"); + DecompressionError(); + return DrError_OutOfMemory; + } + m_outputBuffer.Attach(block); + m_stream.next_out = (byte *) m_outputBuffer->GetDataAddress(); + m_stream.avail_out = (z_uInt) m_writeSize; + m_crcStart = m_stream.next_out; + return DrError_OK; +} + +DrError GzipDecompressionChannelTransform::Finish(bool atEndOfStream) +{ + if (atEndOfStream) + { + WriteOutputBuffer(); + } + return DrError_OK; +} + +inline DrError GzipDecompressionChannelTransform::IncrementInputPosition(UInt32 increment) +{ + if (m_stream.avail_in >= increment) { + m_stream.avail_in -= increment; + m_stream.next_in += increment; + return DrError_OK; + } + else { + DrLogE( "Buffer overrun during decompression; corrupted input stream?"); + m_vertex->ReportError(DryadError_ChannelRestart, "Buffer overrun during decompression; corrupted input stream?"); + DecompressionError(); + return DrError_Fail; + } +} + +DrError GzipDecompressionChannelTransform::ProcessItem(DataBlockItem *item) +{ + m_stream.next_in = (byte *) item->GetDataAddress(); + m_stream.avail_in = (Int32) item->GetAvailableSize(); + int retVal = 0; + DrError err = DrError_OK; + + if (m_trailerBytesFilled != 0) { // we must have a split trailer + int bytesToCopy = TrailerLength - m_trailerBytesFilled; + memcpy(m_trailerBytes + m_trailerBytesFilled, + m_stream.next_in, bytesToCopy); + m_trailerBytesFilled += bytesToCopy; + err = IncrementInputPosition(bytesToCopy); + if (err != DrError_OK) { + return DrError_IoReadWriteError; + } + if (m_trailerBytesFilled == TrailerLength) { + err = ProcessTrailer(); + if (err != DrError_OK) { + return err; + } + } + } + + while (m_stream.avail_in != 0) { + if (!m_firstReadProcessed) { + if (m_gzipHeader) + { + // need to read the gzip header + err = ReadGzipHeader(); + } + if (err == DrError_OK) { + m_crcStart = m_stream.next_out; + retVal = inflateInit2(&m_stream, -MAX_WBITS); + if (retVal != Z_OK) { + DecompressionError(); + err = DrError_IoReadWriteError; + m_vertex->ReportError(DryadError_ChannelRestart, "Error during stream decompression"); + } + } + m_firstReadProcessed = true; + } + + retVal = inflate(&m_stream, Z_SYNC_FLUSH); + if (m_gzipHeader) + { + m_crc = crc32(m_crc, m_crcStart, (UInt32)(m_stream.next_out - + m_crcStart)); + m_crcStart = m_stream.next_out; + } + + switch (retVal) { + case Z_OK: + // done with this input block wait for next one to arrive + break; + case Z_STREAM_END: + if (m_gzipHeader) + { + if (m_stream.avail_in >= TrailerLength) { + memcpy(m_trailerBytes, m_stream.next_in, TrailerLength); + err = IncrementInputPosition(TrailerLength); + ProcessTrailer(); + } else { + memcpy(m_trailerBytes + m_trailerBytesFilled, + m_stream.next_in, m_stream.avail_in); + m_trailerBytesFilled = m_stream.avail_in; + err = IncrementInputPosition(m_trailerBytesFilled); + } + } + else + { + /* if there's more data following, things will get + very confused, but that's a malformed stream */ + WriteOutputBuffer(); + err = AllocateOutputBuffer(); + } + break; + default: + char *errorMsg = "NULL"; + if (m_stream.msg != NULL) { + errorMsg = m_stream.msg; + } + DrLogE( "Error in stream inflation: %s", errorMsg); + err = DrError_IoReadWriteError; + DecompressionError(); + m_vertex->ReportError(DryadError_ChannelRestart, "Error in stream inflate %s", errorMsg); + } + + if (err == DrError_OK && m_stream.avail_out == 0) { + WriteOutputBuffer(); + err = AllocateOutputBuffer(); + } + + if (err != DrError_OK) + // out of the loop + break; + } + return err; +} + +DrError GzipDecompressionChannelTransform::ProcessTrailer() +{ + DrError err = DrError_OK; + + UInt32 *intPtr = (UInt32 *)m_trailerBytes; + if (m_crc != *intPtr) { + DecompressionError(); + m_vertex->ReportError(DryadError_ChannelRestart, "Corrupted input stream during decompression"); + DrLogE( "Corrupted input stream during decompression"); + err = DrError_IoReadWriteError; + } + intPtr++; + // NYI - check size against orig size, + // which is in intPtr + // IncrementInputPosition(8); no longer needed + int retVal = inflateEnd(&m_stream); + if (retVal != Z_OK) { + DecompressionError(); + m_vertex->ReportError(DryadError_ChannelRestart, "Corrupted input stream during decompression"); + DrLogE( "Corrupted input stream during decompression"); + err = DrError_IoReadWriteError; + } + m_firstReadProcessed = false; //deal with multiple streams + m_trailerBytesFilled = 0; + // in most cases, this is the end of the stream, so + // write the output + if (err == DrError_OK && m_stream.avail_out != m_writeSize) { + WriteOutputBuffer(); + err = AllocateOutputBuffer(); + } + return err; +} + +DrError GzipDecompressionChannelTransform::ReadGzipHeader() +{ + if (!(m_stream.avail_in > 10)) { + goto error; + } + bool flagExtra = false; + bool flagFilename = false; + bool flagComment = false; + bool flagFHCrc = false; + DrError err = DrError_OK; + + // check the 2 id bytes + if (*(m_stream.next_in) != 0x1f) + { + goto error; + } + err = IncrementInputPosition(1); + if (err != DrError_OK) + goto error; + if (*(m_stream.next_in) != 0x8b) { + goto error; + } + err = IncrementInputPosition(1); + if (err != DrError_OK) + goto error; + + // check the cm + if (*(m_stream.next_in) != 8) { + goto error; + } + err = IncrementInputPosition(1); + if (err != DrError_OK) + goto error; + + // check the flag bits + if (*(m_stream.next_in) & 0x2) + { + flagFHCrc = true; + } + if (*(m_stream.next_in) & 0x4) + { + flagExtra = true; + } + if (*(m_stream.next_in) & 0x8) + { + flagFilename = true; + } + if (*(m_stream.next_in) & 0x10) + { + flagComment = true; + } + // check the reserved bits (5-7) + if (*(m_stream.next_in) & 0xE0){ + //LogAssert(false, "Flag bits 5-7 in the gzip header must be 0"); + goto error; + } + err = IncrementInputPosition(1); + if (err != DrError_OK) + goto error; + // skip mtime(4), xfl(1), and OS(1) fields + err = IncrementInputPosition(6); + if (err != DrError_OK) + goto error; + + if (flagExtra) + { + short* fieldLen = (short *) m_stream.next_in; + err = IncrementInputPosition(UInt32(2 + *fieldLen)); + if (err != DrError_OK) + goto error; + } + if (flagFilename) + { + while (*(m_stream.next_in) != 0) { + err = IncrementInputPosition(1); + if (err != DrError_OK) + goto error; + } + // and skip the terminating 0 + err = IncrementInputPosition(1); + if (err != DrError_OK) + goto error; + } + if (flagComment) + { + while (*(m_stream.next_in) != 0) { + err = IncrementInputPosition(1); + if (err != DrError_OK) + goto error; + } + // and skip the terminating 0 + err = IncrementInputPosition(1); + if (err != DrError_OK) + goto error; + } + if (flagFHCrc) + { + // just skip the header crc for now + err = IncrementInputPosition(2); + if (err != DrError_OK) + goto error; + } + return DrError_OK; + + error: + m_vertex->ReportError(DryadError_ChannelRestart, "Compression header corrupted"); + DrLogE( "Compression header corrupted"); + DecompressionError(); + return DrError_IoReadWriteError; +} + +void GzipDecompressionChannelTransform::DecompressionError() +{ + RChannelItemRef termination; + termination.Attach(RChannelMarkerItem::Create(RChannelItem_MarshalError, false)); + m_channel->WriteTransformedItem(termination.Ptr()); +} + +void GzipDecompressionChannelTransform::SetOutputBufferSize(UInt32 bufferSize) +{ + WriteOutputBuffer(); + m_writeSize = bufferSize; + AllocateOutputBuffer(); +} + +DrError GzipDecompressionChannelTransform::Start(FifoChannel *channel) +{ + m_channel = channel; + AllocateOutputBuffer(); + return DrError_OK; +} + +void GzipDecompressionChannelTransform::WriteOutputBuffer() +{ + DataBlockItem *dbi = (DataBlockItem *)m_outputBuffer.Ptr(); + int bytesToWrite = (int) m_writeSize - m_stream.avail_out; + if (bytesToWrite > 0) { + dbi->SetAvailableSize(bytesToWrite); + m_channel->WriteTransformedItem(dbi); + } + // for safety, zero the output pointers + m_stream.next_out = NULL; + m_stream.avail_out = 0; + m_outputBuffer = NULL; +} +#endif \ No newline at end of file diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfo/InputChannel.cpp b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/InputChannel.cpp new file mode 100644 index 0000000..7cc2304 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/InputChannel.cpp @@ -0,0 +1,164 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "stdafx.h" + +#include + +#pragma unmanaged + + +InputChannel::InputChannel(UInt32 portNum, + DryadVertexProgram* vertex, + RChannelReader* channel) +{ + m_reader = channel; + m_vertex = vertex; + m_bytesRead = 0L; + m_portNum = portNum; + m_atEOC = false; +} + +InputChannel::~InputChannel() +{ + m_reader = NULL; + m_vertex = NULL; +} + +void InputChannel::Stop() +{ +} + +bool InputChannel::AtEndOfChannel() +{ + return m_atEOC; +} + +Int64 InputChannel::GetBytesRead() +{ + return m_bytesRead; +} + +RChannelReader* InputChannel::GetReader() +{ + return m_reader; +} + +bool InputChannel::GetTotalLength(UInt64 *length) +{ + return m_reader->GetTotalLength(length); +} + +bool InputChannel::GetExpectedLength(UInt64 *length) +{ + return m_reader->GetExpectedLength(length); +} + +const char* InputChannel::GetURI() +{ + return m_reader->GetURI(); +} + +DataBlockItem* InputChannel::ReadDataBlock(byte **ppDataBlock, + Int32 *ppDataBlockSize, + Int32 *pErrorCode) +{ + RChannelItemRef nextItem; + DataBlockItem *itemPtr = NULL; + bool result = m_reader->FetchNextItem(&nextItem, + DrTimeInterval_Infinite); + LogAssert(result); + LogAssert(nextItem != NULL, "FetchNextItem() returned a NULL item"); + RChannelItemType itemType = nextItem->GetType(); + + switch (itemType) { + case RChannelItem_Data: + itemPtr = (DataBlockItem*) nextItem.Ptr(); + itemPtr->IncRef(); + *ppDataBlock = (byte *) itemPtr->GetDataAddress(); + *ppDataBlockSize = (Int32) itemPtr->GetAvailableSize(); +#ifdef VERBOSE + fprintf(stdout, "Read %d bytes from channel %u.\n", + itemPtr->GetAvailableSize(), m_portNum); + fprintf(stdout, "MEM ReadDataBlock block has addr %p.\n", itemPtr); + fflush(stdout); +#endif + m_bytesRead += itemPtr->GetAvailableSize(); + LogAssert(itemPtr->GetAvailableSize() != 0); + *pErrorCode = 0; + break; + case RChannelItem_BufferHole: + DrLogE("Error: Received BufferHole Item from channel."); + if (m_vertex != NULL) { + m_vertex->ReportError(DryadError_BufferHole); + } + *pErrorCode = 1; + break; + case RChannelItem_ItemHole: + DrLogE("Received ItemHole Item from channel."); + if (m_vertex != NULL) { + m_vertex->ReportError(DryadError_ItemHole); + } + *pErrorCode = 1; + break; + case RChannelItem_EndOfStream: + DrLogD("ReadDataBlock received EndOfStream"); + m_atEOC= true; + *pErrorCode = 0; + break; + case RChannelItem_Restart: + DrLogE("Received Restart Item from channel."); + if (m_vertex != NULL) { + m_vertex->ReportError(DryadError_ChannelRestart); + } + *pErrorCode = 1; + break; + case RChannelItem_Abort: + DrLogE("Received Abort Item from channel."); + if (m_vertex != NULL) { + m_vertex->ReportError(DryadError_ChannelAbort); + } + *pErrorCode = 1; + break; + case RChannelItem_ParseError: + DrLogE("Received ParseError Item from channel."); + if (m_vertex != NULL) { + m_vertex->ReportError(DryadError_ItemParseError); + } + *pErrorCode = 1; + break; + case RChannelItem_MarshalError: + DrLogE("Received MarshalError Item from channel."); + if (m_vertex != NULL) { + m_vertex->ReportError(DryadError_ItemMarshalError); + } + *pErrorCode = 1; + break; + default: + DrLogE("Received Item of unknown type from channel."); + if (m_vertex != NULL) { + m_vertex->ReportError(DryadError_AssertFailure, + "Received Item of unknown type from channel."); + } + *pErrorCode = 1; + break; + } + return itemPtr; +} diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfo/NullChannelTransform.cpp b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/NullChannelTransform.cpp new file mode 100644 index 0000000..e39eaa3 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/NullChannelTransform.cpp @@ -0,0 +1,64 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "stdafx.h" + +#include + +#pragma unmanaged + + +ChannelTransform::~ChannelTransform() +{ + // nothing needed here +} + +NullChannelTransform::NullChannelTransform(DryadVertexProgram* vertex) +{ + m_vertex = vertex; +} + +NullChannelTransform::~NullChannelTransform() +{ + m_channel = NULL; +} + +DrError NullChannelTransform::Start(FifoChannel *channel) +{ + m_channel = channel; + return DrError_OK; +} + +void NullChannelTransform::SetOutputBufferSize(UInt32 bufferSize) +{ + // nothing needed here +} + +DrError NullChannelTransform::ProcessItem(DataBlockItem *item) +{ + m_channel->WriteTransformedItem(item); + return DrError_OK; +} + +DrError NullChannelTransform::Finish(bool atEndOfStream) +{ + m_channel = NULL; + return DrError_OK; +} diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfo/OutputChannel.cpp b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/OutputChannel.cpp new file mode 100644 index 0000000..45f34d1 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/OutputChannel.cpp @@ -0,0 +1,140 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "stdafx.h" + +#include + +#pragma unmanaged + + +OutputChannel::OutputChannel(UInt32 portNum, + DryadVertexProgram* vertex, + RChannelWriter* outputChannel) +{ + m_writer = outputChannel; + m_vertex = vertex; + m_bytesWritten = 0L; + m_portNum = portNum; +} + +OutputChannel::~OutputChannel() +{ + m_writer = NULL; + m_vertex = NULL; +} + + +void OutputChannel::Stop() +{ +} + +Int64 OutputChannel::GetBytesWritten() +{ + return m_bytesWritten; +} + +void OutputChannel::SetInitialSizeHint(UInt64 hint) +{ + m_writer->SetInitialSizeHint(hint); +} + +const char* OutputChannel::GetURI() +{ + return m_writer->GetURI(); +} + +RChannelWriter* OutputChannel::GetWriter() +{ + return m_writer; +} + +BOOL OutputChannel::WriteDataBlock(DataBlockItem *pItem, + Int32 numBytesToWrite) +{ + if (numBytesToWrite != (Int32) pItem->GetAvailableSize()) + { + pItem->SetAvailableSize(numBytesToWrite); + } + + m_bytesWritten += numBytesToWrite; + BOOL returnValue = true; + RChannelItemRef marshalFailureItem; + RChannelItemType result = m_writer->WriteItemSync(pItem, + false, &marshalFailureItem); + switch (result) { + case RChannelItem_Data: + // successful write + returnValue = true; + break; + case RChannelItem_EndOfStream: + DrLogE("Received EndOfStream Item from channel write."); + if (m_vertex != NULL) { + m_vertex->ReportError(DryadError_ChannelWriteError); + } + returnValue = false; + break; + case RChannelItem_Restart: + DrLogE("Received Restart Item from channel write."); + if (m_vertex != NULL) { + m_vertex->ReportError(DryadError_ChannelRestart); + } + returnValue = false; + break; + case RChannelItem_Abort: + DrLogE("Received Abort Item from channel write."); + if (m_vertex != NULL) { + m_vertex->ReportError(DryadError_ChannelAbort); + } + returnValue = false; + break; + default: + DrLogE("Received Item of unexpected type from channel."); + if (m_vertex != NULL) { + m_vertex->ReportError(DryadError_AssertFailure, + "Received Item of unexpected type from channel."); + } + returnValue = false; + break; + } + + // and make sure there was not a marshalling error + if (marshalFailureItem.Ptr() != NULL) + { + DrLogE("Received MarshalError Item from channel."); + if (m_vertex != NULL) { + m_vertex->ReportError(DryadError_ItemMarshalError); + } + returnValue = false; + } + +#ifdef VERBOSE + fprintf(stdout, "MEM WriteDataBlock block has addr %p.\n", pItem); + fflush(stdout); +#endif + + return returnValue; +} + +void OutputChannel::ProcessWriteCompleted(RChannelItemType status, + RChannelItem* marshalFailureItem) +{ + LogAssert(status == RChannelItem_Data); +} diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfo/WrapperNativeInfo.cpp b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/WrapperNativeInfo.cpp new file mode 100644 index 0000000..4e80d83 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/WrapperNativeInfo.cpp @@ -0,0 +1,332 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "stdafx.h" + +#include +#include +#include +#include + +#pragma unmanaged + +WrapperNativeInfoBase::~WrapperNativeInfoBase() +{ +} + +//#define VERBOSE +WrapperNativeInfo::WrapperNativeInfo(UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel, + DryadVertexProgram* vertex, + WorkQueue* workQueue) +{ +#ifdef VERBOSE + if (DrLogging::Enabled(LogLevel_Info)) + { + fprintf(stdout, "WrapperNativeInfo::WrapperNativeInfo: %p %u\n", this, numberOfInputChannels); + fflush(stdout); + } +#endif + DrLogI( "Creating WrapperNativeInfo"); + m_numberOfInputChannels = numberOfInputChannels; + m_inputChannels = new InputChannel*[m_numberOfInputChannels]; + + m_numberOfOutputChannels = numberOfOutputChannels; + m_outputChannels = new OutputChannel*[m_numberOfOutputChannels]; + + m_vertex = vertex; + m_workQueue = workQueue; + + for (UInt32 i = 0; i < m_numberOfInputChannels; i++) { + m_inputChannels[i] = new InputChannel(i, m_vertex, inputChannel[i]); + } + + for (UInt32 i = 0; i < m_numberOfOutputChannels; i++) + { + m_outputChannels[i] = new OutputChannel(i, m_vertex, outputChannel[i]); + } +} + +WrapperNativeInfo::~WrapperNativeInfo() +{ + m_vertex = NULL; + m_workQueue = NULL; + + for (UInt32 i = 0; i < m_numberOfInputChannels; i++) + { + delete m_inputChannels[i]; + } + m_inputChannels = NULL; + + for (UInt32 i = 0; i < m_numberOfOutputChannels; i++) + { + delete m_outputChannels[i]; + } + m_outputChannels = NULL; +} + +void WrapperNativeInfo::CleanUp() +{ + // we don't own the reader/writer reference, so we should not delete it + + DrLogI( "Cleaning up WrapperNativeInfo"); + + /* first call stop on all the channels to drain any transforms + that are still in flight */ + for (UInt32 i = 0; i < m_numberOfInputChannels; i++) { + m_inputChannels[i]->Stop(); + } + + for (UInt32 i = 0; i < m_numberOfOutputChannels; i++) + { + m_outputChannels[i]->Stop(); + } + + // enumerate which channels were not completely read + for (UInt32 i = 0; i < m_numberOfInputChannels; i++) + { + if (!m_inputChannels[i]->AtEndOfChannel()) + { + DrLogI("The client did not completely read channel: %u ", i); + } + } + + // output read/write stats for debugging + Int64 totalBytesRead = 0L; + Int64 totalBytesWritten = 0L; + for (UInt32 i = 0; i < m_numberOfInputChannels; i++) + { + Int64 channelBytesRead = m_inputChannels[i]->GetBytesRead(); + DrLogI( + "WrapperNativeInfo read %I64d bytes from channel %u.", + channelBytesRead, i); + + totalBytesRead += channelBytesRead; + } + + for (UInt32 i = 0; i < m_numberOfOutputChannels; i++) + { + Int64 channelBytesWritten = m_outputChannels[i]->GetBytesWritten(); + DrLogI("WrapperNativeInfo wrote %I64d bytes to channel %u.", + channelBytesWritten, i); + totalBytesWritten += channelBytesWritten; + } + + DrLogI("WrapperNativeInfo read %I64d bytes from all channels.", totalBytesRead); + DrLogI("WrapperNativeInfo wrote %I64d bytes to all channels.", totalBytesWritten); +} + +Int64 WrapperNativeInfo::GetTotalLength(UInt32 portNum) +{ + UInt64 len; + bool isKnown = m_inputChannels[portNum]->GetTotalLength(&len); + return (isKnown) ? len : -1; +} + +Int64 WrapperNativeInfo::GetExpectedLength(UInt32 portNum) +{ + UInt64 len; + bool isKnown = m_inputChannels[portNum]->GetExpectedLength(&len); + return (isKnown) ? len : -1; +} + +Int64 WrapperNativeInfo::GetVertexId() +{ + return m_vertex->GetVertexId(); +} + +void WrapperNativeInfo::SetInitialSizeHint(UInt32 portNum, UInt64 hint) +{ + m_outputChannels[portNum]->SetInitialSizeHint(hint); +} + +UInt32 WrapperNativeInfo::GetNumOfInputs() +{ +#ifdef VERBOSE + if (DrLogging::Enabled(LogLevel_Info)) + { + fprintf(stdout, "WrapperNativeInfo::GetNumOfInputs: %p %u\n", this, m_numberOfInputChannels); + fflush(stdout); + } +#endif + return m_numberOfInputChannels; +} + +UInt32 WrapperNativeInfo::GetNumOfOutputs() +{ + return m_numberOfOutputChannels; +} + +const char* WrapperNativeInfo::GetInputChannelURI(UInt32 portNum) +{ + return m_inputChannels[portNum]->GetURI(); +} + +const char* WrapperNativeInfo::GetOutputChannelURI(UInt32 portNum) +{ + return m_outputChannels[portNum]->GetURI(); +} + +DataBlockItem* WrapperNativeInfo::AllocateDataBlock(Int32 dataBlockSize, + byte **pDataBlock) +{ + DrLogD("AllocateDataBlock(): dataBlockSize = %d", dataBlockSize); + + DataBlockItem *dbi = new DataBlockItem(dataBlockSize); + *pDataBlock = (byte *) dbi->GetDataAddress(); +#ifdef VERBOSE + if (DrLogging::Enabled(LogLevel_Info)) + { + fprintf(stdout, "MEM AllocateDataBlock block has addr %p.\n", dbi); + fflush(stdout); + } +#endif + return dbi; +} + +void WrapperNativeInfo::ReleaseDataBlock(DataBlockItem *pItem) +{ +#ifdef VERBOSE + if (DrLogging::Enabled(LogLevel_Info)) + { + fprintf(stdout, "MEM ReleaseDataBlock block has addr %p.\n", pItem); + fflush(stdout); + } +#endif + if (pItem != NULL) + { + pItem->DecRef(); + } +} + +DataBlockItem* WrapperNativeInfo::ReadDataBlock(UInt32 portNum, byte **ppDataBlock, Int32 *ppDataBlockSize, Int32* pErrorCode) +{ + LogAssert(portNum < m_numberOfInputChannels); + + DrLogD("ReadDataBlock() entering: portNum = %d", portNum); + + DataBlockItem *dbi = m_inputChannels[portNum]->ReadDataBlock(ppDataBlock, ppDataBlockSize, pErrorCode); + + DrLogD("ReadDataBlock() returning: portNum = %d, dataBlockSize = %d, errorCode = %d", portNum, *ppDataBlockSize, *pErrorCode); + + return dbi; +} + +BOOL WrapperNativeInfo::WriteDataBlock(UInt32 portNum, + DataBlockItem *pItem, + Int32 numBytesToWrite) +{ + DrLogD("WriteDataBlock() entering: portNum = %d", portNum); + + LogAssert(portNum < m_numberOfOutputChannels); + + BOOL retVal = m_outputChannels[portNum]->WriteDataBlock(pItem, numBytesToWrite); + + DrLogD("WriteDataBlock() returning: portNum = %d, success = %d ", portNum, retVal); + + return retVal; +} + + +void WrapperNativeInfo::EnableFifoInputChannel(Int32 compressionScheme, + UInt32 channel) +{ + DrLogI( "Enabling fifo for input channel. Channel %u scheme %d", channel, compressionScheme); + LogAssert(channel < m_numberOfInputChannels); + LogAssert((compressionScheme >= 0) && ( compressionScheme <= 6)); + TransformType tType = TT_NullTransform; + if (compressionScheme == 1) + { + tType = TT_GzipDecompression; + } + else if (compressionScheme == 2) + { + tType = TT_GzipDecompression; + } + else if (compressionScheme == 3) + { + tType = TT_DeflateDecompression; + } + else if (compressionScheme == 4) + { + tType = TT_DeflateDecompression; + } + + /* Xpress removed, but left in comments as an example of + * supporting an alternate compression scheme. + else if (compressionScheme == 5) + { + tType = TT_XpressDecompression; + } + else if (compressionScheme == 6) + { + tType = TT_XpressDecompression; + } + */ + + InputChannel *origChannel = m_inputChannels[channel]; + m_inputChannels[channel] = new FifoInputChannel(channel, + m_vertex, m_workQueue, origChannel->GetReader(), tType); + delete origChannel; +} + +void WrapperNativeInfo::EnableFifoOutputChannel(Int32 compressionScheme, + UInt32 channel) +{ + DrLogI( "Enabling fifo for output channel. Channel %u scheme %d", channel, compressionScheme); + LogAssert(channel < m_numberOfOutputChannels); + LogAssert((compressionScheme >= 0) && ( compressionScheme <= 6)); + TransformType tType = TT_NullTransform; + if (compressionScheme == 1) + { + tType = TT_GzipCompression; + } + else if (compressionScheme == 2) + { + tType = TT_GzipFastCompression; + } + else if (compressionScheme == 3) + { + tType = TT_DeflateCompression; + } + else if (compressionScheme == 4) + { + tType = TT_DeflateFastCompression; + } + + /* Xpress removed, but left in comments as an example of + * supporting an alternate compression scheme. + else if (compressionScheme == 5) + { + tType = TT_XpressCompression; + } + else if (compressionScheme == 6) + { + tType = TT_XpressFastCompression; + } + */ + + OutputChannel *origChannel = m_outputChannels[channel]; + m_outputChannels[channel] = new FifoOutputChannel(channel, + m_vertex, m_workQueue, origChannel->GetWriter(), tType); + delete origChannel; +} diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfo/WrapperNativeInfo.vcxproj b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/WrapperNativeInfo.vcxproj new file mode 100644 index 0000000..79f0bb0 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/WrapperNativeInfo.vcxproj @@ -0,0 +1,151 @@ + + + + + Debug + Win32 + + + Debug + x64 + + + Release + Win32 + + + Release + x64 + + + + {AB9EA66C-5811-49A7-B002-24203AEB9083} + WrapperNativeInfo + Win32Proj + + + + StaticLibrary + + + StaticLibrary + + + StaticLibrary + Unicode + true + + + StaticLibrary + Unicode + true + + + + + + + + + + + + + + + + + + + <_ProjectFileVersion>10.0.40219.1 + Debug\ + Debug\ + $(Platform)\$(Configuration)\ + $(Platform)\$(Configuration)\ + Release\ + Release\ + $(Platform)\$(Configuration)\ + $(Platform)\$(Configuration)\ + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + + + + Disabled + WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions) + true + EnableFastChecks + MultiThreadedDebugDLL + + + Level3 + EditAndContinue + + + + + X64 + + + Disabled + ..\include;..\..\system\classlib\include;..\..\system\channel\include;..\..\system\dprocess\include;..\..\system\common\include;..\zlib;%(AdditionalIncludeDirectories) + WIN32;_DEBUG;_LIB;WIN32_LEAN_AND_MEAN;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + false + Default + MultiThreadedDebugDLL + Create + Level3 + ProgramDatabase + + + + + WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + + + + + X64 + + + WIN32;NDEBUG;_LIB;WIN32_LEAN_AND_MEAN;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + MultiThreadedDLL + Create + Level3 + ProgramDatabase + ..\include;..\..\system\classlib\include;..\..\system\channel\include;..\..\system\dprocess\include;..\..\system\common\include;..\zlib;%(AdditionalIncludeDirectories) + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfo/WrapperNativeInfo.vcxproj.filters b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/WrapperNativeInfo.vcxproj.filters new file mode 100644 index 0000000..cdb0c92 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/WrapperNativeInfo.vcxproj.filters @@ -0,0 +1,51 @@ + + + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hpp;hxx;hm;inl;inc;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + + + Header Files + + + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + \ No newline at end of file diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfo/stdafx.h b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/stdafx.h new file mode 100644 index 0000000..b39f805 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfo/stdafx.h @@ -0,0 +1,25 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "DrExecution.h" +#include "XCompute.h" + diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/DryadLINQNativeChannels.def b/DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/DryadLINQNativeChannels.def new file mode 100644 index 0000000..7839c61 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/DryadLINQNativeChannels.def @@ -0,0 +1,19 @@ +LIBRARY "DryadLINQNativeChannels.dll" + +EXPORTS + GetTotalLength + GetExpectedLength + GetVertexId + GetNumOfInputs + GetNumOfOutputs + SetInitialSizeHint + GetInputChannelURI + GetOutputChannelURI + Flush + Close + ReadDataBlock + WriteDataBlock + AllocateDataBlock + ReleaseDataBlock + EnableFifoInputChannel + EnableFifoOutputChannel diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/WrapperNativeInfoDll.vcxproj b/DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/WrapperNativeInfoDll.vcxproj new file mode 100644 index 0000000..6ba7e8d --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/WrapperNativeInfoDll.vcxproj @@ -0,0 +1,182 @@ + + + + + Debug + Win32 + + + Debug + x64 + + + Release + Win32 + + + Release + x64 + + + + {3EE0920C-0607-4569-9EC3-5C12BB6EF244} + WrapperNativeInfoDll + Win32Proj + + + + DynamicLibrary + + + DynamicLibrary + + + DynamicLibrary + Unicode + + + DynamicLibrary + Unicode + false + + + + + + + + + + + + + + + + + + + <_ProjectFileVersion>10.0.40219.1 + Debug\ + Debug\ + true + ..\..\..\..\bin\$(Configuration)\ + $(Platform)\$(Configuration)\ + false + Release\ + Release\ + true + ..\..\..\..\bin\$(Configuration)\ + $(Platform)\$(Configuration)\ + false + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + DryadLINQNativeChannels + DryadLINQNativeChannels + + + + Disabled + WIN32;_DEBUG;_WINDOWS;_USRDLL;WRAPPERNATIVEINFODLL_EXPORTS;%(PreprocessorDefinitions) + true + EnableFastChecks + MultiThreadedDebugDLL + + + Level3 + EditAndContinue + + + true + Windows + MachineX86 + + + + + X64 + + + Disabled + ..\include;..\..\system\classlib\include;..\..\system\dprocess\include;..\..\system\common\include;..\..\system\channel\include;%(AdditionalIncludeDirectories) + WIN32;_DEBUG;_WINDOWS;_USRDLL;WRAPPERNATIVEINFODLL_EXPORTS;WIN32_LEAN_AND_MEAN;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + false + Default + MultiThreadedDebugDLL + + + Level3 + ProgramDatabase + + + $(OutDir)$(TargetName)$(TargetExt) + DryadLINQNativeChannels.def + true + true + Windows + MachineX64 + + + + + WIN32;NDEBUG;_WINDOWS;_USRDLL;WRAPPERNATIVEINFODLL_EXPORTS;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + + + true + Windows + true + true + MachineX86 + + + + + X64 + + + WIN32;NDEBUG;_WINDOWS;_USRDLL;WRAPPERNATIVEINFODLL_EXPORTS;WIN32_LEAN_AND_MEAN;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + ..\include;..\..\system\classlib\include;..\..\system\dprocess\include;..\..\system\common\include;..\..\system\channel\include;%(AdditionalIncludeDirectories) + + + true + Windows + true + true + MachineX64 + $(OutDir)$(TargetName)$(TargetExt) + DryadLINQNativeChannels.def + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/WrapperNativeInfoDll.vcxproj.filters b/DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/WrapperNativeInfoDll.vcxproj.filters new file mode 100644 index 0000000..a49afcc --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/WrapperNativeInfoDll.vcxproj.filters @@ -0,0 +1,32 @@ + + + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hpp;hxx;hm;inl;inc;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + + + Header Files + + + + + Source Files + + + + + Source Files + + + \ No newline at end of file diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/stdafx.h b/DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/stdafx.h new file mode 100644 index 0000000..b39f805 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/stdafx.h @@ -0,0 +1,25 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "DrExecution.h" +#include "XCompute.h" + diff --git a/DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/wrappernativeinfostubs.cpp b/DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/wrappernativeinfostubs.cpp new file mode 100644 index 0000000..ce43a5c --- /dev/null +++ b/DryadVertex/VertexHost/vertex/WrapperNativeInfoDll/wrappernativeinfostubs.cpp @@ -0,0 +1,168 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "stdafx.h" + +#include +#ifdef TIDYFS +#include +#endif + +#pragma unmanaged + +Int64 GetTotalLength(WrapperNativeInfoBase *info, UInt32 portNum) +{ + return info->GetTotalLength(portNum); +} + +Int64 GetExpectedLength(WrapperNativeInfoBase *info, UInt32 portNum) +{ + return info->GetExpectedLength(portNum); +} + +Int64 GetVertexId(WrapperNativeInfoBase *info) +{ + return info->GetVertexId(); +} + +void SetInitialSizeHint(WrapperNativeInfoBase *info, UInt32 portNum, UInt64 hint) +{ + return info->SetInitialSizeHint(portNum, hint); +} + +UInt32 GetNumOfInputs(WrapperNativeInfoBase *info) +{ + return info->GetNumOfInputs(); +} + +UInt32 GetNumOfOutputs(WrapperNativeInfoBase *info) +{ + return info->GetNumOfOutputs(); +} + +const char* GetInputChannelURI(WrapperNativeInfoBase *info, UInt32 portNum) +{ + return info->GetInputChannelURI(portNum); +} + +const char* GetOutputChannelURI(WrapperNativeInfoBase *info, UInt32 portNum) +{ + return info->GetOutputChannelURI(portNum); +} + +DataBlockItem* AllocateDataBlock(WrapperNativeInfoBase *info, + Int32 dataBlockSize, + byte **pDataBlock) +{ + return info->AllocateDataBlock(dataBlockSize, pDataBlock); +} + +void ReleaseDataBlock(WrapperNativeInfoBase *info, + DataBlockItem *pItem) +{ + info->ReleaseDataBlock(pItem); +} + +DataBlockItem* ReadDataBlock(WrapperNativeInfoBase *info, + UInt32 portNum, + byte **ppDataBlock, + Int32 *ppDataBlockSize, + Int32 *pErrorCode) +{ + return info->ReadDataBlock(portNum, ppDataBlock, ppDataBlockSize, pErrorCode); +} + +BOOL WriteDataBlock(WrapperNativeInfoBase *info, + UInt32 portNum, + DataBlockItem *pItem, + Int32 numBytesToWrite) +{ +#ifdef VERBOSE + fprintf(stdout, "Writing %d bytes to channel %u.\n", + numBytesToWrite, portNum); + fflush(stdout); +#endif + + return info->WriteDataBlock(portNum, pItem, numBytesToWrite); +} + +void Flush(WrapperNativeInfoBase *info, UInt32 portNum) +{ + // NYI +} + +void Close(WrapperNativeInfoBase *info, UInt32 portNum) +{ + // NYI +} + +/* JC +bool GetDryadStreamInfo(char *streamName, + UInt32 *numExtents, + UInt64 *streamLength) +{ +#ifdef VERBOSE + fprintf(stderr, "Getting Stream Info for stream '%s'\n", streamName); +#endif + // initialize the cosmos libraries + DrError err = DrInitialize(DR_CURRENT_VERSION, NULL); + if (err != DrError_OK) + { + return false; + } + + DR_STREAM *pStreamInfo; + err = DrGetStreamInformation(streamName, 0, 0, DR_ALL_EXTENTS, + &pStreamInfo, NULL); + if (err != DrError_OK) + { + return false; + } + + *numExtents = pStreamInfo->TotalNumberOfExtents; + *streamLength = pStreamInfo->Length; + + err = DrFreeMemory(pStreamInfo); + LogAssert(err == DrError_OK); + + err = DrUninitialize(); + LogAssert(err == DrError_OK); + +#ifdef VERBOSE + fprintf(stderr, "Returning stream info for stream '%s' of %u extents and %I64u bytes.\n", + streamName, *numExtents, *streamLength); +#endif + return true; +} +*/ + +void EnableFifoInputChannel(WrapperNativeInfo *info, + Int32 compresionScheme, + UInt32 channel) +{ + info->EnableFifoInputChannel(compresionScheme, channel); +} + +void EnableFifoOutputChannel(WrapperNativeInfo *info, + Int32 compressionScheme, + UInt32 channel) +{ + info->EnableFifoOutputChannel(compressionScheme, channel); +} diff --git a/DryadVertex/VertexHost/vertex/include/ChannelTransform.h b/DryadVertex/VertexHost/vertex/include/ChannelTransform.h new file mode 100644 index 0000000..3ef7077 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/include/ChannelTransform.h @@ -0,0 +1,39 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include +#include +class FifoChannel; + +class ChannelTransform +{ +public: + virtual ~ChannelTransform() = 0; + virtual DrError Start(FifoChannel *channel) = 0; + virtual void SetOutputBufferSize(UInt32 bufferSize) = 0; + virtual DrError ProcessItem(DataBlockItem *item) = 0; + virtual DrError Finish(bool atEndOfStream) = 0; +protected: + DryadVertexProgram *m_vertex; // used for error reporting +}; + diff --git a/DryadVertex/VertexHost/vertex/include/CompressionVertex.h b/DryadVertex/VertexHost/vertex/include/CompressionVertex.h new file mode 100644 index 0000000..ec31dea --- /dev/null +++ b/DryadVertex/VertexHost/vertex/include/CompressionVertex.h @@ -0,0 +1,46 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include + +/* A CompressionVertex is a vertex that will either compress or decompress it's input using + the gzip format. The vertex takes one argument, which is either compress or decompress. + The vertex can only be called with one output channel, but may be called with multiple + input channels. The requested operation is applied to each input channel in order. */ +class CompressionVertex : public DryadVertexProgram +{ +public: + CompressionVertex(); + + void Main(WorkQueue* workQueue, + UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel); + +private: + +}; + +typedef StdTypedVertexFactory FactoryCompressionVertexWrapper; + diff --git a/DryadVertex/VertexHost/vertex/include/DataBlockItem.h b/DryadVertex/VertexHost/vertex/include/DataBlockItem.h new file mode 100644 index 0000000..fa44571 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/include/DataBlockItem.h @@ -0,0 +1,73 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "DrCommon.h" +#include +#include +#include + +class DataBlockItem : public RChannelDataItem +{ +public: + DataBlockItem(DrMemoryBuffer* buf); + DataBlockItem(Size_t size); + ~DataBlockItem(); + Size_t GetAllocatedSize(); + Size_t GetAvailableSize(); + void SetAvailableSize(Size_t size); + void * GetDataAddress(); + DrMemoryBuffer * GetData(); + +private: + DrRef m_buf; +}; + +class DataBlockParser : public RChannelItemParser +{ +public: + DataBlockParser() {}; + + DataBlockParser(DObjFactoryBase* factory); + + RChannelItem* ParseNextItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + Size_t* pOutLength); + RChannelItem* ParsePartialItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + RChannelBufferMarker* + markerBuffer); + +}; + +typedef StdParserFactory DataBlockParserFactory; + +class DataBlockMarshaler : public RChannelItemMarshaler +{ +public: + DrError MarshalItem(ChannelMemoryBufferWriter* writer, + RChannelItem* item, + bool flush, + RChannelItemRef* pFailureItem); +}; +typedef StdMarshalerFactory DataBlockMarshalerFactory; + + diff --git a/DryadVertex/VertexHost/vertex/include/FifoChannel.h b/DryadVertex/VertexHost/vertex/include/FifoChannel.h new file mode 100644 index 0000000..cf60682 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/include/FifoChannel.h @@ -0,0 +1,69 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once +#include +#include +#include + +enum FifoReaderState { + RS_Stopped, + RS_OutstandingHandler, + RS_Stopping +}; + +class FifoChannelItemWriterHandler : public RChannelItemWriterHandler +{ +public: + virtual ~FifoChannelItemWriterHandler(); + virtual void ProcessWriteCompleted(RChannelItemType status, + RChannelItem* marshalFailureItem); +}; + +class FifoChannel : public RChannelItemWriterHandler, + public RChannelItemReaderHandlerQueued +{ +public: + FifoChannel(RChannelReader* reader, + RChannelWriter *writer, + DryadVertexProgram* vertex, + TransformType transType); + virtual ~FifoChannel(); + virtual void ProcessItem(RChannelItem* deliveredItem); + virtual void ProcessWriteCompleted(RChannelItemType status, + RChannelItem* marshalFailureItem); + void MaybeSendHandler(); + void Start(); + bool Stop(RChannelWriter* fifoWriter); + void Drain(bool mustWait); + void WriteTransformedItem(RChannelItem* transformedItem); +private: + + bool m_initialHandlerSent; + ChannelTransform *m_transform; + RChannelReader *m_reader; + RChannelWriter *m_writer; + CRITSEC m_critsec; + HANDLE m_shutdownEvent; + int m_numItemsInFlight; + FifoReaderState m_state; + FifoChannelItemWriterHandler *m_fifoWriterHandler; + DryadVertexProgram *m_vertex; +}; diff --git a/DryadVertex/VertexHost/vertex/include/FifoInputChannel.h b/DryadVertex/VertexHost/vertex/include/FifoInputChannel.h new file mode 100644 index 0000000..4353679 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/include/FifoInputChannel.h @@ -0,0 +1,47 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include + +class FifoInputChannel : public InputChannel +{ +public: + FifoInputChannel(UInt32 portNum, DryadVertexProgram* vertex, + WorkQueue *workQueue, RChannelReader* channel, + TransformType tType); + void Stop(); + virtual DataBlockItem* ReadDataBlock(byte **ppDataBlock, + Int32 *ppDataBlockSize, + Int32 *pErrorCode); + bool GetTotalLength(UInt64 *length); + bool GetExpectedLength(UInt64 *length); + const char* GetURI(); +private: + void MakeFifo(UInt32 fifoLength, WorkQueue* workQueue); + + bool m_initialHandlerSent; + FifoChannel *m_fifoChannel; + RChannelReaderHolderRef m_fifoReader; + RChannelWriterHolderRef m_fifoWriter; + RChannelReader* m_origReader; +}; diff --git a/DryadVertex/VertexHost/vertex/include/FifoOutputChannel.h b/DryadVertex/VertexHost/vertex/include/FifoOutputChannel.h new file mode 100644 index 0000000..5a01016 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/include/FifoOutputChannel.h @@ -0,0 +1,42 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once +#include +#include + +class FifoOutputChannel: public OutputChannel +{ +public: + FifoOutputChannel(UInt32 portNum, DryadVertexProgram* vertex, + WorkQueue *workQueue, RChannelWriter* outputChannel, + TransformType tType); + void Stop(); + + void SetInitialSizeHint(UInt64 hint); + const char* GetURI(); + +private: + FifoChannel *m_fifoChannel; + RChannelReaderHolderRef m_fifoReader; + RChannelWriterHolderRef m_fifoWriter; + RChannelWriter* m_origWriter; + void MakeFifo(UInt32 fifoLength, WorkQueue* workQueue); +}; diff --git a/DryadVertex/VertexHost/vertex/include/GzipCompressionChannelTransform.h b/DryadVertex/VertexHost/vertex/include/GzipCompressionChannelTransform.h new file mode 100644 index 0000000..85183a0 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/include/GzipCompressionChannelTransform.h @@ -0,0 +1,62 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once +#include +#include +#include +#ifndef Z_PREFIX +#define Z_PREFIX +#endif +#ifdef LINKWITHZLIB +#include "zlib.h" + + +class GzipCompressionChannelTransform : public ChannelTransform +{ +public: + GzipCompressionChannelTransform(DryadVertexProgram* vertex, bool gzipHeader, bool optimizeForSpeed); + virtual ~GzipCompressionChannelTransform(); + virtual DrError Start(FifoChannel *channel); + virtual void SetOutputBufferSize(UInt32 bufferSize); + virtual DrError ProcessItem(DataBlockItem *item); + virtual DrError Finish(bool atEndOfStream); + + private: + void AllocateOutputBuffer(); + void AppendOutputByte(byte value); + void IncrementOutputPosition(UInt32 increment); + DrError ProcessBlock(); + void WriteOutputBuffer(); + void WriteGzipHeader(); + + bool m_firstReadProcessed; + UInt32 m_crc; + int m_zlibArg; + byte *m_crcStart; + FifoChannel *m_channel; + UInt32 m_writeSize; + z_stream m_stream; + DrRef m_outputBuffer; + bool m_gzipHeader; + bool m_optimizeForSpeed; + CRITSEC m_critsec; +}; +#endif \ No newline at end of file diff --git a/DryadVertex/VertexHost/vertex/include/GzipDecompressionChannelTransform.h b/DryadVertex/VertexHost/vertex/include/GzipDecompressionChannelTransform.h new file mode 100644 index 0000000..d209555 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/include/GzipDecompressionChannelTransform.h @@ -0,0 +1,62 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once +#include +#include +#include +#ifndef Z_PREFIX +#define Z_PREFIX +#endif + +#ifdef LINKWITHZLIB +#include "zlib.h" + +class GzipDecompressionChannelTransform : public ChannelTransform +{ +public: + GzipDecompressionChannelTransform(DryadVertexProgram * vertex, bool gzipHeader); + virtual ~GzipDecompressionChannelTransform(); + virtual DrError Start(FifoChannel *channel); + virtual void SetOutputBufferSize(UInt32 bufferSize); + virtual DrError ProcessItem(DataBlockItem *item); + virtual DrError Finish(bool atEndOfStream); + + private: + inline DrError IncrementInputPosition(UInt32 increment); + DrError ReadGzipHeader(); + DrError AllocateOutputBuffer(); + void WriteOutputBuffer(); + DrError ProcessTrailer(); + void DecompressionError(); + + static const int TrailerLength = 8; + byte *m_trailerBytes; + UInt32 m_trailerBytesFilled; + bool m_firstReadProcessed; + UInt32 m_crc; + byte *m_crcStart; + FifoChannel *m_channel; + Size_t m_writeSize; + z_stream m_stream; + DrRef m_outputBuffer; + bool m_gzipHeader; +}; +#endif diff --git a/DryadVertex/VertexHost/vertex/include/InputChannel.h b/DryadVertex/VertexHost/vertex/include/InputChannel.h new file mode 100644 index 0000000..5b1d19f --- /dev/null +++ b/DryadVertex/VertexHost/vertex/include/InputChannel.h @@ -0,0 +1,52 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once +#include +#include +#include + +class InputChannel +{ +public: + InputChannel(UInt32 portNum, + DryadVertexProgram* vertex, + RChannelReader* channel); + virtual ~InputChannel(); + + virtual void Stop(); + + virtual DataBlockItem* ReadDataBlock(byte **ppDataBlock, + Int32 *ppDataBlockSize, + Int32 *pErrorCode); + virtual bool AtEndOfChannel(); + Int64 GetBytesRead(); + RChannelReader* GetReader(); + virtual bool GetTotalLength(UInt64 *length); + virtual bool GetExpectedLength(UInt64 *length); + virtual const char* GetURI(); + +protected: + RChannelReader* m_reader; + DryadVertexProgram* m_vertex; + Int64 m_bytesRead; + UInt32 m_portNum; + bool m_atEOC; +}; diff --git a/DryadVertex/VertexHost/vertex/include/ManagedWrapper.h b/DryadVertex/VertexHost/vertex/include/ManagedWrapper.h new file mode 100644 index 0000000..305e770 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/include/ManagedWrapper.h @@ -0,0 +1,49 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include +#include +#include + +class ManagedWrapperVertex : public DryadVertexProgram +{ +public: + ManagedWrapperVertex(); + + void Main(WorkQueue* workQueue, + UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel); + +private: + static ICLRRuntimeHost *pClrHost; // we only need one per process + static DrCriticalSection m_atomic; // loading the CLR should be atomic + + // Use DryadLinqLog to log messages from the DryadLinqRuntime + // These messages will be send both to the logging infrastructure, but also to the vertex standard error. + void DryadLinqLog(LogLevel level, const char* title, const char* fmt, va_list args); +}; + +typedef StdTypedVertexFactory FactoryMWrapper; + diff --git a/DryadVertex/VertexHost/vertex/include/NullChannelTransform.h b/DryadVertex/VertexHost/vertex/include/NullChannelTransform.h new file mode 100644 index 0000000..cdb7fe5 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/include/NullChannelTransform.h @@ -0,0 +1,36 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once +#include +#include + +class NullChannelTransform : public ChannelTransform +{ +public: + NullChannelTransform(DryadVertexProgram* vertex); + virtual ~NullChannelTransform(); + virtual DrError Start(FifoChannel *channel); + virtual void SetOutputBufferSize(UInt32 bufferSize); + virtual DrError ProcessItem(DataBlockItem *item); + virtual DrError Finish(bool atEndOfStream); +private: + FifoChannel *m_channel; +}; diff --git a/DryadVertex/VertexHost/vertex/include/OutputChannel.h b/DryadVertex/VertexHost/vertex/include/OutputChannel.h new file mode 100644 index 0000000..2e33b87 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/include/OutputChannel.h @@ -0,0 +1,51 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once +#include +#include +#include + +class OutputChannel : public RChannelItemWriterHandler +{ +public: + OutputChannel(UInt32 portNum, + DryadVertexProgram* vertex, + RChannelWriter* outputChannel); + virtual ~OutputChannel(); + + virtual void Stop(); + + BOOL WriteDataBlock(DataBlockItem *pItem, + Int32 numBytesToWrite); + void ProcessWriteCompleted(RChannelItemType status, + RChannelItem* marshalFailureItem); + Int64 GetBytesWritten(); + virtual void SetInitialSizeHint(UInt64 hint); + RChannelWriter* GetWriter(); + virtual const char* GetURI(); + +protected: + RChannelWriter* m_writer; + DryadVertexProgram* m_vertex; + Int64 m_bytesWritten; + UInt32 m_portNum; +}; + diff --git a/DryadVertex/VertexHost/vertex/include/wrappernativeinfo.h b/DryadVertex/VertexHost/vertex/include/wrappernativeinfo.h new file mode 100644 index 0000000..d306264 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/include/wrappernativeinfo.h @@ -0,0 +1,150 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +class WrapperNativeInfoBase +{ +public: + virtual ~WrapperNativeInfoBase(); + + virtual void CleanUp() = 0; + + virtual DataBlockItem* AllocateDataBlock(Int32 dataBlockSize, + byte **pDataBlock) = 0; + virtual void ReleaseDataBlock(DataBlockItem *pItem) = 0; + + virtual DataBlockItem* ReadDataBlock(UInt32 portNum, + byte **ppDataBlock, + Int32 *ppDataBlockSize, + Int32 *pErrorCode) = 0; + virtual BOOL WriteDataBlock(UInt32 portNum, + DataBlockItem *pData, + Int32 numBytesToWrite) = 0; + virtual Int64 GetTotalLength(UInt32 portNum) = 0; + virtual Int64 GetExpectedLength(UInt32 portNum) = 0; + virtual Int64 GetVertexId() = 0; + virtual void SetInitialSizeHint(UInt32 portNum, UInt64 hint) = 0; + virtual UInt32 GetNumOfInputs() = 0; + virtual UInt32 GetNumOfOutputs() = 0; + virtual const char* GetInputChannelURI(UInt32 portNum) = 0; + virtual const char* GetOutputChannelURI(UInt32 portNum) = 0; + virtual void EnableFifoInputChannel(int compresionScheme, + UInt32 channel) = 0; + virtual void EnableFifoOutputChannel(int compresionScheme, + UInt32 channel) = 0; +}; + +class WrapperNativeInfo : public WrapperNativeInfoBase +{ +public: + WrapperNativeInfo(UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel, + DryadVertexProgram* vertex, + WorkQueue* workQueue); + + ~WrapperNativeInfo(); + void CleanUp(); + + DataBlockItem* AllocateDataBlock(Int32 dataBlockSize, + byte **pDataBlock); + void ReleaseDataBlock(DataBlockItem *pItem); + + DataBlockItem* ReadDataBlock(UInt32 portNum, + byte **ppDataBlock, + Int32 *ppDataBlockSize, + Int32 *pErrorCode); + BOOL WriteDataBlock(UInt32 portNum, + DataBlockItem *pData, + Int32 numBytesToWrite); + Int64 GetTotalLength(UInt32 portNum); + Int64 GetExpectedLength(UInt32 portNum); + Int64 GetVertexId(); + void SetInitialSizeHint(UInt32 portNum, UInt64 hint); + const char* GetInputChannelURI(UInt32 portNum); + const char* GetOutputChannelURI(UInt32 portNum); + UInt32 GetNumOfInputs(); + UInt32 GetNumOfOutputs(); + void EnableFifoInputChannel(int compresionScheme, UInt32 channel); + void EnableFifoOutputChannel(int compresionScheme, UInt32 channel); + +private: + UInt32 m_numberOfInputChannels; + InputChannel** m_inputChannels; + UInt32 m_numberOfOutputChannels; + OutputChannel** m_outputChannels; + DryadVertexProgram* m_vertex; + WorkQueue* m_workQueue; +}; + +extern "C" { + UInt32 GetNumOfInputs(WrapperNativeInfoBase *info); + UInt32 GetNumOfOutputs(WrapperNativeInfoBase *info); + Int64 GetTotalLength(WrapperNativeInfoBase *info, UInt32 portNum); + Int64 GetExpectedLength(WrapperNativeInfoBase *info, UInt32 portNum); + Int64 GetVertexId(WrapperNativeInfoBase *info); + void SetInitialSizeHint(WrapperNativeInfoBase *info, UInt32 portNum, UInt64 hint); + const char* GetInputChannelURI(WrapperNativeInfoBase *info, UInt32 portNum); + const char* GetOutputChannelURI(WrapperNativeInfoBase *info, UInt32 portNum); + void Flush(WrapperNativeInfoBase *info, UInt32 portNum); + void Close(WrapperNativeInfoBase *info, UInt32 portNum); + + // The compressionScheme argument allows selection of + // three different tranformations of the channel data + // The values are: + // 0 - No transform, just passthrough + // 1 - gzip compression or decompression + void EnableFifoInputChannel(WrapperNativeInfoBase *info, + Int32 compresionScheme, UInt32 channel); + void EnableFifoOutputChannel(WrapperNativeInfoBase *info, + Int32 compresionScheme, UInt32 channel); + + DataBlockItem* ReadDataBlock(WrapperNativeInfoBase *info, + UInt32 portNum, + byte **ppDataBlock, + Int32 *ppDataBlockSize, + Int32 *pErrorCode); + + // The data block should be considered read only after WriteDataBlock + // has been called. The data block will not be reclaimed until the client + // explicitly releases it. + BOOL WriteDataBlock(WrapperNativeInfoBase *info, + UInt32 portNum, + DataBlockItem *pItem, + Int32 numBytesToWrite); + + DataBlockItem* AllocateDataBlock(WrapperNativeInfoBase *info, + Int32 dataBlockSize, + byte **pDataBlock); + + void ReleaseDataBlock(WrapperNativeInfoBase *info, + DataBlockItem *pItem); + + bool GetDryadStreamInfo(char *streamName, UInt32 *numExtents, UInt64 *streamLength); + +} diff --git a/DryadVertex/VertexHost/vertex/managedwrappervertex/DataBlockItem.cpp b/DryadVertex/VertexHost/vertex/managedwrappervertex/DataBlockItem.cpp new file mode 100644 index 0000000..5d68ddb --- /dev/null +++ b/DryadVertex/VertexHost/vertex/managedwrappervertex/DataBlockItem.cpp @@ -0,0 +1,122 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "stdafx.h" + +#include + +DataBlockItem::DataBlockItem(DrMemoryBuffer* buf) +{ + m_buf = buf; +} + +DataBlockItem::DataBlockItem(Size_t size) +{ + m_buf.Attach(new DrSimpleHeapBuffer(size)); +} + +DataBlockItem::~DataBlockItem() +{ + m_buf.Release(); +} + +Size_t DataBlockItem::GetAllocatedSize() +{ + LogAssert(m_buf.Ptr() != NULL); + return m_buf->GetAllocatedSize(); +} + + +Size_t DataBlockItem::GetAvailableSize() +{ + LogAssert(m_buf.Ptr() != NULL); + return m_buf->GetAvailableSize(); +} + +void DataBlockItem::SetAvailableSize(Size_t size) +{ + LogAssert(m_buf.Ptr() != NULL); + m_buf->SetAvailableSize(size); +} + +void * DataBlockItem::GetDataAddress() +{ + Size_t size = 0; + LogAssert(m_buf.Ptr() != NULL); + void *addr = m_buf->GetDataAddress(0, &size, NULL); + return addr; +} + +DrMemoryBuffer* DataBlockItem::GetData() +{ + LogAssert(m_buf.Ptr() != NULL); + return m_buf.Ptr(); +} + +DataBlockParser::DataBlockParser(DObjFactoryBase* factory) : + RChannelItemParser() +{ + // nothing needed now +} + + + +RChannelItem* DataBlockParser::ParseNextItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + Size_t* pOutLength) +{ + + // startOffset should always be zero + LogAssert(startOffset == 0); + + RChannelBufferData* headBuffer = + bufferList->CastOut(bufferList->GetHead()); + *pOutLength = headBuffer->GetData()->GetAvailableSize(); + + return new DataBlockItem(headBuffer->GetData()); +} + +RChannelItem* DataBlockParser::ParsePartialItem(ChannelDataBufferList* bufferList, + Size_t startOffset, + RChannelBufferMarker* + markerBuffer) +{ + + // NYI + return NULL; +} + +DrError DataBlockMarshaler::MarshalItem(ChannelMemoryBufferWriter* writer, + RChannelItem* item, + bool flush, + RChannelItemRef* pFailureItem) +{ + if (item->GetType() == RChannelItem_Data) + { + DataBlockItem* bufferItem = (DataBlockItem*) item; + + DrMemoryBuffer* buffer = bufferItem->GetData(); + return writer->WriteBytesFromBuffer(buffer, true); + } + else + { + return DrError_OK; + } +} diff --git a/DryadVertex/VertexHost/vertex/managedwrappervertex/ManagedWrapperVertex.cpp b/DryadVertex/VertexHost/vertex/managedwrappervertex/ManagedWrapperVertex.cpp new file mode 100644 index 0000000..519329b --- /dev/null +++ b/DryadVertex/VertexHost/vertex/managedwrappervertex/ManagedWrapperVertex.cpp @@ -0,0 +1,357 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "stdafx.h" + +#include +#include + +#include +#include +#include +#include + + +#pragma managed + +#pragma warning(disable:4947) // so that we can use Assembly::LoadWithPartialName() + +static DataBlockParserFactory s_FactoryDataBlockParser; +static DataBlockMarshalerFactory s_FactoryDataBlockMarshaler; + +DrCriticalSection ManagedWrapperVertex::m_atomic = DrCriticalSection("ManagedWrapperVertex"); + + +ManagedWrapperVertex::ManagedWrapperVertex() +{ + SetCommonParserFactory(&s_FactoryDataBlockParser); + SetCommonMarshalerFactory(&s_FactoryDataBlockMarshaler); +} + +FactoryMWrapper s_factoryHWrapper("MW"); + +// +// Convert ANSI string to WCHAR* +// +LPCWSTR makeWStr(const char * ansiStr) +{ + int lenA = lstrlenA(ansiStr); + int lenW; + LPWSTR unicodeStr; + + // + // Call MultiByteToWideChar once with 0 for last arg, to get wchar length of converted string. + // + lenW = ::MultiByteToWideChar(CP_UTF8, 0, ansiStr, lenA, 0, 0); + + // + // Check conversion was successful. + // + LogAssert(lenW > 0); + unicodeStr = ::SysAllocStringLen(0, lenW); + ::MultiByteToWideChar(CP_UTF8, 0, ansiStr, lenA, unicodeStr, lenW); + return unicodeStr; +} + +// +// Converts the current local directory to a UNC path by correlating it to the actual nw share +// Assumes the current directory is under a NW share, specified by pwszNWShareName. +// +// Examples for conversion: +// pwszNWShareName = "HPCTEMP" (shared from c:\HPCTEMP) +// actual CWD = "c:\HPCTEMP\USERNAME\1234\56" +// returned value "\\hostname\HPCTEMP\USERNAME\1234\56" +// +//or +// pwszNWShareName = L"HPCTEMP" (shared from c:\FOO\BAR) +// actual CWD = "c:\FOO\BAR\USERNAME\1234\56" +// returned value "\\hostname\HPCTEMP\USERNAME\1234\56" +// +BOOL ConvertCurrentDirToUNCPath(WCHAR *pwszNWShareName, WCHAR *pwszUncPath, DWORD dwUncPathLen) +{ + BOOL bSuccess = FALSE; + PSHARE_INFO_2 pShareInfoBuf = NULL; + do + { + // First get the local path from which \\hostname\HPCTEMP is shared from + if( NetShareGetInfo(NULL, pwszNWShareName, 2, (LPBYTE*) &pShareInfoBuf) != ERROR_SUCCESS) + break; + + WCHAR *pwszShareLocalPath = pShareInfoBuf->shi2_path; + + WCHAR wszCurrentDir[MAX_PATH+1]; + ZeroMemory(wszCurrentDir, sizeof(wszCurrentDir)); + DWORD dwCurrentDirLen = GetCurrentDirectory(_countof(wszCurrentDir), wszCurrentDir); // this should give us something like "c:\hpctemp\\\" + if (dwCurrentDirLen == 0) + break; + + DWORD dwShareLocalPathLen = wcslen(pwszShareLocalPath); + + if( dwShareLocalPathLen >= dwCurrentDirLen) // current directory must be longer than share local path, otherwise something is off. + break; + + // convert everything to upper case before comparisons + _wcsupr(pwszShareLocalPath); + _wcsupr(wszCurrentDir); + + // search for the share local path as a substring of current dir + WCHAR *pPos = wcsstr(wszCurrentDir, pwszShareLocalPath); + if( pPos != wszCurrentDir ) // current directory must contain share local path at its starting position, if not cwd is not under the share. + break; + + // now everything checks out, we can truncate the current dir to get the part that goes after the UNC share. + WCHAR *pwszTruncatedCurrentDir = wszCurrentDir + dwShareLocalPathLen; + + // uncomment for DNS hostname support + //WCHAR wszComputerName[DNS_MAX_LABEL_BUFFER_LENGTH]; + WCHAR wszComputerName[MAX_COMPUTERNAME_LENGTH+1]; + DrError drErr = DrGetComputerName(wszComputerName) ; + if(drErr != DrError_OK) + { + break; + } + + // and do the final formatting to produce the UNC path + if (_snwprintf(pwszUncPath, dwUncPathLen, L"\\\\%s\\%s%s", wszComputerName, pwszNWShareName, pwszTruncatedCurrentDir) <= 0) + break; + + bSuccess = TRUE; + } + while(FALSE); + + // we need to free all buffers returned by Net* APIs + if (pShareInfoBuf != NULL) + { + NetApiBufferFree(pShareInfoBuf); + } + + return bSuccess; +} + + + +// +// Invoke user vertex code +// +void ManagedWrapperVertex::Main(WorkQueue* workQueue, + UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel) +{ + DrLogI("Starting ManagedWrapperVertex Main with %u arguments", GetArgumentCount()); + LogAssert(GetArgumentCount() >= 4); + + // + // Create an object encapsulating all the native stuff: + // + WrapperNativeInfo *nativeInfo = + new WrapperNativeInfo(numberOfInputChannels, + inputChannel, + numberOfOutputChannels, + outputChannel, + this, workQueue); + + // + // Set the compression mode for all the channels + // + for (UInt32 i = 0; i < numberOfInputChannels; i++) + { + int tt = static_cast(inputChannel[i]->GetTransformType()); + if (tt != 0) // only enable the FIFO channels if we actually have a transform + { + nativeInfo->EnableFifoInputChannel(tt, i); + } + } + + for (UInt32 i = 0; i < numberOfOutputChannels; i++) + { + int tt = static_cast(outputChannel[i]->GetTransformType()); + if (tt != 0) // only enable the FIFO channels if we actually have a transform + { + nativeInfo->EnableFifoOutputChannel(tt, i); + } + } + + DrLogI("ManagedWrapperVertex: %p %u %u", nativeInfo, numberOfInputChannels, numberOfOutputChannels); + DrLogI("ManagedWrapperVertex: Calling %s.%s", GetArgument(2), GetArgument(3)); + DrLogging::FlushLog(); + + + DrStr128 errorMsg; + DrError error; + + { + // + // Instead of invoking the vertex entry point directly from here, we delegate it to the bridge method in the Microsoft.Hpc.Linq assembly, specifically: + // static int Microsoft.Hpc.Linq.Internal.VertexEnv.VertexBridge(string vertexBridgeArgs) + // + // This indirect method of invoking the vertex entry point is used so that any type load / assembly load problems coming from user code + // can be caught and reported with full details using the same mechanism that other vertex failures go through (exception dumped into vertexexception.txt etc.) + // + // The format of vertexBridgeArgs is simply a comma separated string packing vertex assembly, class, method name, and the *actual* vertex method args (==the native channel string) + // L",,," + // + System::String ^bridgeAssemblyPartialName = gcnew System::String(L"Microsoft.Research.DryadLinq"); + System::String ^bridgeClassName = gcnew System::String(L"Microsoft.Research.DryadLinq.Internal.HpcLinqVertexEnv"); + System::String ^bridgeMethodName = gcnew System::String(L"VertexBridge"); + + // + // Construct the actual vertex methods args from the native information (the "native channel string") + // + System::Text::StringBuilder ^vertexMethodArgs = gcnew System::Text::StringBuilder(); + System::IntPtr ^nativeInfoIntPtr = gcnew System::IntPtr((void*) nativeInfo); + + vertexMethodArgs->Append(nativeInfoIntPtr->ToString("X")); //use hex format, because that's what the vertex env uses when converting it back to a handle + for (UInt32 i = 4; i < GetArgumentCount(); i++) + { + vertexMethodArgs->Append(L"|"); + DrStr64 arg(GetArgument(i)); + vertexMethodArgs->Append(gcnew System::String(arg.GetString())); + } + + + // + // Get assembly path, class name, and method name, and construct the vertex bridge args with the following format: + // ",,," + // + System::Text::StringBuilder ^vertexBridgeArg = gcnew System::Text::StringBuilder(); + vertexBridgeArg->Append(gcnew System::String(GetArgument(1))); // path to vertex DLL as passed to the vertex host, e.g. L"c:\\HpcTemp\\user\\jobID\\Microsoft.Hpc.Linq0.dll"; + vertexBridgeArg->Append(","); + vertexBridgeArg->Append(gcnew System::String(GetArgument(2))); // full name of class that contains vertex entry method, e.g. L"Microsoft.Hpc.Linq.HpcLinq__Vertex"; + vertexBridgeArg->Append(","); + vertexBridgeArg->Append(gcnew System::String(GetArgument(3))); // vertex entry method name L"Select__1"; + vertexBridgeArg->Append(","); + vertexBridgeArg->Append(vertexMethodArgs->ToString()); + + + DrLogI("ManagedWrapperVertex: Calling into Vertex Bridge to invoke Vertex Entry: %s", GetArgument(3)); + DrLogging::FlushLog(); + + HRESULT hr = S_OK; + + // + // Now that we have everything ready, we can invoke vertex bridge using reflection + // + try + { + System::Reflection::Assembly ^vertexBridgeAsm = System::Reflection::Assembly::LoadWithPartialName(bridgeAssemblyPartialName); + System::Type ^vertexBridgeType = vertexBridgeAsm->GetType(gcnew System::String(bridgeClassName)); + System::Reflection::MethodInfo ^vertexBridgeMethod = vertexBridgeType->GetMethod(gcnew System::String(bridgeMethodName), + static_cast(System::Reflection::BindingFlags::NonPublic | + System::Reflection::BindingFlags::Static)); + + cli::array ^invokeArgs = gcnew array(1); + invokeArgs[0] = vertexBridgeArg->ToString(); + + vertexBridgeMethod->Invoke(nullptr, invokeArgs); + } + catch(System::Exception ^ex) + { + hr = System::Runtime::InteropServices::Marshal::GetHRForException(ex); + + if (hr == S_OK) + { + // if for some reason GetHRForException() mistakenly returned S_OK we want to make sure we don't skip the failure handling path below + hr = E_FAIL; + } + } + + // + // Flush stdout to make sure all LINQ logs are written out + // + fflush(stdout); + + if (hr != S_OK) + { + // + // Log errors. + // + DrLogE("ManagedWrapperVertex: Assembly path = %s", GetArgument(1)); + DrLogE("ManagedWrapperVertex: Class name = %s", GetArgument(2)); + DrLogE("ManagedWrapperVertex: Method name = %s", GetArgument(3)); + + error = (DrError)hr; + errorMsg.Set("Error returned from managed runtime invocation, "); + errorMsg.Append(DRERRORSTRING(error)); + errorMsg.Append("\n"); + DrLogE( "Error returned from managed runtime invocation. %s (%d)", DRERRORSTRING(error), error); + + // + // Prepare the error message that will be sent over to the GM, and eventually displayed in the HPC console if this is the last vertex to fail the job + // + { + WCHAR *pwszHpcTempShare = L"HPCTEMP"; + + WCHAR wszUncPath[MAX_PATH+1]; + ZeroMemory(wszUncPath, sizeof(wszUncPath)); + + if(ConvertCurrentDirToUNCPath(pwszHpcTempShare, wszUncPath, _countof(wszUncPath)) == TRUE) + { + errorMsg.Append("For vertex logs, exception dump and rerun batch files see working directory for failed vertex:\n\n"); + errorMsg.Append(wszUncPath); + errorMsg.Append("\n\n"); + } + + // + // Open file containing the exception dump produced by vertex code. + // The reason we read it out of the file instead of extracting straight from the excetpion we caught above is that + // the HPCLINQ runtime (vertexbridge or other code paths in Microsoft.Hpc.Linq.DLL) will report the inner exception in some cases. + // So we'll trust it to find the exception layer which is most relevant for a user looking at the HPC console for initial diagnosis. + // + FILE* errorFile = fopen("VertexException.txt", "r"); + if (errorFile) + { + errorMsg.Append("The following callstack was reported as cause of failure:\n"); + char line[1024]; + while (fgets(line, _countof(line), errorFile) != NULL) + { + errorMsg.Append(line); + } + + fclose(errorFile); + errorMsg.Append("\n"); + } + else + { + // + // If no error exists, report that + // + errorMsg.Append("No error stack trace reported by managed code. See VertexHostLog.txt in vertex working directory for failure information.\n"); + } + } + + // + // Record error information + // todo: make sure nativeInfo cleaned up + // + ReportError(error, "%s", errorMsg.GetString()); + return; + } + } + + // + // Report success, clean up native Info + // + DrLogI("ManagedWrapperVertex: Cleaning up NativeInfo at %p", nativeInfo); + nativeInfo->CleanUp(); + delete nativeInfo; + return; +} diff --git a/DryadVertex/VertexHost/vertex/managedwrappervertex/ManagedWrapperVertex.vcxproj b/DryadVertex/VertexHost/vertex/managedwrappervertex/ManagedWrapperVertex.vcxproj new file mode 100644 index 0000000..125934c --- /dev/null +++ b/DryadVertex/VertexHost/vertex/managedwrappervertex/ManagedWrapperVertex.vcxproj @@ -0,0 +1,150 @@ + + + + + Debug + Win32 + + + Debug + x64 + + + Release + Win32 + + + Release + x64 + + + + {BDEDD3BB-C7E2-498F-A212-F99786C8E23C} + ManagedWrapperVertex + Win32Proj + + + + StaticLibrary + + + StaticLibrary + + + StaticLibrary + Unicode + true + + + StaticLibrary + Unicode + true + + + + + + + + + + + + + + + + + + + <_ProjectFileVersion>10.0.40219.1 + Debug\ + Debug\ + $(Platform)\$(Configuration)\ + $(Platform)\$(Configuration)\ + Release\ + Release\ + $(Platform)\$(Configuration)\ + $(Platform)\$(Configuration)\ + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + + + + Disabled + WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions) + true + EnableFastChecks + MultiThreadedDebugDLL + + + Level3 + EditAndContinue + + + + + X64 + + + Disabled + ..\include;..\..\system\classlib\include;..\..\system\channel\include;..\..\system\dprocess\include;..\..\system\common\include;..\zlib;%(AdditionalIncludeDirectories) + WIN32;_DEBUG;_LIB;WIN32_LEAN_AND_MEAN;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + false + Default + MultiThreadedDebugDLL + + + Level3 + ProgramDatabase + + + + + WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + + + + + X64 + + + WIN32;NDEBUG;_LIB;WIN32_LEAN_AND_MEAN;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + ..\include;..\..\system\classlib\include;..\..\system\channel\include;..\..\system\dprocess\include;..\..\system\common\include;..\zlib;%(AdditionalIncludeDirectories) + true + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/DryadVertex/VertexHost/vertex/managedwrappervertex/ManagedWrapperVertex.vcxproj.filters b/DryadVertex/VertexHost/vertex/managedwrappervertex/ManagedWrapperVertex.vcxproj.filters new file mode 100644 index 0000000..9f9b3cc --- /dev/null +++ b/DryadVertex/VertexHost/vertex/managedwrappervertex/ManagedWrapperVertex.vcxproj.filters @@ -0,0 +1,35 @@ + + + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hpp;hxx;hm;inl;inc;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + + + Header Files + + + + + Header Files + + + + + Source Files + + + Source Files + + + \ No newline at end of file diff --git a/DryadVertex/VertexHost/vertex/managedwrappervertex/stdafx.h b/DryadVertex/VertexHost/vertex/managedwrappervertex/stdafx.h new file mode 100644 index 0000000..b39f805 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/managedwrappervertex/stdafx.h @@ -0,0 +1,25 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "DrExecution.h" +#include "XCompute.h" + diff --git a/DryadVertex/VertexHost/vertex/vertexHost/VertexHost.vcxproj b/DryadVertex/VertexHost/vertex/vertexHost/VertexHost.vcxproj new file mode 100644 index 0000000..67bdcc6 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/vertexHost/VertexHost.vcxproj @@ -0,0 +1,200 @@ + + + + + Debug + Win32 + + + Debug + x64 + + + Release + Win32 + + + Release + x64 + + + + {0CF3D1D5-9BBE-4175-979B-EC6138EF4F37} + VertexHost + Win32Proj + + + + Application + + + Application + + + Application + Unicode + true + + + Application + Unicode + true + + + + + + + + + + + + + + + + + + + <_ProjectFileVersion>10.0.40219.1 + Debug\ + Debug\ + true + ..\..\..\..\bin\$(Configuration)\ + $(Platform)\$(Configuration)\ + true + Release\ + Release\ + true + ..\..\..\..\bin\$(Configuration)\ + $(Platform)\$(Configuration)\ + false + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + + + + Disabled + WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + EnableFastChecks + MultiThreadedDebugDLL + + + Level3 + EditAndContinue + + + true + Console + MachineX86 + + + + + X64 + + + Disabled + ..\include;..\..\system\classlib\include;..\..\system\channel\include;..\..\system\dprocess\include;..\..\system\common\include;..\zlib;%(AdditionalIncludeDirectories) + WIN32;_DEBUG;_CONSOLE;WIN32_LEAN_AND_MEAN;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + false + Default + MultiThreadedDebugDLL + + + Level3 + ProgramDatabase + + + MSCorEE.lib;Netapi32.lib;oleaut32.lib;ws2_32.lib;%(AdditionalDependencies) + ..\..\system\common\$(Platform)\$(Configuration);..\..\system\dprocess\$(Platform)\$(Configuration);..\..\system\classlib\$(Platform)\$(Configuration);..\..\system\channel\$(Platform)\$(Configuration);%(AdditionalLibraryDirectories) + true + true + Console + MachineX64 + + + + + WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + + + true + Console + true + true + MachineX86 + + + + + X64 + + + WIN32;NDEBUG;_CONSOLE;WIN32_LEAN_AND_MEAN;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + ..\include;..\..\system\classlib\include;..\..\system\channel\include;..\..\system\dprocess\include;..\..\system\common\include;..\zlib;%(AdditionalIncludeDirectories) + + + true + Console + true + true + MachineX64 + MSCorEE.lib;Netapi32.lib;oleaut32.lib;ws2_32.lib;%(AdditionalDependencies) + + + + + + + + {95fbf9b7-9407-4554-a74a-3527839bd1b6} + + + {e092e2b9-d3c9-4ce2-8201-bda442574c97} + + + {482e0741-e244-4974-97d4-3a7167581e91} + + + {016e71d3-9a6f-425c-ab4f-8c5edeffe7fa} + + + {57663b94-e11b-431e-be4b-e2c61112dec5} + + + {aa529122-f51c-48d7-a8c1-c0b24f570885} + + + {bdedd3bb-c7e2-498f-a212-f99786c8e23c} + + + {ab9ea66c-5811-49a7-b002-24203aeb9083} + + + + + + \ No newline at end of file diff --git a/DryadVertex/VertexHost/vertex/vertexHost/VertexHost.vcxproj.filters b/DryadVertex/VertexHost/vertex/vertexHost/VertexHost.vcxproj.filters new file mode 100644 index 0000000..97a7c02 --- /dev/null +++ b/DryadVertex/VertexHost/vertex/vertexHost/VertexHost.vcxproj.filters @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/DryadVertex/VertexHost/vertex/vertexHost/vertexHost.cpp b/DryadVertex/VertexHost/vertex/vertexHost/vertexHost.cpp new file mode 100644 index 0000000..5cd8f0d --- /dev/null +++ b/DryadVertex/VertexHost/vertex/vertexHost/vertexHost.cpp @@ -0,0 +1,526 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// Includes +// +#include "dvertexmain.h" +#include "managedwrapper.h" +#include "recorditem.h" + +#pragma managed + +// +// Managed Wrapper vertex factory. +// +extern FactoryMWrapper s_factoryHWrapper; + +// +// Output stream files +// +FILE* g_oldStdout = NULL; +FILE* g_oldStderr = NULL; + +// +// Attempts to open stdout.txt and stderr.txt in current directory +// If successful, redirects output streams to these files. +// If unsucessful, no error, just retains original stdout/stderr streams +// _wfreopen_s locks the logs so that they cannot be read while the vertex is running. +// +#pragma warning (disable: 4996) // _wfreopen : This function may be unsafe, consider using _wfreopen_s +static void RedirectOutputStreams() +{ + WCHAR szCurrentDir[MAX_PATH + 1] = {0}; + WCHAR szStdout[MAX_PATH + 1] = {0}; + WCHAR szStderr[MAX_PATH + 1] = {0}; + + if (GetCurrentDirectory(MAX_PATH, szCurrentDir) != 0) + { + if (S_OK == StringCchPrintf(szStdout, MAX_PATH, L"%s\\stdout.txt", szCurrentDir)) + { + g_oldStdout = _wfreopen(szStdout, L"w", stdout); + } + if (S_OK == StringCchPrintf(szStderr, MAX_PATH, L"%s\\stderr.txt", szCurrentDir)) + { + g_oldStderr = _wfreopen(szStderr, L"w", stderr); + } + } +} + +// +// Eliminates arguments used by previous operations. +// Arguments left are those not yet used. +// +static void EliminateArguments(int* pArgc, char* argv[], + int startingLocation, int numberToRemove) +{ + // + // Get total number of arguments. Check if there number to remove makes sense. + // + int argc = *pArgc; + LogAssert(argc >= startingLocation+numberToRemove); + + int i; + for (i=startingLocation+numberToRemove; i 0) + { + // + // Create argument storage and verify that it's correctly created + // + newArgv = new char *[(size_t)argc]; + LogAssert(newArgv != NULL); + + // + // Foreach argument + // + for (int i = 0; i < argc; i++) + { + if (wargv[i] == NULL) + { + // + // If NULL argument, store NULL + // + newArgv[i] = NULL; + } + else + { + // + // Get argument value and verify that it's a valid string + // + strArg.Set( wargv[i] ); + LogAssert(strArg.GetString() != NULL); + + // + // Copy argument value into list + // + newArgv[i] = new char[strArg.GetLength() + 1]; + LogAssert(newArgv[i] != NULL); + memcpy(newArgv[i], strArg.GetString(), strArg.GetLength() + 1); + } + } + } + + // + // Store argument list (can be null if no arguments) + // + *pargv = newArgv; +} + +// +// Sets the logging level based on the environment variable +// +void SetLoggingLevel() +{ + WCHAR traceLevel [MAX_PATH]; + HRESULT hr = DrGetEnvironmentVariable(L"CCP_DRYADTRACELEVEL", traceLevel); + if(hr == DrError_OK) + { + if(wcscmp(traceLevel, L"OFF") == 0) + { + DrLogging::SetLoggingLevel(LogLevel_Off); + } + else if(wcscmp(traceLevel, L"CRITICAL") == 0) + { + DrLogging::SetLoggingLevel(LogLevel_Assert); + } + else if(wcscmp(traceLevel, L"ERROR") == 0) + { + DrLogging::SetLoggingLevel(LogLevel_Error); + } + else if(wcscmp(traceLevel, L"WARN") == 0) + { + DrLogging::SetLoggingLevel(LogLevel_Warning); + } + else if(wcscmp(traceLevel, L"INFO") == 0) + { + DrLogging::SetLoggingLevel(LogLevel_Info); + } + else + { + DrLogging::SetLoggingLevel(LogLevel_Debug); + } + } + else + { + DrLogging::SetLoggingLevel(LogLevel_Debug); + } +} + +// +// if $HPCQUERY_DEBUGVERTEXHOST is defined, break into the debugger +// +void BreakForDebugger() +{ + WCHAR strDebugBreak [MAX_PATH]; + HRESULT hr = DrGetEnvironmentVariable(L"HPCQUERY_DEBUGVERTEXHOST", strDebugBreak); + if(hr == DrError_OK) + { + DrLogE("Waiting for debugger "); + DrLogging::FlushLog(); + + while (!IsDebuggerPresent()) + { + Sleep(2000); + } + + DebugBreak(); + } +} + +// +// Start up vertex host +// +#if defined(_AMD64_) +int wmain(int argc, wchar_t* wargv[]) +#else +int __cdecl wmain(int argc, wchar_t* wargv[]) +#endif +{ + // + // Set up std.out and std.err files in current directory and redirect to them + // + RedirectOutputStreams(); + + // + // Enable logging based on environment variable + // + SetLoggingLevel(); + + // + // trace for startup + // + DrLogI("Vertex Host starting"); + + // + // Get environment variable to know whether to break into debugger + // + BreakForDebugger(); + + // + // We call Register on the Managed Wrapper vertex factory to force its library to be linked. + // Registration actually occurs during static initialization. + // + s_factoryHWrapper.Register(); + + // + // Get command line arguments + // + char** argv; + DrGetUtf8CommandArgs(argc, wargv, &argv); + + // + // Initialize the dryad communication layer with the command line arguments + // + int nOpts; + DrError e; + e = DryadInitializeXCompute(NULL, NULL, argc, argv, &nOpts); + if (e != DrError_OK) + { + // + // Report error in initializing xcompute layter + // + DrLogE("Couldn't initialise XCompute"); + return 1; + } + + // + // Update the argument list to just those parameters that weren't used by xcompute init + // + EliminateArguments(&argc, argv, 1, nOpts); + + int exitCode; + + // + // Check if --vertex argument provided. + // If it is, remove it from the argument list and call DryadVertexMain + // If it is not, report error + // + if (argc > 1 && ::_stricmp(argv[1], "--vertex") == 0) + { + EliminateArguments(&argc, argv, 1, 1); + + // + // Call main function to continue execution of vertex + // + exitCode = DryadVertexMain(argc, argv, NULL); + } + else + { + DrLogE("--vertex argument required - only vertex execution mode is supported."); + return 1; + } + + // + // Close the xcompute connection after dryadvertexmain returns + // + e = DryadShutdownXCompute(); + if (e == DrError_OK) + { + // + // Report success + // + DrLogI("Completed uninitialise xcompute"); + } + else + { + // + // Report failure + // + DrLogE("Couldn't uninitialise xcompute"); + } + + return exitCode; +} + +// +// Simple data class which contains the byte array and its length. +// +class DummyRecord { + /// used to copy arbitrary-sized items + size_t m_dummySize; + BYTE* m_dummyStuff; + +public: + // + // Create an empty record + // + DummyRecord() + { + m_dummySize = 0; + m_dummyStuff = NULL; + } + + // + // Clean up record + // + ~DummyRecord() + { + if (m_dummyStuff) + delete [] m_dummyStuff; + } + + // + // Copy constructor for record + // + DummyRecord(const DummyRecord& other) + { + m_dummySize = other.m_dummySize; + if (m_dummySize) + { + m_dummyStuff = new BYTE[m_dummySize]; + memcpy(m_dummyStuff, other.m_dummyStuff, m_dummySize); + } + } + + // + // Assignment operator overload to copy existing record + // + DummyRecord& operator=(const DummyRecord& other) + { + // + // Clean up existing record + // + if (m_dummyStuff) + { + // todo: why not 'delete []' like in destructor + delete m_dummyStuff; + } + + // + // Copy other record contents into this record + // + m_dummySize = other.m_dummySize; + if (m_dummySize) + { + m_dummyStuff = new BYTE[m_dummySize]; + memcpy(m_dummyStuff, other.m_dummyStuff, m_dummySize); + } + else + { + // + // If nothing in other record, set local record contents to NULL + // todo: use NULL rather than 0 + // + m_dummyStuff = 0; + } + + return *this; + } + + // + // Define deserialization of record + // + DrError DeSerialize(DrMemoryBufferReader* reader, + Size_t availableSize, bool lastRecordInStream) + { + // + // Read n bytes into record from memory buffer + // + m_dummySize = availableSize; + m_dummyStuff = new BYTE[m_dummySize]; + return reader->ReadBytes(m_dummyStuff, m_dummySize); + } + + // + // Define serialization of record + // + DrError Serialize(DrMemoryBufferWriter* writer) + { + // + // Write record into memory buffer + // + return writer->WriteBytes(m_dummyStuff, m_dummySize); + } + + // + // Move contents of another record into this record + // + void TransferFrom(DummyRecord& src) + { + // + // Get size and data + // + m_dummySize = src.m_dummySize; + m_dummyStuff = src.m_dummyStuff; + + // + // Clear other record's size and data + // + src.m_dummySize = 0; + src.m_dummyStuff = NULL; + } + + // + // Return the data size of this record + // + size_t GetSize() const + { + return m_dummySize; + } + + // + // Return a pointer to the data in this record + // + BYTE* GetData() const + { + return m_dummyStuff; + } +}; + +// +// Define a type for multiple records and create an instance of that type +// +typedef RecordBundle DummyBundle; +DummyBundle s_packedBundle; + +// +// Copy Vertex ('CP') +// Copies input to output. Used for broadcast. +// +class CopyVertex : public DryadVertexProgram +{ +public: + // + // Constructor - Ensures that parser factory is the factory associated with the global record bundle + // + CopyVertex() + { + SetCommonParserFactory(s_packedBundle.GetParserFactory()); + } + + // + // Run vertex which involves copying input channel to output channel + // + void Main(WorkQueue* workQueue, + UInt32 numberOfInputChannels, + RChannelReader** inputChannel, + UInt32 numberOfOutputChannels, + RChannelWriter** outputChannel) + { + // + // Ensure exactly one input and one output + // + LogAssert(numberOfInputChannels == 1 && numberOfOutputChannels == 1); + + // + // Associates reader with input channel and writer with output channel + // + DummyBundle::Reader input(inputChannel[0]); + DummyBundle::Writer output(&s_packedBundle, outputChannel[0]); + + // + // Log start of main. DrLogging will add reference to CopyVertex.main + // + DrLogI("Started"); + + // + // Reads from input channel until nothing left + // + while (input.Advance()) + { + //todo: Decide whether to remove this or log it + //printf("Transferring\n"); + + // + // Prepare output record array writer for additional input + // + output.MakeValid(); + + // + // Transfer contents of input into output + // + output->TransferFrom(*input); + } + } +}; + +// +// Factory for copy verticies +// +StdTypedVertexFactory s_factoryCopy("CP"); diff --git a/DryadVertex/service/DryadVertexService.csproj b/DryadVertex/service/DryadVertexService.csproj new file mode 100644 index 0000000..70f9b60 --- /dev/null +++ b/DryadVertex/service/DryadVertexService.csproj @@ -0,0 +1,150 @@ + + + + Debug + AnyCPU + {27D89037-8934-45BE-8A44-2561F9330EB7} + 9.0.30729 + 2.0 + Exe + false + DryadVertexService + v4.0 + 512 + DryadVertexService + + + 3.5 + + false + publish\ + true + Disk + false + Foreground + 7 + Days + false + false + true + 0 + 1.0.0.%2a + false + true + + + + + + true + full + false + ..\..\bin\Debug\ + DEBUG;TRACE + prompt + 4 + AllRules.ruleset + true + x64 + + + pdbonly + true + bin\Release\ + TRACE + prompt + 4 + AllRules.ruleset + x64 + + + + + + + + + + + 3.0 + + + + + + + Constants.cs + + + DryadTracing.cs + + + DryadVertexServiceAuthorizationManager.cs + + + ExecutionHelper.cs + + + IDryadVertexCallback.cs + + + IDryadVertexService.cs + + + NativeMethods.cs + + + NetShareWrapper.cs + + + ProcessPathHelper.cs + + + ProcessState.cs + + + QueryUtility.cs + + + RetryFramework.cs + + + SchedulerHelper.cs + + + + + + + + + + False + .NET Framework 3.5 SP1 Client Profile + false + + + False + .NET Framework 3.5 SP1 + true + + + False + Windows Installer 3.1 + true + + + + + + + + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2} + DryadYarnBridge + + + + + + + \ No newline at end of file diff --git a/DryadVertex/service/ReplyDispatcher.cs b/DryadVertex/service/ReplyDispatcher.cs new file mode 100644 index 0000000..8f01c24 --- /dev/null +++ b/DryadVertex/service/ReplyDispatcher.cs @@ -0,0 +1,416 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//------------------------------------------------------------------------------ +// +// Handles WCF communication for vertex service +// +//------------------------------------------------------------------------------ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.Collections.Generic; + using System.Diagnostics; + using System.Net.Security; + using System.ServiceModel; + using Microsoft.Research.Dryad; + + class ReplyDispatcher + { + private static string graphMgrUri = String.Empty; + private static VertexCallbackServiceClient graphMgrClient; + + private static string vertexProcUri = String.Empty; + private static VertexCallbackServiceClient vertexProcClient; + + private static NetTcpBinding binding = null; + private static readonly int numRetries = 6; + private static readonly int retrySleepTime = 1000; + private static object vertexLock = new object(); + private static object graphMgrLock = new object(); + private static object syncRoot = new object(); + private static bool shuttingDown = false; + + /// + /// Flag to notify reply dispatcher that job is shutting down and any communication errors should be ignored + /// + internal static bool ShuttingDown + { + get + { + return shuttingDown; + } + + set + { + shuttingDown = value; + } + } + + /// + /// Initialize binding defaults + /// + static ReplyDispatcher() + { + using (ISchedulerHelper helper = SchedulerHelperFactory.GetInstance()) + { + binding = helper.GetVertexServiceBinding(); + } + } + + /// + /// Check URI to see if it is the graph manager by checking for vertex id /1/ + /// + /// URI where vertex is expecting a response + /// true if graph manager + public static bool IsGraphMrgUri(string uri) + { + Uri u = new Uri(uri); + return u.AbsolutePath.StartsWith("/1/"); + } + + /// + /// Creates a new WCF client to service listening at URI + /// + /// wcf service endpoint + /// client to WCF service + private static VertexCallbackServiceClient CreateClient(string uri) + { + VertexCallbackServiceClient client = new VertexCallbackServiceClient(binding, new EndpointAddress(uri)); + lock (syncRoot) + { + // + // If the graph manager URI is specified, store this client as the GM client, otherwise assume it's a vertex host + // + if (IsGraphMrgUri(uri)) + { + graphMgrUri = uri; + graphMgrClient = client; + } + else + { + vertexProcUri = uri; + vertexProcClient = client; + } + } + + return client; + } + + /// + /// Close an existing client + /// + /// client to dispose + private static void DisposeClient(ref VertexCallbackServiceClient client) + { + if (client != null) + { + try + { + client.Close(); + } + catch (Exception) + { + try + { + client.Abort(); + } + catch (Exception) + { + // If client cannot be aborted, just finish silently. + } + } + + client = null; + } + else + { + throw new ArgumentNullException("client"); + } + } + + /// + /// Returns client pointing to URI - create if needed + /// + /// WCF server address + /// Client to server listening at URI + private static VertexCallbackServiceClient GetClient(string uri) + { + if (graphMgrUri.Equals(uri, StringComparison.OrdinalIgnoreCase)) + { + return graphMgrClient; + } + else if (vertexProcUri.Equals(uri, StringComparison.OrdinalIgnoreCase)) + { + return vertexProcClient; + } + else + { + return CreateClient(uri); + } + } + + /// + /// Try to reopen client to WCF service + /// + /// Address of service + /// new client + private static VertexCallbackServiceClient ReopenClient(string uri) + { + lock (syncRoot) + { + // + // Get any existing client to this URI + // + VertexCallbackServiceClient client = GetClient(uri); + if (client != null) + { + // + // If a client exists, dispose it + // + DisposeClient(ref client); + } + + // + // Recreate the client + // + return CreateClient(uri); + } + } + + /// + /// Helper method to retry opening the client for use with state changes and property comm + /// + /// URI to respond to + /// Reason for retry + /// new client - may be null on failures + private static VertexCallbackServiceClient ReopenClientForRetry(string replyUri, Exception e) + { + VertexCallbackServiceClient client = null; + DryadLogger.LogError(0, e); + try + { + client = ReopenClient(replyUri); + } + catch (Exception reopenEx) + { + DryadLogger.LogError(0, reopenEx, "Unable to reopen client connection"); + } + + // + // If retrying, sleep briefly + // + System.Threading.Thread.Sleep(retrySleepTime); + + return client; + } + + /// + /// Notify URI of state change + /// + /// where to send state change notification + /// vertex process id + /// updated state + /// success/failure of state change notification + public static bool FireStateChange(string replyUri, int processId, ProcessState newState) + { + DryadLogger.LogMethodEntry(replyUri, processId, newState); + + bool result = false; + VertexCallbackServiceClient client = GetClient(replyUri); + + // + // Try to notify GM of state change up to numRetries times + // + for (int index = 0; index < numRetries; index++) + { + try + { + // + // If client is null, try reopening it + // + if (client == null) + { + client = CreateClient(replyUri); + } + + // + // Make FireStateChange WCF call, return success + // + client.FireStateChange(processId, newState); + result = true; + break; + } + catch (Exception e) + { + if (shuttingDown) + { + // if shutting down, just return + DisposeClient(ref client); + return true; + } + else + { + // + // If call failed, try reopening WCF client and calling again + // + client = ReopenClientForRetry(replyUri, e); + } + } + } + + // + // If failure occurs after X retry attempts, report error + // + DryadLogger.LogMethodExit(result); + return result; + } + + /// + /// Notify GM that vertex host process exited + /// + /// GM address + /// vertex process id + /// reason for vertex host exit + /// success/failure + public static bool ProcessExited(string replyUri, int processId, int exitCode) + { + DryadLogger.LogMethodEntry(replyUri, processId, exitCode); + + bool result = false; + + VertexCallbackServiceClient client = GetClient(replyUri); + + // + // Try to notify GM that the process has exited up to numRetries times + // + for(int index = 0; index < numRetries; index++) + { + try + { + // + // If client is null, try reopening it + // + if(client == null) + { + client = CreateClient(replyUri); + } + + // + // Make ProcessExited WCF call, return success + // + client.ProcessExited(processId, exitCode); + result = true; + break; + } + catch (Exception e) + { + if (shuttingDown) + { + // if shutting down, just return + DisposeClient(ref client); + return true; + } + else + { + // + // If call failed, try reopening WCF client and calling again + // + client = ReopenClientForRetry(replyUri, e); + } + } + } + + // + // If failure occurs after X retry attempts, report error + // + DryadLogger.LogMethodExit(result); + return result; + } + + /// + /// Attempt to call SetGetPropsComplete on specified WCF service. + /// + /// Service endpoint + /// + /// + /// + /// + /// + /// + public static bool SetGetPropsComplete(string replyUri, Process systemProcess, int processId, ProcessInfo info, string[] propertyLabels, ulong[] propertyVersions) + { + DryadLogger.LogMethodEntry(replyUri, processId); + + bool result = false; + + VertexCallbackServiceClient client = GetClient(replyUri); + + // + // Try to set/get properties up to numRetries times + // + for (int index = 0; index < numRetries; index++) + { + try + { + // + // If client is null, try reopening it + // + if (client == null) + { + client = CreateClient(replyUri); + } + + // + // Make SetGetPropsComplete WCF call, return success + // + client.SetGetPropsComplete(processId, info, propertyLabels, propertyVersions); + result = true; + break; + } + catch (Exception e) + { + if ((IsGraphMrgUri(replyUri) == false && systemProcess.HasExited) || shuttingDown) + { + // + // If trying to connect to non-running vertex or job is shutting down, don't retry and report success. + // + DisposeClient(ref client); + return true; + } + else + { + // + // If call failed and talking to GM or running vertex process, try reopening WCF client and calling again + // + client = ReopenClientForRetry(replyUri, e); + } + } + } + + // + // If failed to connect X times, report error + // + DryadLogger.LogMethodExit(result); + return result; + } + } +} diff --git a/DryadVertex/service/VertexCallbackServiceClient.cs b/DryadVertex/service/VertexCallbackServiceClient.cs new file mode 100644 index 0000000..f1fecd3 --- /dev/null +++ b/DryadVertex/service/VertexCallbackServiceClient.cs @@ -0,0 +1,51 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + + +namespace Microsoft.Research.Dryad +{ + using System; + using System.ServiceModel; + using Microsoft.Research.Dryad; + + class VertexCallbackServiceClient : ClientBase, IDryadVertexCallback + { + public VertexCallbackServiceClient(System.ServiceModel.Channels.Binding binding, System.ServiceModel.EndpointAddress remoteAddress) : + base(binding, remoteAddress) + { + } + + public void FireStateChange(int processId, ProcessState newState) + { + Channel.FireStateChange(processId, newState); + } + + public void SetGetPropsComplete(int processId, ProcessInfo info, string[] propertyLabels, ulong[] propertyVersions) + { + Channel.SetGetPropsComplete(processId, info, propertyLabels, propertyVersions); + } + + public void ProcessExited(int processId, int exitCode) + { + Channel.ProcessExited(processId, exitCode); + } + + } +} diff --git a/DryadVertex/service/VertexProcess.cs b/DryadVertex/service/VertexProcess.cs new file mode 100644 index 0000000..a3c545c --- /dev/null +++ b/DryadVertex/service/VertexProcess.cs @@ -0,0 +1,1089 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//------------------------------------------------------------------------------ +// +// Wrapper around vertex host +// +//------------------------------------------------------------------------------ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.Diagnostics; + using System.Collections; + using System.Collections.Generic; + using System.Collections.Specialized; + using System.Threading; + using System.IO; + using Microsoft.Research.Dryad; + + class VertexProcess : IDisposable + { + #region members + + private int dryadProcessId; + private string graphManagerReplyUri; + public string commandLine; + private StringDictionary environment; + private string localAddress; + private bool failed = false; + private bool cancelled = false; + private bool exited = false; + private int exitCode = 0; + + private ManualResetEvent processStartEvent = new ManualResetEvent(false); + + private Process systemProcess; + private ProcessState state = ProcessState.AssignedToNode; + private ProcessInfo info = new ProcessInfo(); + private long peakWorkingSet = 0; + + private Dictionary> propertyWaitEvents = new Dictionary>(); + private int propertyWaiters = 0; + + private bool finalStatusMessageSent = false; + private ProcessPropertyInfo latestVertexStatusReceived; + private ProcessPropertyInfo latestVertexStatusSent; + + public Object syncRoot = new Object(); + private bool m_disposed = false; + + #endregion + + #region Constructors + + /// + /// Instantiates a new instance of the VertexProcess Class. Saves parameters. + /// + /// GM URI + /// Vertex ID + /// Vertex cmd line args + /// Environment variables for vertex host + /// VS URI + public VertexProcess(string uri, int id, string cmd, StringDictionary env, string localAddr) + { + dryadProcessId = id; + commandLine = cmd; + environment = env; + graphManagerReplyUri = uri; + localAddress = localAddr; + } + + #endregion + + #region Public methods + + // + // Asynchronously start vertex process + // + public bool Start(ManualResetEvent serviceInitializedEvent) + { + DryadLogger.LogMethodEntry(this.DryadId); + + bool result = ThreadPool.QueueUserWorkItem(new WaitCallback(StartProcessThreadProc), serviceInitializedEvent); + + DryadLogger.LogMethodExit(result); + return result; + } + + /// + /// Set process state to cancelled and stop the vertex host process if possible + /// + public void Cancel(bool suppressNotifications) + { + DryadLogger.LogMethodEntry(this.DryadId); + + lock (syncRoot) + { + if (state == ProcessState.Completed) + { + // Process has already completed before cancelation made it here, do nothing + DryadLogger.LogInformation("Cancel process", "Process {0} has already exited", DryadId); + DryadLogger.LogMethodExit(); + return; + } + DryadLogger.LogInformation("Cancel process", "Process {0} has not already exited", DryadId); + state = ProcessState.Completed; + this.cancelled = true; + } + + // If the process started, kill it + if (systemProcess != null) + { + try + { + // Killing the process will trigger Process_Exited + DryadLogger.LogInformation("Cancel process", "Killing system process for process id {0}", DryadId); + + if (suppressNotifications) + { + // Remove the Exited event handler + systemProcess.Exited -= this.Process_Exited; + } + systemProcess.Kill(); + DryadLogger.LogMethodExit(); + return; + } + catch (Exception e) + { + // + // Failed to kill process - log exception + // + DryadLogger.LogError(0, e, "Failed to kill system process for process id {0}", DryadId); + } + } + else + { + DryadLogger.LogInformation("Cancel process", "Process {0} has not started yet", DryadId); + } + + // Process was either not running or failed to die, trigger Process_Exited ourself + if (!suppressNotifications) + { + Process_Exited(this, null); + } + DryadLogger.LogMethodExit(); + } + + /// + /// + /// + public void Dispose() + { + DryadLogger.LogMethodEntry(this.DryadId); + + Dispose(true); + GC.SuppressFinalize(this); + + DryadLogger.LogMethodExit(); + } + + /// + /// Creates a request for property update from one vertex to another (one vertex is usually GM) + /// + /// Address to send response to + /// + /// + /// + /// + /// + /// + /// Returns success/failure of thread startup + public bool SetGetProps(string replyEpr, ProcessPropertyInfo[] infos, string blockOnLabel, ulong blockOnVersion, long maxBlockTime, string getPropLabel, bool ProcessStatistics) + { + DryadLogger.LogMethodEntry(replyEpr, this.DryadId); + + // + // Check if graph manager is the one calling. If so, increment propertyWaiter count + // + if (ReplyDispatcher.IsGraphMrgUri(replyEpr)) + { + int n = Interlocked.Increment(ref propertyWaiters); + } + + // + // Create new property request with provided parameters and queue up request sending + // + PropertyRequest req = new PropertyRequest(replyEpr, infos, blockOnLabel, blockOnVersion, maxBlockTime, getPropLabel, ProcessStatistics); + + bool result = ThreadPool.QueueUserWorkItem(new WaitCallback(SetGetPropThreadProc), req); + + DryadLogger.LogMethodExit(result); + return result; + } + + #endregion + + #region Private methods + + /// + /// Adds specified property to property wait list and waits for it. + /// + /// Property label to wait for + /// Version of property to wait for + /// Time to wait for property + /// False if property was requested but none was returned + private bool BlockOnProperty(string blockOnLabel, ulong blockOnVersion, long maxBlockTime) + { + DryadLogger.LogMethodEntry(); + + // + // Return true if no label is provided + // + if (String.IsNullOrEmpty(blockOnLabel)) + { + DryadLogger.LogMethodExit(true); + return true; + } + + DryadLogger.LogInformation("Block on property", "Label {0} Version {1} maxBlockTime {2}", blockOnLabel, blockOnVersion, maxBlockTime); + + ProcessPropertyInfo prop = null; + + // + // If the process already exited, don't bother adding a wait event for + // this property - if it's not already set it never will be. + // + + lock (syncRoot) + { + if (!exited) + { + // + // Add this label and version to the wait events list if needed + // + if (propertyWaitEvents.ContainsKey(blockOnLabel) == false) + { + propertyWaitEvents.Add(blockOnLabel, new Dictionary()); + } + + if (propertyWaitEvents[blockOnLabel].ContainsKey(blockOnVersion) == false) + { + propertyWaitEvents[blockOnLabel].Add(blockOnVersion, new ManualResetEvent(false)); + } + } + else + { + DryadLogger.LogInformation("Block on property", "Process {0} already exited, not adding waiter", this.DryadId); + } + } + + // todo: We still may want to implement timeouts to deal with deadlocks in the service / host but it hasn't been an issue yet. + //if (propertyWaitEvents[blockOnLabel][blockOnVersion].WaitOne(new TimeSpan(maxBlockTime), false)) + + // + // Wait forever (or until process exits or is disposed) for the property to be set or interrupted + // + + while (!exited) + { + try + { + if (propertyWaitEvents[blockOnLabel][blockOnVersion].WaitOne(100, false)) + { + break; + } + } + catch (ObjectDisposedException) + { + DryadLogger.LogWarning("Block on property", "Process {0} disposed while waiting for label {1}, version {2}", DryadId, blockOnLabel, blockOnVersion); + DryadLogger.LogMethodExit(false); + return false; + } + } + + // Did we get the property, or did the process + // terminate? + int index; + if (TryGetProperty(blockOnLabel, out prop, out index)) + { + // + // If a property was successfully returned, return true + // + if ((blockOnVersion == 0) || (prop.propertyVersion > blockOnVersion)) + { + DryadLogger.LogMethodExit(true); + return true; + } + + if (state == ProcessState.Completed) + { + DryadLogger.LogInformation("Block on property", "Vertex completed (wait) requested version:{0} returned version:{1} of label {2}", blockOnVersion, prop.propertyVersion, blockOnLabel); + DryadLogger.LogMethodExit(true); + return true; + } + } + + // + // Return false if property was requested but none was found + // + DryadLogger.LogMethodExit(false); + return false; + } + + /// + /// + /// + /// + private void Dispose(bool disposing) + { + DryadLogger.LogMethodEntry(disposing); + if (!this.m_disposed) + { + lock (syncRoot) + { + if (!this.m_disposed) + { + if (disposing) + { + // Close start event handle + try + { + processStartEvent.Close(); + } + catch (Exception ex) + { + DryadLogger.LogError(0, ex); + } + + // Close any get/set property wait handles + foreach (KeyValuePair> label in propertyWaitEvents) + { + foreach (KeyValuePair version in label.Value) + { + try + { + version.Value.Close(); + } + catch (Exception ex) + { + DryadLogger.LogError(0, ex); + } + } + } + propertyWaitEvents.Clear(); + } + + m_disposed = true; + } + } + } + DryadLogger.LogMethodExit(); + } + + /// + /// Record memory and CPU statistics for vertex host process + /// + /// requested statistics + /// true on success + private bool GetStatistics(out ProcessStatistics stats) + { + stats = new ProcessStatistics(); + + // These are the only statistics returned by the xcexec implementation + UpdateMemoryStatistics(); + stats.processUserTime = systemProcess.UserProcessorTime.Ticks * 10 * TimeSpan.TicksPerMillisecond; + stats.processKernelTime = (systemProcess.TotalProcessorTime.Ticks * 10 * TimeSpan.TicksPerMillisecond) - stats.processUserTime; + stats.peakMemUsage = (ulong)peakWorkingSet; + return true; + } + + /// + /// Set the property wait events for all versions of all properties + /// + private void SetAllPropertyWaiters() + { + lock (syncRoot) + { + // signal all threads waiting for properties + foreach (KeyValuePair> label in this.propertyWaitEvents) + { + foreach (KeyValuePair version in label.Value) + { + version.Value.Set(); + } + } + } + } + + /// + /// Set all requested properties + /// + /// Properties to set + /// Output - labels of properties set + /// Output - version of properties set + private void SetProperties(ProcessPropertyInfo[] infos, out string[] labels, out ulong[] versions) + { + // + // Return null if infos list contains no properties + // + labels = null; + versions = null; + if (infos != null && infos.Length > 0) + { + versions = new ulong[infos.Length]; + labels = new string[infos.Length]; + + for (int i = 0; i < infos.Length; i++) + { + // + // Set each property and update version and label + // + ulong newVersion = 0; + SetProperty(infos[i], out newVersion); + versions[i] = newVersion; + labels[i] = infos[i].propertyLabel; + } + } + } + + /// + /// Set a property and record the version + /// + /// + /// + private void SetProperty(ProcessPropertyInfo property, out ulong newVersion) + { + DryadLogger.LogMethodEntry(property.propertyLabel); + ProcessPropertyInfo oldProperty = null; + lock (syncRoot) + { + int index; + if (TryGetProperty(property.propertyLabel, out oldProperty, out index)) + { + // + // If property found in local array, then we are setting a new version of existing property + // Copy the new property information into the array + // + oldProperty.propertyVersion++; + newVersion = oldProperty.propertyVersion; + if (property.propertyBlock != null && property.propertyBlock.Length > 0) + { + oldProperty.propertyBlock = property.propertyBlock; + } + + oldProperty.propertyString = property.propertyString; + CopyProp(oldProperty, out info.propertyInfos[index]); + } + else + { + // + // if property not found in local array, then setting a new property + // use version 1, unless valid value specified + // + if (property.propertyVersion == ulong.MaxValue || property.propertyVersion == 0) + { + property.propertyVersion = 1; + } + + newVersion = property.propertyVersion; + + // + // Create or resize the local info array as necessary and append this property + // + if (info.propertyInfos == null) + { + info.propertyInfos = new ProcessPropertyInfo[1]; + } + else + { + Array.Resize(ref info.propertyInfos, info.propertyInfos.Length + 1); + } + + info.propertyInfos[info.propertyInfos.Length - 1] = property; + } + + // + // If there was a vertex completed event, record the latest vertex status + // + if (StatusMessageContainsDryadError_VertexCompleted( property.propertyLabel)) + { + CopyProp(property, out latestVertexStatusReceived); + latestVertexStatusReceived.propertyVersion = newVersion; + } + + // + // Wake up anyone waiting for a property change by adding a new wait event for this property if needed + // + if (propertyWaitEvents.ContainsKey(property.propertyLabel) == false) + { + propertyWaitEvents.Add(property.propertyLabel, new Dictionary()); + } + + // + // Wake up anyone waiting for this version of the property + // + if (propertyWaitEvents[property.propertyLabel].ContainsKey(newVersion - 1)) + { + propertyWaitEvents[property.propertyLabel][newVersion - 1].Set(); + } + else + { + propertyWaitEvents[property.propertyLabel].Add(newVersion - 1, new ManualResetEvent(true)); + } + + // + // Wake up anyone waiting for any version of this property + // + if (newVersion > 1) + { + if (propertyWaitEvents[property.propertyLabel].ContainsKey(0)) + { + propertyWaitEvents[property.propertyLabel][0].Set(); + } + } + } + DryadLogger.LogMethodExit(); + } + + /// + /// Copy the information from one ProcessPropertyInfo object to another + /// + /// Source of info + /// destination of info + private void CopyProp(ProcessPropertyInfo propertySrc, out ProcessPropertyInfo propertyDst) + { + propertyDst = new ProcessPropertyInfo(); + propertyDst.propertyLabel = propertySrc.propertyLabel; + propertyDst.propertyVersion = propertySrc.propertyVersion; + propertyDst.propertyString = propertySrc.propertyString; + if (propertySrc.propertyBlock != null) + { + propertyDst.propertyBlock = new byte[ propertySrc.propertyBlock.Length ]; + Array.Copy(propertySrc.propertyBlock, propertyDst.propertyBlock, propertySrc.propertyBlock.Length); + } + } + + /// + /// Looks at current list of properties to get latest information + /// + /// property label + /// property info output + /// property info index where found if found + /// found property = true + private bool TryGetProperty(string getPropLabel, out ProcessPropertyInfo property, out int index) + { + index = 0; + property = null; + if (info.propertyInfos != null) + { + lock (syncRoot) + { + // + // Look through each known property for one sharing the same label. + // + foreach (ProcessPropertyInfo p in info.propertyInfos) + { + if (String.Compare(p.propertyLabel, getPropLabel, StringComparison.OrdinalIgnoreCase) == 0) + { + // + // If found, set output parameter and return true + // + CopyProp(p, out property); + return true; + } + + index ++; + } + } + } + + return false; + } + + /// + /// Update vertex host process information + /// + private void UpdateMemoryStatistics() + { + try + { + systemProcess.Refresh(); + peakWorkingSet = systemProcess.PeakWorkingSet64; + } + catch + { + // Process has exited + } + } + + #endregion + + #region Properties + + public int DryadId + { + get { return this.dryadProcessId; } + } + + public int ProcessId + { + get { return this.systemProcess.Id; } + } + + public ProcessState State + { + get { return this.state; } + } + + public bool Succeeded + { + get { return ((this.exited) && !(this.cancelled || this.failed)); } + } + + #endregion + + #region Event handlers + + /// + /// Vertex host process exited event - marks process state and queues up exit process thread + /// + /// + /// + private void Process_Exited(object sender, EventArgs args) + { + DryadLogger.LogMethodEntry(DryadId); + + // Ensure the process exited code can only be executed once + lock (syncRoot) + { + if (exited) + { + DryadLogger.LogInformation("Process exit", "Process {0} already exited", DryadId); + DryadLogger.LogMethodExit(); + return; + } + exited = true; + } + + if (cancelled) + { + DryadLogger.LogInformation("Process exit", "Process {0} was cancelled", DryadId); + exitCode = unchecked((int)0x830A0003); // DrError_VertexReceivedTermination + } + else + { + exitCode = systemProcess.ExitCode; + DryadLogger.LogInformation("Process exit", "Process {0} exit code {1}", DryadId, exitCode); + if (exitCode == 0) + { + lock (syncRoot) + { + state = ProcessState.Completed; + } + } + else + { + lock (syncRoot) + { + state = ProcessState.Completed; + this.failed = true; + } + } + } + + // + // Ensure that the vertex complete event is sent to GM and that all pending properties are handled + // + ThreadPool.QueueUserWorkItem(new WaitCallback(ExitProcessThreadProc)); + + DryadLogger.LogMethodExit(); + } + + #endregion + + #region Thread functions + + /// + /// Asynchronously called on start command + /// + /// + void StartProcessThreadProc(Object obj) + { + ManualResetEvent serviceInitializedEvent = obj as ManualResetEvent; + bool started = false; + + try + { + // + // Wait for service initialization + // + serviceInitializedEvent.WaitOne(); + + if (ExecutionHelper.InitializeForProcessExecution(dryadProcessId, Environment.GetEnvironmentVariable("XC_RESOURCEFILES"))) + { + // + // Vertex working directory configured successfully, start the vertex host + // + environment.Add(Constants.vertexSvcLocalAddrEnvVar, localAddress); + + ProcessStartInfo startInfo = new ProcessStartInfo(); + startInfo.CreateNoWindow = true; + startInfo.UseShellExecute = false; + startInfo.WorkingDirectory = ProcessPathHelper.ProcessPath(dryadProcessId); + + //YARN Debugging + //var procEnvVarKeys = startInfo.EnvironmentVariables.Keys; + //foreach (string key in procEnvVarKeys) + //{ + // DryadLogger.LogInformation("StartProcess", "key: '{0}' value: '{1}'", key, startInfo.EnvironmentVariables[key]); + //} + + string[] args = commandLine.Split(' '); + string arg = ""; + for (int i = 1; i < args.Length; i++) + { + arg += args[i] + " "; + } + + // + // Use either FQ path or path relative to job path + // + if (Path.IsPathRooted(args[0])) + { + startInfo.FileName = args[0]; + } + else + { + startInfo.FileName = Path.Combine(ProcessPathHelper.JobPath, args[0]); + } + DryadLogger.LogInformation("StartProcess", "FileName: '{0}'", startInfo.FileName); + + // + // Add environment variable to vertex host process + // + startInfo.Arguments = arg; + foreach (DictionaryEntry entry in environment) + { + string key = entry.Key.ToString(); + + if (key == null || startInfo.EnvironmentVariables.ContainsKey(key)) + { + DryadLogger.LogInformation("StartProcess", "Attempting to add existing key '{0}' with value '{1}'", + entry.Key, entry.Value); + } + else + { + startInfo.EnvironmentVariables.Add(key, entry.Value.ToString()); + } + } + + lock (syncRoot) + { + // + // After taking lock, start the vertex host process and set up exited event handler + // + if (cancelled) + { + // If we've already been canceled, don't start the process + DryadLogger.LogInformation("Process start", "Not starting process {0} due to receipt of cancellation", DryadId); + return; + } + else + { + + systemProcess = new Process(); + systemProcess.StartInfo = startInfo; + systemProcess.EnableRaisingEvents = true; + systemProcess.Exited += new EventHandler(Process_Exited); + Console.WriteLine("Process start - Vertex host process starting"); + started = systemProcess.Start(); + Console.WriteLine("Process start - Vertex host process started"); + if (started) + { + DryadLogger.LogInformation("Process start", "Vertex host process started"); + state = ProcessState.Running; + } + else + { + DryadLogger.LogError(0, null, "Vertex host process failed to start"); + } + } + } + } + else + { + DryadLogger.LogError(0, null, "Initialization failed"); + } + } + catch (Exception e) + { + DryadLogger.LogError(0, e, "Error starting vertex"); + } + + if (started) + { + // + // Notify Graph Manager that process started if successful + // + bool success = ReplyDispatcher.FireStateChange(this.graphManagerReplyUri, this.dryadProcessId, ProcessState.Running); + if (!success) + { + // + // Graph manager doesn't know we started and we have no way to tell it, so it's + // best to just fail the vertex service task and let the job manager inform the graph manager + // + VertexService.Surrender(new Exception("Unable to communicate with graph manager.")); + } + } + else + { + // + // Otherwise, notify GM that process has failed + // + lock (syncRoot) + { + // If we've already been canceled, we don't need to change state or record the initialization failure + if (!cancelled) + { + state = ProcessState.Completed; + this.failed = true; + exitCode = unchecked((int)Constants.DrError_VertexInitialization); // DryadError_VertexInitialization + } + } + + if (failed) // This also means we weren't canceled + { + // Notify the Graph Manager that the process failed to start + Process_Exited(this, null); + } + } + + // + // Make sure process start event is set + // + processStartEvent.Set(); + } + + /// + /// Check if the message label marks vertex as completed + /// + /// Label to check + /// true if message contains DryadError_VertexCompleted + bool StatusMessageContainsDryadError_VertexCompleted(string statusMessageLabel) + { + // + // todo: This seems hacky - make sure it always works + // + if (statusMessageLabel.StartsWith(@"DVertexStatus-", StringComparison.OrdinalIgnoreCase) && + !statusMessageLabel.EndsWith(@"update", StringComparison.OrdinalIgnoreCase)) + { + return true; + } + + return false; + } + + /// + /// Called in new thread in setgetproperty service operation + /// + /// + void SetGetPropThreadProc(Object obj) + { + DryadLogger.LogMethodEntry(DryadId); + PropertyRequest r = obj as PropertyRequest; + + ProcessInfo infoLocal = new ProcessInfo(); + ulong[] propertyVersions = null; + string[] propertyLabels = null; + + // + // Make sure process is started before continuing + // + if (this.State < ProcessState.Running) + { + try + { + processStartEvent.WaitOne(); + } + catch (ObjectDisposedException ex) + { + // The process was cancelled and released before it started running, just return + if (exited) + { + DryadLogger.LogInformation("SetGetProp Thread", "Process {0} cancelled or exited before starting.", this.DryadId); + } + else + { + DryadLogger.LogError(0, ex); + } + DryadLogger.LogMethodExit(); + return; + } + } + + // + // Use status_pending if running, vertex initialization failure if process is failed and process exit code otherwise + // + infoLocal.processStatus = 0x103; // WinNT.h STATUS_PENDING + infoLocal.processState = state; + if (state == ProcessState.Running) + { + infoLocal.exitCode = 0x103; // WinNT.h STATUS_PENDING + } + else if (failed) + { + infoLocal.exitCode = Constants.DrError_VertexError; + } + else if (cancelled) + { + infoLocal.exitCode = Constants.DrError_VertexReceivedTermination; // DryadError_VertexReceivedTermination + } + else + { + infoLocal.exitCode = (uint)systemProcess.ExitCode; + } + + // + // Record specified properties and update versions - wakes up anyone waiting for property changes + // + SetProperties(r.infos, out propertyLabels, out propertyVersions); + + // + // Try to get property update + // + if (BlockOnProperty(r.blockOnLabel, r.blockOnVersion, r.maxBlockTime)) + { + // + // If property update was received, update the received property information + // If received property marks vertex completed, record that + // + if (r.getPropLabel != null && r.getPropLabel.Length > 0) + { + lock (syncRoot) + { + infoLocal.propertyInfos = new ProcessPropertyInfo[1]; + + int index; + if (TryGetProperty(r.getPropLabel, out infoLocal.propertyInfos[0], out index) == false) + { + DryadLogger.LogError(0, null, "Failed to get property for label {0}", r.getPropLabel); + } + + if (StatusMessageContainsDryadError_VertexCompleted(infoLocal.propertyInfos[0].propertyLabel)) + { + CopyProp(infoLocal.propertyInfos[0], out latestVertexStatusSent); + } + } + } + + // + // If request asks for statistics on vertex process, get them + // + if (r.ProcessStatistics) + { + if (GetStatistics(out infoLocal.processStatistics) == false) + { + DryadLogger.LogError(0, null, "Failed to get vertex statistics"); + } + } + } + + // + // Try to report property change, if unsuccessful, kill the running vertex host process + // + if (!ReplyDispatcher.SetGetPropsComplete(r.replyUri, systemProcess, dryadProcessId, infoLocal, propertyLabels, propertyVersions)) + { + try + { + systemProcess.Kill(); + } + catch (InvalidOperationException /* unused ioe */) + { + // The process has already exited + // -or- + // There is no process associated with this Process object. + } + catch (Exception eInner) + { + // + // all other exceptions + // + DryadLogger.LogError(0, eInner, "Exception calling back to '{0}'", r.replyUri); + } + } + + // + // If a property was handled from the graph manager, decrement the waiter count + // + if (ReplyDispatcher.IsGraphMrgUri(r.replyUri)) + { + int n = Interlocked.Decrement(ref propertyWaiters); + DryadLogger.LogInformation("SetGetProp Thread", "Process {0} propertyWaiters = {1}", DryadId, n); + } + + lock (syncRoot) + { + // + // If vertex process has exited, and sending vertex completed event, we can stop worrying + // + if (!finalStatusMessageSent) + { + if (latestVertexStatusSent != null) + { + if (!String.IsNullOrEmpty(latestVertexStatusSent.propertyString)) + { + if (latestVertexStatusSent.propertyString.Contains(string.Format(@"(0x{0:x8})", Constants.DrError_VertexCompleted))) + { + finalStatusMessageSent = true; + } + } + } + } + } + DryadLogger.LogMethodExit(); + } + + /// + /// Make sure all pending properties get handled and vertex complete event is sent to GM + /// + /// + void ExitProcessThreadProc(Object obj) + { + DryadLogger.LogMethodEntry(); + + // + // Wait until all property waiters have been notified and the final + // status message sent, iff the process completed successfully + // + do + { + // + // Clear any thing intended for the vertex + // + SetAllPropertyWaiters(); + + lock (syncRoot) + { + // If nobody is waiting, AND + if (propertyWaiters == 0) + { + // Process did not complete successfully OR + // final status message has already been sent + if (!Succeeded || finalStatusMessageSent) + { + // Then we can send the Process Exit notification + break; + } + } + } + + Thread.Sleep(10); + + } while(true); + + ReplyDispatcher.ProcessExited(this.graphManagerReplyUri, this.dryadProcessId, this.exitCode); + + // + // This should never happen unless a property is requested after the vertex completed event is sent + // so it's not a big deal if it does because the GM knows that the vertex is done + // + if (propertyWaiters > 0) + { + DryadLogger.LogWarning("Process exit", "Leaving thread with {0} property waiter(s).", propertyWaiters); + } + + DryadLogger.LogMethodExit(); + } + + #endregion + } +} diff --git a/DryadVertex/service/VertexService.cs b/DryadVertex/service/VertexService.cs new file mode 100644 index 0000000..8da799a --- /dev/null +++ b/DryadVertex/service/VertexService.cs @@ -0,0 +1,507 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//------------------------------------------------------------------------------ +// +// Implementation of the vertex service +// +//------------------------------------------------------------------------------ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.Globalization; + using System.Collections; + using System.Collections.Generic; + using System.Collections.Specialized; + using System.Runtime.Serialization; + using System.ServiceModel; + using System.ServiceModel.Channels; + using System.Threading; + using System.Configuration; + using System.IO; + using System.Diagnostics; + using System.Management; + using System.Runtime.InteropServices; + using Microsoft.Research.Dryad; + + /// + /// Class that holds all information needed to make a property request + /// + internal class PropertyRequest + { + public ProcessPropertyInfo[] infos; + public string blockOnLabel; + public ulong blockOnVersion; + public long maxBlockTime; + public string getPropLabel; + public bool ProcessStatistics; + public string replyUri; + + /// + /// Constructor - fills in properties + /// + /// + /// + /// + /// + /// + /// + /// + public PropertyRequest(string uri, ProcessPropertyInfo[] infos, string blockOnLabel, ulong blockOnVersion, long maxBlockTime, string getPropLabel, bool ProcessStatistics) + { + this.infos = infos; + this.blockOnLabel = blockOnLabel; + this.blockOnVersion = blockOnVersion; + this.maxBlockTime = maxBlockTime; + this.getPropLabel = getPropLabel; + this.ProcessStatistics = ProcessStatistics; + this.replyUri = uri; + } + } + + /// + /// Implementation of the IDryadVertexService and IDryadVertexFileService + /// + [ServiceBehavior(ConcurrencyMode = ConcurrencyMode.Reentrant, InstanceContextMode = InstanceContextMode.Single, IncludeExceptionDetailInFaults = true)] + internal class VertexService : IDryadVertexService + { + #region members + + public static ManualResetEvent shutdownEvent = new ManualResetEvent(false); + internal static bool internalShutdown = false; + internal static Exception ShutdownReason { get; set; } + private ManualResetEvent initializedEvent = new ManualResetEvent(false); + + // TODO: add synchronization locks as necessary + private SynchronizedCollection vertexProcessTable; + + private StringDictionary vertexEndpointAddresses = new StringDictionary(); + + #endregion + + #region Public methods + + /// + /// Constructor - called when service first hosted + /// + public VertexService() + { + DryadLogger.LogMethodEntry(); + this.vertexProcessTable = new SynchronizedCollection(); + System.Threading.ThreadPool.QueueUserWorkItem(new WaitCallback(InitializationThreadProc)); + DryadLogger.LogMethodExit(); + } + + #endregion + + #region IDryadVertexService methods + + /// + /// Cancels the vertex process with the provided id + /// + /// vertex process id + void IDryadVertexService.CancelScheduleProcess(int processId) + { + VertexProcess vp = null; + DryadLogger.LogMethodEntry(processId); + + try + { + vp = FindByDryadId(processId); + if (vp != null) + { + vp.Cancel(false); + } + else + { + DryadLogger.LogWarning("Cancel Process", "Unknown process id {0}", processId); + } + } + catch (Exception e) + { + DryadLogger.LogWarning("Cancel Process", "Operation threw exception: {0}", e.ToString()); + } + + DryadLogger.LogMethodExit(); + } + + /// + /// Gets information about the vertex service + /// + /// + VertexStatus IDryadVertexService.CheckStatus() + { + DryadLogger.LogMethodEntry(); + VertexStatus status = new VertexStatus(); + status.serviceIsAlive = true; + + // + // Update information about disk usage + // + foreach (string disk in Environment.GetLogicalDrives()) + { + ulong freeDiskSpaceforUser; + ulong totalDiskSpace; + ulong freeDiskSpace; + + if (NativeMethods.GetDiskFreeSpaceEx(disk, out freeDiskSpaceforUser, out totalDiskSpace, out freeDiskSpace)) + { + status.freeDiskSpaces.Add(disk, freeDiskSpace); + } + else + { + // + // Report any errors as warnings, as this is a non-essential call + // + int errorCode = Marshal.GetLastWin32Error(); + Exception lastex = Marshal.GetExceptionForHR(errorCode); + if (lastex != null) + { + DryadLogger.LogWarning("Unable to get disk space information", "Disk: {0} Error: {1}", disk, lastex.Message); + } + else + { + DryadLogger.LogWarning("Unable to get disk space information", "Disk: {0} Error Code: {1}", disk, errorCode); + } + } + } + + // + // Update information about memory usage + // + NativeMethods.MEMORYSTATUSEX memStatus = new NativeMethods.MEMORYSTATUSEX(); + if (NativeMethods.GlobalMemoryStatusEx(memStatus)) + { + status.freePhysicalMemory = memStatus.ullAvailPhys; + status.freeVirtualMemory = memStatus.ullAvailVirtual; + } + else + { + // + // Report any errors as warnings, as this is a non-essential call + // + int errorCode = Marshal.GetLastWin32Error(); + Exception lastex = Marshal.GetExceptionForHR(errorCode); + if (lastex != null) + { + DryadLogger.LogWarning("Unable to get memory information", "Error: {0}", lastex.Message); + } + else + { + DryadLogger.LogWarning("Unable to get memory information", "Error Code: {0}", errorCode); + } + } + + // + // Get process info for each running vertex process + // + status.runningProcessCount = 0; + lock (vertexProcessTable.SyncRoot) + { + foreach (VertexProcess vp in this.vertexProcessTable) + { + VertexProcessInfo vpInfo = new VertexProcessInfo(); + vpInfo.DryadId = vp.DryadId; + vpInfo.commandLine = vp.commandLine; + vpInfo.State = vp.State; + + status.vps.Add(vpInfo); + + if (vp.State == ProcessState.Running) + { + status.runningProcessCount++; + } + } + } + + DryadLogger.LogMethodExit(status); + return status; + } + + /// + /// Initialize the endpoint addresses for each vertex host + /// + /// List of vertex host addresses + void IDryadVertexService.Initialize(StringDictionary vertexEndpointAddresses) + { + DryadLogger.LogMethodEntry(vertexEndpointAddresses.Count); + + try + { + this.vertexEndpointAddresses = vertexEndpointAddresses; + } + catch (Exception e) + { + DryadLogger.LogWarning("Initialize", "Operation threw exception: {0}", e.ToString()); + } + + DryadLogger.LogMethodExit(); + } + + /// + /// Removes reference to a vertex process + /// + /// process id to forget + void IDryadVertexService.ReleaseProcess(int processId) + { + DryadLogger.LogMethodEntry(processId); + VertexProcess vp = null; + + try + { + vp = FindByDryadId(processId); + if (vp != null) + { + vertexProcessTable.Remove(vp); + vp.Dispose(); + } + else + { + DryadLogger.LogWarning("Release Process", "Unknown process id {0}", processId); + } + } + catch (Exception e) + { + DryadLogger.LogWarning("Release Process", "Operation threw exception: {0}", e.ToString()); + } + + DryadLogger.LogMethodExit(); + } + + /// + /// Schedule a vertex host process using the provided parameters + /// + /// callback URI + /// vertex process id + /// vertex host command line + /// vertex host environment variables + /// Success/Failure of starting vertex process thread + bool IDryadVertexService.ScheduleProcess(string replyUri, int processId, string commandLine, StringDictionary environment) + { + DryadLogger.LogMethodEntry(processId, commandLine); + bool startSuccess = false; + Console.WriteLine("Starting process id {0} with commandLIne: '{1}", processId, commandLine); + try + { + VertexProcess newProcess = null; + + lock (vertexProcessTable.SyncRoot) + { + foreach (VertexProcess vp in vertexProcessTable) + { + if (vp.DryadId == processId) + { + // This means a previous call to Schedule process partially succeeded: + // the call made it to the service but something went wrong with the response + // so the GM's xcompute machinery retried the call. We can just return success + // for this case rather than tearing down the process and creating a new one. + return true; + } + + if (vp.State <= ProcessState.Running) + { + // There should be no other processes running. + // If there are, it means a previous communication error + // cause the GM to give up on this node for a while. + // Kill anything that's still hanging around. + vp.Cancel(true); + } + } + + newProcess = new VertexProcess( + replyUri, + processId, + commandLine, + environment, + OperationContext.Current.Channel.LocalAddress.Uri.ToString() + ); + this.vertexProcessTable.Add(newProcess); + } + + startSuccess = newProcess.Start(initializedEvent); + } + catch (Exception e) + { + DryadLogger.LogWarning("Schedule Process", "Operation threw exception: {0}", e.ToString()); + throw new FaultException(new VertexServiceError("ReleaseProcess", e.ToString())); + } + + DryadLogger.LogMethodExit(startSuccess); + return startSuccess; + } + + /// + /// Update properties + /// + /// callback URI + /// vertex process id + /// property information + /// property update label + /// property update version + /// maximum time to wait for update + /// property to get + /// vertex host process statistics + /// success/failure of property update + bool IDryadVertexService.SetGetProps(string replyEpr, int processId, ProcessPropertyInfo[] infos, string blockOnLabel, ulong blockOnVersion, long maxBlockTime, string getPropLabel, bool ProcessStatistics) + { + DryadLogger.LogMethodEntry(replyEpr, processId); + bool success = false; + + try + { + // Get the vertex process ID + VertexProcess vp = FindByDryadId(processId); + if (vp != null) + { + success = vp.SetGetProps(replyEpr, infos, blockOnLabel, blockOnVersion, maxBlockTime, getPropLabel, ProcessStatistics); + } + else + { + DryadLogger.LogError(0, null, "Failed to set / get process properties: Unknown process id {0}", processId); + } + } + catch (Exception e) + { + DryadLogger.LogWarning("Set Or Get Process Properties", "Operation threw exception: {0}", e.ToString()); + throw new FaultException(new VertexServiceError("SetGetProps", e.ToString())); + } + + DryadLogger.LogMethodExit(success); + return success; + } + + /// + /// Shut down the vertex service + /// + /// + void IDryadVertexService.Shutdown(uint ShutdownCode) + { + DryadLogger.LogMethodEntry(ShutdownCode); + + try + { + ReplyDispatcher.ShuttingDown = true; + VertexService.shutdownEvent.Set(); + } + catch (Exception e) + { + DryadLogger.LogWarning("Shutdown", "Operation threw exception: {0}", e.ToString()); + } + + DryadLogger.LogMethodExit(); + } + + #endregion + + #region Private methods + + /// + /// Get vertex process cooresponding to dryad id + /// + /// dryad id + /// vertex process + private VertexProcess FindByDryadId(int id) + { + lock (vertexProcessTable.SyncRoot) + { + foreach (VertexProcess p in vertexProcessTable) + { + if (p.DryadId == id) + { + return p; + } + } + } + return null; + } + + /// + /// Get vertex process cooresponding to process id + /// + /// process id + /// vertex process + private VertexProcess FindByProcessId(int id) + { + lock (vertexProcessTable.SyncRoot) + { + foreach (VertexProcess p in vertexProcessTable) + { + if (p.ProcessId == id) + { + return p; + } + } + } + return null; + } + + #endregion + + #region Internal methods + + /// + /// Fail the vertex service task + /// + internal static void Surrender(Exception ex) + { + DryadLogger.LogMethodEntry(); + ReplyDispatcher.ShuttingDown = true; + VertexService.internalShutdown = true; + VertexService.ShutdownReason = ex; + VertexService.shutdownEvent.Set(); + DryadLogger.LogMethodExit(); + } + #endregion + + #region Thread Functions + + /// + /// Initialization thread - initialize job working directory if needed. + /// + /// + void InitializationThreadProc(Object state) + { + try + { + if (Environment.GetEnvironmentVariable(Constants.schedulerTypeEnvVar) == Constants.schedulerTypeLocal) + { + initializedEvent.Set(); + } + else if (ExecutionHelper.InitializeForJobExecution(Environment.GetEnvironmentVariable("XC_RESOURCEFILES"))) + { + DryadLogger.LogInformation("InitializationThreadProc", "InitializeForJobExecution was successful."); + initializedEvent.Set(); + } + else + { + Surrender(new Exception("Failed to initialize vertex service for job execution")); + } + } + catch (Exception ex) + { + Surrender(ex); + } + } + + #endregion + } + +} diff --git a/DryadVertex/service/app.config b/DryadVertex/service/app.config new file mode 100644 index 0000000..e365603 --- /dev/null +++ b/DryadVertex/service/app.config @@ -0,0 +1,3 @@ + + + diff --git a/DryadVertex/service/program.cs b/DryadVertex/service/program.cs new file mode 100644 index 0000000..0616517 --- /dev/null +++ b/DryadVertex/service/program.cs @@ -0,0 +1,217 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//------------------------------------------------------------------------------ +// +// The main entry point for the application. +// +//------------------------------------------------------------------------------ +namespace Microsoft.Research.Dryad +{ + using System; + using System.Collections.Generic; + using System.Text; + using System.ServiceModel; + using System.ServiceModel.Description; + using System.Diagnostics; + using System.Net; + using System.Net.Security; + using System.Net.Sockets; + using System.Threading; + using Microsoft.Research.Dryad; + using System.IO; + using System.Security.Principal; + using System.Security.AccessControl; + + /// + /// Main entry point + /// + internal static class Program + { + /// + /// Number of times to retry operations before failing + /// + private static int numRetries = 2; + + /// + /// The main entry point for the application. + /// + private static int Main(string[] args) + { + // + // Try to create working directory. Fail vertex service if unable to do so. + // + bool createdJobDir = false; + int retryCount = 0; + do + { + try + { + ProcessPathHelper.CreateUserWorkingDirectory(); + + Directory.CreateDirectory(ProcessPathHelper.JobPath); + + createdJobDir = true; + } + catch (Exception ex) + { + Console.Error.WriteLine("Failed to create working directory, {0}. Error: {1}.", ProcessPathHelper.JobPath, ex.ToString()); + retryCount++; + } + } while (retryCount < numRetries && !createdJobDir); + + if (!createdJobDir) + { + Console.Error.WriteLine("Vertex service cannot proceed because working directory could not be created."); + return 1; + } + + // + // Get Task ID from environment + // + int taskId; + if (Int32.TryParse(Environment.GetEnvironmentVariable("CCP_TASKID"), out taskId) == false) + { + Console.Error.WriteLine("Program.Main", "Failed to read CCP_TASKID from environment"); + return 1; + } + + // + // Initialize tracing subsystem + // + string traceFile = Path.Combine(ProcessPathHelper.JobPath, String.Format("VertexServiceTrace_{0}.txt", taskId)); + DryadLogger.Start(traceFile); + + // + // Initialize scheduler helper of the correct type + // + ISchedulerHelper schedulerHelper; + try + { + schedulerHelper = SchedulerHelperFactory.GetInstance(); + } + catch (Exception ex) + { + DryadLogger.LogCritical(0, ex, "Failed to get scheduler helper"); + DryadLogger.Stop(); + Console.Error.WriteLine("Failed to contact HPC scheduler. See log for details."); + return 1; + } + + // + // Step 1 of the address configuration procedure: Create a URI to serve as the base address. + // + string strAddress = schedulerHelper.GetVertexServiceBaseAddress("localhost", taskId); + Uri baseAddress = new Uri(strAddress); + + // + // Step 2 of the hosting procedure: Create ServiceHost + // + ServiceHost selfHost = new ServiceHost(typeof(VertexService), baseAddress); + + try + { + // + // Get the service binding + // + NetTcpBinding binding = schedulerHelper.GetVertexServiceBinding(); + + // + // Step 3 of the hosting procedure: Add service endpoints. + // + ServiceEndpoint vertexEndpoint = selfHost.AddServiceEndpoint(typeof(IDryadVertexService), binding, Constants.vertexServiceName); + DryadLogger.LogInformation("Initialize vertex service", "listening on address {0}", vertexEndpoint.Address.ToString()); + + // + // Step 4 of hosting procedure : Add a security manager + // TODO: Fix this for local scheduler and / or Azure scheduler when supported + // + selfHost.Authorization.ServiceAuthorizationManager = new DryadVertexServiceAuthorizationManager(); + + // Step 5 of the hosting procedure: Start (and then stop) the service. + selfHost.Open(); + + Console.WriteLine("Vertex Service up and waiting for commands"); + + // Wait for the shutdown event to be set. + VertexService.shutdownEvent.WaitOne(-1, true); + + // Check vertex service shutdown condition + if (VertexService.internalShutdown) + { + string errorMsg = string.Format("Vertex Service Task unable to continue after critical error in initialization or communication: {0}", VertexService.ShutdownReason.ToString()); + Console.WriteLine(errorMsg); + DryadLogger.LogCritical(0, new Exception(errorMsg)); + DryadLogger.Stop(); + try + { + selfHost.Abort(); + } + catch + { + } + + return 1; + } + + // Close the ServiceHostBase to shutdown the service. + selfHost.Close(); + } + catch (CommunicationException ce) + { + // + // Report any errors and fail task + // + DryadLogger.LogCritical(0, ce, "A communication exception occurred"); + DryadLogger.Stop(); + try + { + selfHost.Abort(); + } + catch + { + } + Console.Error.WriteLine("CommunicationException occured, aborting vertex service. See log for details."); + return 1; + } + catch (Exception ex) + { + // + // Report any errors and fail task + // + DryadLogger.LogCritical(0, ex, "An exception occurred"); + DryadLogger.Stop(); + try + { + selfHost.Abort(); + } + catch + { + } + Console.Error.WriteLine("An exception occured, aborting vertex service. See log for details."); + return 1; + } + + DryadLogger.LogInformation("Vertex Service", "Shut down cleanly"); + DryadLogger.Stop(); + return 0; + } + } +} diff --git a/DryadYarnBridge/DryadYarnBridge.cpp b/DryadYarnBridge/DryadYarnBridge.cpp new file mode 100644 index 0000000..80d22e1 --- /dev/null +++ b/DryadYarnBridge/DryadYarnBridge.cpp @@ -0,0 +1,52 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DryadYarnBridge.h" +#include "YarnAppMasterNative.h" +using namespace DryadYarn; + +void * DrCreateNativeAppMaster() +{ + AMNativeInstance *instance = new AMNativeInstance(); + if (instance->OpenInstance()) + { + return instance; + } + return NULL; +} + +void DrDestroyNativeAppMaster(void *ptr) +{ + AMNativeInstance *instance = (AMNativeInstance *)ptr; + delete instance; +} + + +char* DrGetExceptionMessage(void *ptr) +{ + AMNativeInstance *instance = (AMNativeInstance *)ptr; + return instance->GetExceptionMessage(); +} + +bool DrScheduleProcess(void *ptr, int vertexId, const char* name, const char* commandLine) +{ + AMNativeInstance *instance = (AMNativeInstance *)ptr; + return instance->ScheduleProcess(vertexId, name, commandLine); +} diff --git a/DryadYarnBridge/DryadYarnBridge.def b/DryadYarnBridge/DryadYarnBridge.def new file mode 100644 index 0000000..0e663f4 --- /dev/null +++ b/DryadYarnBridge/DryadYarnBridge.def @@ -0,0 +1,9 @@ +LIBRARY "DryadYarnBridge.dll" + +EXPORTS + DrCreateNativeAppMaster + DrDestroyNativeAppMaster + DrGetExceptionMessage + DrScheduleProcess + Java_com_microsoft_research_TestLib_SendVertexState + Java_com_microsoft_research_DryadAppMaster_SendVertexState diff --git a/DryadYarnBridge/DryadYarnBridge.h b/DryadYarnBridge/DryadYarnBridge.h new file mode 100644 index 0000000..dbf7cc7 --- /dev/null +++ b/DryadYarnBridge/DryadYarnBridge.h @@ -0,0 +1,32 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +extern "C" +{ + void * _stdcall DrCreateNativeAppMaster(); + void _stdcall DrDestroyNativeAppMaster(void *ptr); + char* _stdcall DrGetExceptionMessage(void *ptr); + bool _stdcall DrScheduleProcess(void *ptr, int vertexId, const char* name, const char* commandLine); + +} + + diff --git a/DryadYarnBridge/DryadYarnBridge.vcxproj b/DryadYarnBridge/DryadYarnBridge.vcxproj new file mode 100644 index 0000000..a3e255e --- /dev/null +++ b/DryadYarnBridge/DryadYarnBridge.vcxproj @@ -0,0 +1,198 @@ + + + + + Debug + Win32 + + + Debug + x64 + + + Release + Win32 + + + Release + x64 + + + + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2} + DryadYarnBridge + ManagedCProj + + + + DynamicLibrary + true + + + DynamicLibrary + true + + + DynamicLibrary + true + Unicode + + + DynamicLibrary + Unicode + true + + + + + + + + + + + + + + + + + + + <_ProjectFileVersion>10.0.40219.1 + Debug\ + Debug\ + true + ..\bin\$(Configuration)\ + $(Platform)\$(Configuration)\ + true + Release\ + Release\ + true + $(Platform)\$(Configuration)\ + $(Platform)\$(Configuration)\ + false + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + + + + Disabled + c:\Apps\java\openjdk7\include\;C:\Apps\java\openjdk7\include\win32;%(AdditionalIncludeDirectories) + WIN32;_DEBUG;_WINDOWS;_USRDLL;DRYADYARNBRIDGEVS2008_EXPORTS;%(PreprocessorDefinitions) + MultiThreadedDebugDLL + + + Level3 + ProgramDatabase + + + true + Windows + main + MachineX64 + + + + + X64 + + + Disabled + $(JAVA_HOME)\include;$(JAVA_HOME)\include\win32;%(AdditionalIncludeDirectories) + WIN32;_DEBUG;_WINDOWS;_USRDLL;DRYADYARNBRIDGEVS2008_EXPORTS;%(PreprocessorDefinitions) + MultiThreadedDebugDLL + + + Level3 + ProgramDatabase + + + jvm.lib;%(AdditionalDependencies) + $(OutDir)$(TargetName)$(TargetExt) + $(JAVA_HOME)\lib;%(AdditionalLibraryDirectories) + DryadYarnBridge.def + true + true + NotSet + + + false + MachineX64 + + + + + WIN32;NDEBUG;_WINDOWS;_USRDLL;DRYADYARNBRIDGEVS2008_EXPORTS;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + + + true + Windows + true + true + main + MachineX86 + + + + + X64 + + + WIN32;NDEBUG;_WINDOWS;_USRDLL;DRYADYARNBRIDGEVS2008_EXPORTS;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + $(JAVA_HOME)\include;$(JAVA_HOME)\include\win32;%(AdditionalIncludeDirectories) + + + true + Windows + true + true + + + MachineX64 + jvm.lib;%(AdditionalDependencies) + DryadYarnBridge.def + $(JAVA_HOME)\lib;%(AdditionalLibraryDirectories) + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/DryadYarnBridge/DryadYarnBridge.vcxproj.filters b/DryadYarnBridge/DryadYarnBridge.vcxproj.filters new file mode 100644 index 0000000..dd29f49 --- /dev/null +++ b/DryadYarnBridge/DryadYarnBridge.vcxproj.filters @@ -0,0 +1,56 @@ + + + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hpp;hxx;hm;inl;inc;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + + + Source Files + + + \ No newline at end of file diff --git a/DryadYarnBridge/YarnAppMasterManaged.cpp b/DryadYarnBridge/YarnAppMasterManaged.cpp new file mode 100644 index 0000000..f76c2b9 --- /dev/null +++ b/DryadYarnBridge/YarnAppMasterManaged.cpp @@ -0,0 +1,128 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma managed +#include "YarnAppMasterNative.h" +#include "YarnAppMasterManaged.h" + +#include +#include +#include +using namespace System; +using namespace System::Text; +using namespace System::IO; +using namespace DryadYarn; + +namespace Microsoft { namespace Research { namespace Dryad { namespace YarnBridge +{ + AMInstance::AMInstance() + { + AMNativeInstance *instance = new AMNativeInstance(); + + if (instance->OpenInstance()) + { + m_instance = IntPtr(instance); + } + else + { + m_instance = IntPtr::Zero; + throw gcnew ApplicationException("Unable to initialize Yarn Native App Master Instance"); + } + } + + AMInstance::~AMInstance() + { + Close(); + } + + void AMInstance::Close() + { + if (m_instance != IntPtr::Zero) + { + AMNativeInstance *instance = (AMNativeInstance *) m_instance.ToPointer(); + delete instance; + m_instance = IntPtr::Zero; + } + } + + void AMInstance::Finish() + { + if (m_instance != IntPtr::Zero) + { + AMNativeInstance *instance = (AMNativeInstance *) m_instance.ToPointer(); + instance->Shutdown(); + } + } + + int AMInstance::GetHealthyNodeCount() + { + if (m_instance != IntPtr::Zero) + { + AMNativeInstance *instance = (AMNativeInstance *) m_instance.ToPointer(); + //NYI + return 1; + } + return -1; + } + + void AMInstance::ScheduleProcess(int vertexId, const char* name, const char* cmdLine) + { + AMNativeInstance *instance = (AMNativeInstance *) m_instance.ToPointer(); + bool result = instance->ScheduleProcess(vertexId, name, cmdLine); + + if (!result) + { + throw gcnew ApplicationException("Unable to schedule process"); + } + } + + void AMInstance::ScheduleProcess(int vertexId, String^ name, String^ cmdLine) + { + Console::WriteLine("Scheduling process: Vertex ID: {0}, Name: '{1}', Command Line: '{2}'", vertexId, name, cmdLine); + IntPtr namePtr = Marshal::StringToHGlobalAnsi(name); + const char* nameString = static_cast(namePtr.ToPointer()); + IntPtr cmdLinePtr = Marshal::StringToHGlobalAnsi(cmdLine); + const char* cmdLineString = static_cast(cmdLinePtr.ToPointer()); + + AMNativeInstance *instance = (AMNativeInstance *) m_instance.ToPointer(); + bool result = instance->ScheduleProcess(vertexId, nameString, cmdLineString); + + Marshal::FreeHGlobal(namePtr); + Marshal::FreeHGlobal(cmdLinePtr); + + + if (!result) + { + throw gcnew ApplicationException("Unable to schedule process"); + } + } + + void AMInstance::UpdateProcess(int vertexId, int state, String^ nodeName) + { + Console::Error->WriteLine("Calling GM Callback: Vertex Id: {0} State: {1} NodeName: '{2}'", vertexId, state, nodeName); + m_gmCallback(vertexId, state, nodeName); + } + + void AMInstance::RegisterGMCallback(UpdateProcessState^ callback) + { + m_gmCallback = callback; + } + +}}}} diff --git a/DryadYarnBridge/YarnAppMasterManaged.h b/DryadYarnBridge/YarnAppMasterManaged.h new file mode 100644 index 0000000..0f2a004 --- /dev/null +++ b/DryadYarnBridge/YarnAppMasterManaged.h @@ -0,0 +1,54 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma managed + +using namespace System; +using namespace System::Collections::Generic; +using namespace System::Runtime::InteropServices; + +namespace Microsoft { namespace Research { namespace Dryad { namespace YarnBridge +{ + + public delegate void UpdateProcessState(int vertexId, int state, String^ nodeName); + + public ref class AMInstance : public IDisposable + { + public: + AMInstance(); + ~AMInstance(); + + + void Close(); + void Finish(); + + int GetHealthyNodeCount(); + + void ScheduleProcess(int vertexId, const char* name, const char* cmdLine); + void ScheduleProcess(int vertexId, String^ name, String^ cmdLine); + + static void UpdateProcess(int vertexId, int state, String^ nodeName); + static void RegisterGMCallback(UpdateProcessState^ callback); + + private: + IntPtr m_instance; + static UpdateProcessState^ m_gmCallback; + }; +}}}} diff --git a/DryadYarnBridge/YarnAppMasterNative.cpp b/DryadYarnBridge/YarnAppMasterNative.cpp new file mode 100644 index 0000000..02c7403 --- /dev/null +++ b/DryadYarnBridge/YarnAppMasterNative.cpp @@ -0,0 +1,317 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma unmanaged +#include "YarnAppMasterNative.h" + +#include +#include +#include +#include + +static JavaVM* s_jvm = NULL; +//static FILE* s_logfile = NULL; + +static char* GetExceptionMessageLocal(JNIEnv* env, jclass cls, jobject obj) +{ + jfieldID fidMessage = env->GetFieldID( + cls, "exceptionMessage", "Ljava/lang/String;"); + + assert(fidMessage != NULL); + + jstring message = (jstring) env->GetObjectField(obj, fidMessage); + + char* msg = NULL; + + if (message == NULL) + { + msg = _strdup(""); + } + else + { + const char* msgCopy = (const char*)(env->GetStringUTFChars(message, NULL)); + msg = _strdup(msgCopy); + env->ReleaseStringUTFChars(message, msgCopy); + } + + env->DeleteLocalRef(message); + + return msg; +} + +static JNIEnv* AttachToJvm() +{ + JNIEnv* env; + int ret = s_jvm->AttachCurrentThread((void**) &env, NULL); + + assert(ret == JNI_OK); + + return env; +} + +namespace DryadYarn +{ +#define CHAR_BUFFER_SIZES 70000 + struct Env + { + JNIEnv* e; + }; + + class InstanceInternal + { + public: + jmethodID m_midSchProc; + jmethodID m_midShutdown; + jclass m_clsInstance; + jobject m_obj; + Instance* m_holder; + }; + + bool Initialize() + { + //::DebugBreak(); + if (s_jvm != NULL) + { + return true; + } + + jsize bufLen = 1; + jsize nVMs = -1; + int ret = JNI_GetCreatedJavaVMs(&s_jvm, bufLen, &nVMs); + if (ret < 0) + { + fprintf(stderr, "\nGetCreatedJavaVMs returned %d\n", ret); + return false; + } + + if (nVMs != 0) + { + fprintf(stderr, "\nProcess already contains %d Java VMs\n", nVMs); + return false; + } + + char classPath[_MAX_ENV]; + DWORD dRet = GetEnvironmentVariableA("JNI_CLASSPATH", classPath, _MAX_ENV); + if (dRet == 0) + { + fprintf(stderr, "Failed to get 'classpath' environment variable\n"); + return false; + } + + JavaVMInitArgs vm_args; + JNI_GetDefaultJavaVMInitArgs(&vm_args); + vm_args.version = JNI_VERSION_1_6; + + JavaVMOption options[1]; // increment when turning on verbose JNI + vm_args.nOptions = 1; + vm_args.options = options; + options[0].optionString = new char[_MAX_ENV]; + sprintf_s(options[0].optionString, _MAX_ENV, "-Djava.class.path=%s", classPath); + //fprintf(stderr, "JNI_CLASSPATH:[%s]\n", options[0].optionString); + //options[1].optionString = "-verbose:jni"; + /* + vm_args.nOptions = 1; + JavaVMOption options; + options.optionString = "-verbose:jni"; + vm_args.options = &options; + */ + vm_args.ignoreUnrecognized = 0; + + JNIEnv* env; + ret = JNI_CreateJavaVM(&s_jvm, (void**) &env, &vm_args); + + delete [] options[0].optionString; + + if (ret < 0) + { + s_jvm = NULL; + fprintf(stderr, "\nCreateJavaVM returned %d\n", ret); + fflush(stderr); + return false; + } + fflush(stderr); + return true; + } + + AMNativeInstance::AMNativeInstance() + { + /* + while (!::IsDebuggerPresent()) + { + printf("Waiting for debugger\n");fflush(stdout); + Sleep(1000); + } + ::DebugBreak(); + */ + m_inst = NULL; + m_env = NULL; + } + + AMNativeInstance::~AMNativeInstance() + { + if (m_inst != NULL && m_inst->m_obj != NULL) + { + m_env->e->DeleteGlobalRef(m_inst->m_obj); + } + delete m_inst->m_holder; + delete m_inst; + m_inst = NULL; + delete m_env; + m_env = NULL; + } + + bool AMNativeInstance::OpenInstance() + { + + //TODO Determine if we should detach the current thread from the jvm when exiting this call + if (Initialize()) + { + m_env = new Env; + m_env->e = AttachToJvm(); + } + else + { + return false; + } + + jclass clsDryadAppMaster = m_env->e->FindClass("com/microsoft/research/DryadAppMaster"); + if (clsDryadAppMaster == NULL) + { + jthrowable exc; + exc = m_env->e->ExceptionOccurred(); + if (exc) { + m_env->e->ExceptionDescribe(); + m_env->e->ExceptionClear(); + } + + fprintf(stderr, "Failed to find DryadAppMaster class\n"); + fprintf(stderr, "Destroying JVM\n"); + fflush(stderr); + s_jvm->DestroyJavaVM(); + return false; + } + + jmethodID midAMCons = m_env->e->GetMethodID(clsDryadAppMaster, "", "()V"); + assert(midAMCons != NULL); + + jobject localInstance = m_env->e->NewObject(clsDryadAppMaster, midAMCons); + + if (localInstance == NULL) + { + jthrowable exc; + exc = m_env->e->ExceptionOccurred(); + if (exc) { + m_env->e->ExceptionDescribe(); + m_env->e->ExceptionClear(); + } + + fprintf(stderr, "Failed to initialize DryadAppMaster\n"); + fprintf(stderr, "Destroying JVM\n"); + fflush(stderr); + s_jvm->DestroyJavaVM(); + return false; + } + + jmethodID midSchProc = m_env->e->GetMethodID(clsDryadAppMaster, "scheduleProcess", "(ILjava/lang/String;Ljava/lang/String;)V"); + if (midSchProc == NULL) + { + jthrowable exc; + exc = m_env->e->ExceptionOccurred(); + if (exc) { + m_env->e->ExceptionDescribe(); + m_env->e->ExceptionClear(); + } + + fprintf(stderr, "Failed to find DryadAppMaster.scheduleProcess method\n"); + fprintf(stderr, "Destroying JVM\n"); + fflush(stderr); + s_jvm->DestroyJavaVM(); + return false; + } + + jmethodID midShutdown = m_env->e->GetMethodID(clsDryadAppMaster, "shutdown", "()V"); + if (midSchProc == NULL) + { + jthrowable exc; + exc = m_env->e->ExceptionOccurred(); + if (exc) { + m_env->e->ExceptionDescribe(); + m_env->e->ExceptionClear(); + } + + fprintf(stderr, "Failed to find DryadAppMaster.shutdown method\n"); + fprintf(stderr, "Destroying JVM\n"); + fflush(stderr); + s_jvm->DestroyJavaVM(); + return false; + } + + m_inst = new InstanceInternal(); + + m_inst->m_clsInstance = clsDryadAppMaster; + m_inst->m_obj = m_env->e->NewGlobalRef(localInstance); + m_env->e->DeleteLocalRef(localInstance); + m_inst->m_midSchProc = midSchProc; + m_inst->m_midShutdown = midShutdown; + fprintf(stderr, "Created Instance\n"); + fflush(stderr); + return true; + } + + char* AMNativeInstance::GetExceptionMessage() + { + return GetExceptionMessageLocal(m_env->e, m_inst->m_clsInstance, m_inst->m_obj); + } + + bool AMNativeInstance::ScheduleProcess(int vertexId, const char* name, const char* commandLine) + { + fprintf(stderr, "Scheduling process %s\n", commandLine); + fflush(stderr); + JNIEnv* env = AttachToJvm(); + + jstring jName = env->NewStringUTF(name); + jstring jCmdLine = env->NewStringUTF(commandLine); + + env->CallVoidMethod(m_inst->m_obj, m_inst->m_midSchProc, vertexId, jName, jCmdLine); + + env->DeleteLocalRef(jName); + env->DeleteLocalRef(jCmdLine); + + // detach here? + + return true; + } + + bool AMNativeInstance::Shutdown() + { + fprintf(stderr, "Shutting down AMNativeInstance\n"); + fflush(stderr); + JNIEnv* env = AttachToJvm(); + + env->CallVoidMethod(m_inst->m_obj, m_inst->m_midShutdown); + + // detach here? + + return true; + } + + + +} \ No newline at end of file diff --git a/DryadYarnBridge/YarnAppMasterNative.h b/DryadYarnBridge/YarnAppMasterNative.h new file mode 100644 index 0000000..6e90864 --- /dev/null +++ b/DryadYarnBridge/YarnAppMasterNative.h @@ -0,0 +1,85 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once +#pragma unmanaged + +#include +#include +#include + +//--------------------------------------------------------------------------------------------------- + +namespace DryadYarn +{ + struct Instance + { + void* p; + }; + class InstanceInternal; + + struct Env; + bool Initialize(); + + class InstanceAccessor + { + public: + InstanceAccessor(Instance* instance); + ~InstanceAccessor(); + + void Dispose(); + + char* GetExceptionMessage(); + + bool ScheduleProcess(long vertexId, const char* name, const char* commandLine); + + private: + //void* operator new( size_t ); + //void* operator new[]( size_t ); + + Env* m_env; + InstanceInternal* m_inst; + }; + + bool OpenInstance(Instance** pInstance); + + class AMNativeInstance + { + public: + AMNativeInstance(); + ~AMNativeInstance(); + bool OpenInstance(); + + char* GetExceptionMessage(); + + bool ScheduleProcess(int vertexId, const char* name, const char* commandLine); + bool Shutdown(); + + private: + //void* operator new( size_t ); + //void* operator new[]( size_t ); + + Env* m_env; + InstanceInternal* m_inst; + }; + +}; +//--------------------------------------------------------------------------------------------------- + diff --git a/DryadYarnBridge/YarnDryadBridge.h b/DryadYarnBridge/YarnDryadBridge.h new file mode 100644 index 0000000..0cb3647 --- /dev/null +++ b/DryadYarnBridge/YarnDryadBridge.h @@ -0,0 +1,36 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once +#include + +extern "C" +{ + + //void Register(DryadYarnBridgeBase* bridge); + + void __stdcall SendVertexState(int vertexId, int state); + + JNIEXPORT void JNICALL Java_TestLib_SendVertexState(JNIEnv *env, jobject obj, jint vertexId, jint state); + + JNIEXPORT void JNICALL Java_com_microsoft_research_DryadAppMaster_SendVertexState(JNIEnv *env, jobject obj, int vertexId, jint state, jstring nodeName); + +} + diff --git a/DryadYarnBridge/dllmain.cpp b/DryadYarnBridge/dllmain.cpp new file mode 100644 index 0000000..4e1a0eb --- /dev/null +++ b/DryadYarnBridge/dllmain.cpp @@ -0,0 +1,42 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// dllmain.cpp : Defines the entry point for the DLL application. +#include "stdafx.h" + +#ifdef _MANAGED +#pragma unmanaged +#endif +BOOL APIENTRY DllMain( HMODULE /*hModule*/, + DWORD ul_reason_for_call, + LPVOID /*lpReserved*/ + ) +{ + switch (ul_reason_for_call) + { + case DLL_PROCESS_ATTACH: + case DLL_THREAD_ATTACH: + case DLL_THREAD_DETACH: + case DLL_PROCESS_DETACH: + break; + } + return TRUE; +} + diff --git a/DryadYarnBridge/stdafx.h b/DryadYarnBridge/stdafx.h new file mode 100644 index 0000000..97ecc08 --- /dev/null +++ b/DryadYarnBridge/stdafx.h @@ -0,0 +1,33 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// stdafx.h : include file for standard system include files, +// or project specific include files that are used frequently, but +// are changed infrequently +// + +#pragma once + +#define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers +// Windows Header Files: +#include +#include +#include + diff --git a/DryadYarnBridge/yarndryadbridge.cpp b/DryadYarnBridge/yarndryadbridge.cpp new file mode 100644 index 0000000..cb0954e --- /dev/null +++ b/DryadYarnBridge/yarndryadbridge.cpp @@ -0,0 +1,60 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma unmanaged +#include "stdafx.h" +#include "yarndryadbridge.h" +#include "YarnAppMasterManaged.h" + + +// #pragma managed +// push unmanaged state on to stack and set managed state +#pragma managed(push, on) + +void __stdcall SendVertexState(int vertexId, int state, const char *nodeName) +{ + String^ nodeNameString = gcnew String(nodeName); + //s_bridge->SendVertexState(vertexId, state); + //System::Console::Error->WriteLine("\nIn managed function."); + //System::Console::Error->WriteLine("YarnDryadBridge: Vertex Id: {0} State: {1} NodeName: '{2}'", vertexId, state, nodeNameString); + + Microsoft::Research::Dryad::YarnBridge::AMInstance::UpdateProcess(vertexId, state, nodeNameString); + //printf("Vertex Id: %I64d, State: %d\n", vertexId, state); +} +// #pragma unmanaged +#pragma managed(pop) + + + +void JNICALL Java_com_microsoft_research_TestLib_SendVertexState(JNIEnv *env, jobject obj, jint vertexId, jint state, jstring nodeName) +{ + const char* nodeNameCopy = (const char*)(env->GetStringUTFChars(nodeName, NULL)); + SendVertexState((int) vertexId, (int) state, nodeNameCopy); + env->ReleaseStringUTFChars(nodeName, nodeNameCopy); + //printf("Vertex Id: %d, State: %d\n", vertexId, state); +} + +void JNICALL Java_com_microsoft_research_DryadAppMaster_SendVertexState(JNIEnv *env, jobject obj, jint vertexId, jint state, jstring nodeName) +{ + const char* nodeNameCopy = (const char*)(env->GetStringUTFChars(nodeName, NULL)); + SendVertexState((int) vertexId, (int) state, nodeNameCopy); + env->ReleaseStringUTFChars(nodeName, nodeNameCopy); +} + diff --git a/GraphManager/GraphManager.vcxproj b/GraphManager/GraphManager.vcxproj new file mode 100644 index 0000000..52c423b --- /dev/null +++ b/GraphManager/GraphManager.vcxproj @@ -0,0 +1,296 @@ + + + + + Debug + Win32 + + + Debug + x64 + + + Release + Win32 + + + Release + x64 + + + + {8E30F4A4-603B-4799-A473-6EF5388661BA} + GraphManager-vs2008 + ManagedCProj + GraphManager + + + + DynamicLibrary + Pure + Unicode + + + DynamicLibrary + Unicode + Pure + + + DynamicLibrary + Pure + Unicode + + + DynamicLibrary + Unicode + Pure + v100 + + + + + + + + + + + + + + + + + + + <_ProjectFileVersion>10.0.40219.1 + Debug\ + Debug\ + true + ..\bin\$(Configuration)\ + $(Platform)\$(Configuration)\ + true + Release\ + Release\ + true + ..\bin\$(Configuration)\ + $(Platform)\$(Configuration)\ + false + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + Microsoft.Research.Dryad + Microsoft.Research.Dryad + + + + Disabled + shared;graph;gang;filesystem;vertex;kernel;jobmanager;reporting;stagemanager;%(AdditionalIncludeDirectories) + WIN32;_DEBUG;_WINDOWS;_USRDLL;GRAPHMANAGERVS2008_EXPORTS;%(PreprocessorDefinitions) + MultiThreadedDebugDLL + + + Level3 + ProgramDatabase + + + true + true + Windows + main + MachineX64 + + + + + X64 + + + Disabled + shared;graph;gang;filesystem;vertex;kernel;jobmanager;reporting;stagemanager;..\DryadYarnBridge;..\xcompute_native\inc;%(AdditionalIncludeDirectories) + WIN32;_DEBUG;_WINDOWS;_USRDLL;GRAPHMANAGERVS2008_EXPORTS;%(PreprocessorDefinitions) + MultiThreadedDebugDLL + + + Level3 + ProgramDatabase + + + dbghelp.lib;%(AdditionalDependencies) + $(OutDir)$(TargetName)$(TargetExt) + %(AdditionalLibraryDirectories) + true + true + Windows + + + MachineX64 + + + + + WIN32;NDEBUG;_WINDOWS;_USRDLL;GRAPHMANAGERVS2008_EXPORTS;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + shared;graph;gang;filesystem;vertex;kernel;jobmanager;reporting;stagemanager;%(AdditionalIncludeDirectories) + + + true + Windows + true + true + main + MachineX86 + + + + + X64 + + + WIN32;NDEBUG;_WINDOWS;_USRDLL;GRAPHMANAGERVS2008_EXPORTS;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + shared;graph;gang;filesystem;vertex;kernel;jobmanager;reporting;stagemanager;..\DryadYarnBridge;..\xcompute_native\inc;%(AdditionalIncludeDirectories) + + + true + Windows + true + true + + + MachineX64 + dbghelp.lib;%(AdditionalDependencies) + %(AdditionalLibraryDirectories) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + {09fb27c7-d1a5-4a59-b010-67d5886dd9a2} + + + {c0f4c1e3-1f9e-4c55-bd6a-0241d35425f5} + + + {e092e2b9-d3c9-4ce2-8201-bda442574c97} + true + true + false + true + true + + + + + + \ No newline at end of file diff --git a/GraphManager/GraphManager.vcxproj.filters b/GraphManager/GraphManager.vcxproj.filters new file mode 100644 index 0000000..7a08311 --- /dev/null +++ b/GraphManager/GraphManager.vcxproj.filters @@ -0,0 +1,348 @@ + + + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hpp;hxx;hm;inl;inc;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {fce69969-9dd1-40e1-9412-adeed0df7f1f} + + + {cffed16e-57ba-4ed6-aa94-9162cf5bf4f7} + + + {97301111-46f1-4970-ad41-96480ec1b43a} + + + {816f0b19-ef1a-4bd6-ab1c-0afcbad4165e} + + + {fd413de6-b103-4e17-9c12-ea48f0d084f0} + + + {c632d74d-d13e-4a37-ac3d-9d77efe27d41} + + + {af36e225-6dac-4224-8df1-2ef519270bf8} + + + {652668bf-e40c-4bc0-9cc1-6edb1fb61c1f} + + + {3dc23595-9e50-445e-9aa6-6ee9c47d4021} + + + {dcb126e3-8f5a-4f89-a829-8458a6382e7a} + + + {d864c522-2a2d-4644-8f63-48fc6597ef08} + + + {8752a8db-3f3f-4e61-86be-a1bfff840378} + + + {9a6b764b-fcc7-4fce-a6e9-247528c33aea} + + + {78301c66-d158-43e0-9591-613f63f90516} + + + {1ecff01b-d91e-495c-8e26-64376dd0aef9} + + + {1c257925-942e-4d20-8c7d-8d555f7ecf29} + + + + + Header Files + + + Header Files + + + Header Files\shared + + + Header Files\shared + + + Header Files\reporting + + + Header Files\shared + + + Header Files\vertex + + + Header Files\kernel + + + Header Files\vertex + + + Header Files\shared + + + Header Files\stagemanager + + + Header Files\graph + + + Header Files\shared + + + Header Files\stagemanager + + + Header Files\stagemanager + + + Header Files\stagemanager + + + Header Files\stagemanager + + + Header Files\shared + + + Header Files\shared + + + Header Files\graph + + + Header Files\filesystem + + + Header Files\shared + + + Header Files\gang + + + Header Files\vertex + + + Header Files\graph + + + Header Files\graph + + + Header Files\kernel + + + Header Files\shared + + + Header Files\kernel + + + Header Files\gang + + + Header Files\gang + + + Header Files\shared + + + Header Files\vertex + + + Header Files\filesystem + + + Header Files\stagemanager + + + Header Files\kernel + + + Header Files\vertex + + + Header Files\vertex + + + Header Files\vertex + + + Header Files\vertex + + + Header Files\kernel + + + Header Files\gang + + + Header Files\gang + + + Header Files\shared + + + Header Files\reporting + + + Header Files\shared + + + Header Files\shared + + + Header Files\shared + + + Header Files\vertex + + + Header Files\stagemanager + + + Header Files\stagemanager + + + Header Files\shared + + + Header Files\shared + + + Header Files\shared + + + Header Files\kernel + + + Header Files\kernel + + + Header Files\filesystem + + + + + Source Files + + + Source Files\filesystem + + + Source Files\filesystem + + + Source Files\reporting + + + Source Files\vertex + + + Source Files\vertex + + + Source Files\kernel + + + Source Files\stagemanager + + + Source Files\stagemanager + + + Source Files\stagemanager + + + Source Files\stagemanager + + + Source Files\stagemanager + + + Source Files\shared + + + Source Files\shared + + + Source Files\vertex + + + Source Files\graph + + + Source Files\graph + + + Source Files\shared + + + Source Files\kernel + + + Source Files\gang + + + Source Files\gang + + + Source Files\vertex + + + Source Files\stagemanager + + + Source Files\kernel + + + Source Files\shared + + + Source Files\shared + + + Source Files\stagemanager + + + Source Files\gang + + + Source Files\shared + + + Source Files\vertex + + + Source Files\vertex + + + Source Files\vertex + + + Source Files\kernel + + + Source Files\kernel + + + Source Files\kernel + + + Source Files\filesystem + + + \ No newline at end of file diff --git a/GraphManager/filesystem/DrFileSystems.h b/GraphManager/filesystem/DrFileSystems.h new file mode 100644 index 0000000..38e1164 --- /dev/null +++ b/GraphManager/filesystem/DrFileSystems.h @@ -0,0 +1,32 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +#include + + +#ifdef _MANAGED +//#include +#endif + +#include diff --git a/GraphManager/filesystem/DrHdfsClient.cpp b/GraphManager/filesystem/DrHdfsClient.cpp new file mode 100644 index 0000000..f921b13 --- /dev/null +++ b/GraphManager/filesystem/DrHdfsClient.cpp @@ -0,0 +1,468 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +#ifdef _MANAGED +using namespace System; +using namespace System::Collections::Generic; +using namespace System::Runtime::InteropServices; +#else +using namespace HdfsBridgeNative; +#include +#endif + +/* Returns 'name' from a stream URI of the form hpchdfs://server:port/name */ + +#ifdef _MANAGED +/* Returns 'host' from a UNC path of the form \\host\dir\file.ext */ +String ^HdfsStorageNodeFromReadPath(String ^readPath) +{ + String ^storageNode = String::Empty; + + if (readPath->StartsWith("\\\\")) + { + String ^temp = readPath->TrimStart('\\'); + int serverEnd = temp->IndexOf('\\'); + if (serverEnd > 0) + { + storageNode = temp->Substring(0, serverEnd); + } + } + + return storageNode; +} + + +HdfsInstance ^GetHdfsServiceInstance(String ^HdfsUri) +{ + return DrNew HdfsInstance(HdfsUri); +} + +HdfsInstance ^GetHdfsServiceInstance(DrString DrHdfsUri) +{ + return GetHdfsServiceInstance(DrNew String(DrHdfsUri.GetString())); +} +#else + +HdfsBridgeNative::Instance* GetHdfsServiceInstance(DrString DrHdfsUri) +{ + URL_COMPONENTSA UrlComponents = {0}; + UrlComponents.dwStructSize = sizeof(UrlComponents); + UrlComponents.dwHostNameLength = 1; + + BOOL fOK = InternetCrackUrlA(DrHdfsUri.GetChars(), DrHdfsUri.GetCharsLength(), 0, &UrlComponents); + if (!fOK) + { + return NULL; + } + + HdfsBridgeNative::Instance* instancePtr = NULL; + bool ret = OpenInstance(UrlComponents.lpszHostName, UrlComponents.nPort, &instancePtr); + if (ret) + { + return instancePtr; + } + else + { + return NULL; + } +} + +DrString FromInternalUri(DrString baseUri, DrString inputString) +{ + URL_COMPONENTSA UrlComponents = {0}; + UrlComponents.dwStructSize = sizeof(UrlComponents); + UrlComponents.dwSchemeLength = 1; + UrlComponents.dwHostNameLength = 1; + + BOOL fOK = InternetCrackUrlA(baseUri.GetChars(), baseUri.GetCharsLength(), 0, &UrlComponents); + if (!fOK) + { + DrLogA("Error getting stream path from HDFS URI."); + return DrNull; + } + DrString serviceUri; + serviceUri.AppendF("%s://%s:%d/", UrlComponents.lpszScheme, UrlComponents.lpszHostName, UrlComponents.nPort); + + if (inputString.Compare(serviceUri.GetChars(), serviceUri.GetCharsLength(), false) == 0) + { + return DrString(inputString.GetChars() + serviceUri.GetCharsLength());//inputString->Substring(m_serviceUri->Length); + } + else + { + return DrNull; + } +} + +DrString ToInternalUri(DrString serviceUri, DrString inputString) +{ + DrString resultString(serviceUri); + return resultString.AppendF("%s", inputString.GetChars()); +} + +#endif + +DrHdfsInputStream::DrHdfsInputStream() +{ + m_hdfsInstance = DrNull; +} + +HRESULT DrHdfsInputStream::Open(DrUniversePtr universe, DrNativeString streamUri) +{ + DrString uri = DrString(streamUri); + + return OpenInternal(universe, uri); +} + + +HRESULT DrHdfsInputStream::OpenInternal(DrUniversePtr universe, DrString streamUri) +{ + m_streamUri = streamUri; + HRESULT err = S_OK; + +#ifdef _MANAGED + + try + { +#endif + + + m_hdfsInstance = GetHdfsServiceInstance(streamUri); + +#ifdef _MANAGED + String ^StreamName = m_hdfsInstance->FromInternalUri(streamUri.GetString()); + HdfsFileInfo^ stream = m_hdfsInstance->GetFileInfo(StreamName, true); + m_fileNameArray = stream->fileNameArray; + UInt32 totalPartitionCount = static_cast(stream->blockArray->Length); + +#else + bool ret = HdfsBridgeNative::Initialize(); + if (!ret) + { + DrLogE("Error calling HdfsBridgeNative::Initialize()"); + return E_FAIL; + } + + if (m_hdfsInstance == NULL) + { + DrLogE("Error calling GetHdfsServiceInstance(streamUri)"); + return E_FAIL; + } + URL_COMPONENTSA UrlComponents = {0}; + UrlComponents.dwStructSize = sizeof(UrlComponents); + UrlComponents.dwUrlPathLength = 1; + UrlComponents.dwHostNameLength = 1; + + BOOL fOK = InternetCrackUrlA(streamUri.GetChars(), streamUri.GetCharsLength(), 0, &UrlComponents); + if (!fOK) + { + DrLogE("Error getting stream path from HDFS URI."); + return E_FAIL; + } + + m_hostname.Set(UrlComponents.lpszHostName); + m_portNum = UrlComponents.nPort; + + InstanceAccessor ia(m_hdfsInstance); + FileStat* fileStat = NULL; + ia.OpenFileStat(UrlComponents.lpszUrlPath, true, &fileStat); + UINT32 totalPartitionCount = 0; + HdfsBridgeNative::FileStatAccessor fs(fileStat); + totalPartitionCount = fs.GetNumberOfBlocks(); + + m_fileNameArray = (const char **)fs.GetFileNameArray(); +#endif + + /* Allocate these arrays even if they're size 0, to avoid + NullReferenceException later */ + m_affinity = DrNew DrAffinityArray(totalPartitionCount); + m_partOffsets = DrNew DrUINT64Array(totalPartitionCount); + m_partFileIds = DrNew DrUINT32Array(totalPartitionCount); + + for (UINT32 i=0; iblockArray[i]; +#else + HdfsBridgeNative::HdfsBlockLocInfo* partition = fs.GetBlockInfo(i); +#endif + m_affinity[i] = DrNew DrAffinity(); + m_affinity[i]->SetWeight(partition->Size); + m_partOffsets[i] = partition->Offset; + m_partFileIds[i] = partition->fileIndex; + +#ifdef _MANAGED + for (int j = 0; j < partition->Hosts->Length; ++j) +#else + for (int j = 0; j < partition->numberOfHosts; ++j) +#endif + { + DrResourceRef location = universe->LookUpResource(partition->Hosts[j]); + if (location != DrNull) + { + m_affinity[i]->AddLocality(location); + } + } +#ifndef _MANAGED + delete partition; +#endif + } +#ifdef _MANAGED + } + catch (System::Exception ^e) + { + err = System::Runtime::InteropServices::Marshal::GetHRForException(e); + } + finally + { + // TODO: How do we clean this up? + //hdfsInstance->Dispose(); + } +#endif + + return err; +} + +DrString DrHdfsInputStream::GetStreamName() +{ + return m_streamUri; +} + +int DrHdfsInputStream::GetNumberOfPartitions() +{ + return m_affinity->Allocated(); +} + +DrAffinityRef DrHdfsInputStream::GetAffinity(int partitionIndex) +{ + return m_affinity[partitionIndex]; +} + +DrString DrHdfsInputStream::GetURIForRead(int partitionIndex, + DrResourcePtr /* unused runningResource*/) +{ + DrString uri; + //Put HDFS service host and port in the input partition URI + +#ifdef _MANAGED + String ^HdfsStreamUri = DrNew String(m_streamUri.GetString()); + Uri ^HdfsServiceUri = DrNew Uri(HdfsStreamUri); + String ^HdfsPartitionUri = + String::Format("hpchdfspt://{0}:{1}/{2}?{3}?{4}", + HdfsServiceUri->Host, + HdfsServiceUri->Port, + m_fileNameArray[m_partFileIds[partitionIndex]], + m_partOffsets[partitionIndex], + m_affinity[partitionIndex]->GetWeight()); + uri.Set(HdfsPartitionUri); +#else + uri.SetF("hpchdfspt://%s:%d/%s?%I64u?%I64u", m_hostname, m_portNum, + m_fileNameArray[m_partFileIds[partitionIndex]], m_partOffsets[partitionIndex], + m_affinity[partitionIndex]->GetWeight()); +#endif + + return uri; +} + + +DrHdfsOutputStream::DrHdfsOutputStream() +{ + m_hdfsInstance = DrNull; +} + +HRESULT DrHdfsOutputStream::Open(DrNativeString streamUri) +{ + m_baseUri = streamUri; + m_numParts = -1; + + +#ifdef _MANAGED + try + { + m_hdfsInstance = GetHdfsServiceInstance(m_baseUri); + } + catch (System::Exception ^e) + { + return System::Runtime::InteropServices::Marshal::GetHRForException(e); + } +#else + bool ret = HdfsBridgeNative::Initialize(); + if (!ret) + { + DrLogE("Error calling HdfsBridgeNative::Initialize()"); + return E_FAIL; + } + + m_hdfsInstance = GetHdfsServiceInstance(streamUri); + if (m_hdfsInstance == NULL) + { + DrLogE("Error calling GetHdfsServiceInstance(streamUri)"); + return E_FAIL; + } +#endif + + return S_OK; +} + +void DrHdfsOutputStream::SetNumberOfPartitions(int numberOfPartitions) +{ + // For now, assume that the number of partitions cannot change + DrAssert(m_numParts == -1); + DrAssert(m_hdfsInstance != DrNull); + + m_numParts = numberOfPartitions; +} + +DrString DrHdfsOutputStream::GetURIForWrite(int partitionIndex, + int /* id*/, + int version, + int /* outputPort*/, + DrResourcePtr /*runningResource*/, + DrMetaDataRef /*metaData */) +{ + DrAssert(m_hdfsInstance != DrNull); + DrString fileName; + fileName.Set(m_baseUri); + //String^ fileName = m_baseUri + "-tmp/part-" + partitionIndex.ToString("D8") + "." + version; + fileName.AppendF("-tmp/part-%8d.%d", partitionIndex, version); + return fileName; +} + +void DrHdfsOutputStream::DiscardUnusedPartition(int partitionIndex, + int id, + int version, + int outputPort, + DrResourcePtr runningResource) +{ + DrAssert(m_hdfsInstance != DrNull); + + /* delete the partition if it has been created */ + DrString uriString = GetURIForWrite( + partitionIndex, + id, + version, + outputPort, + runningResource, + DrNull); + +#ifdef _MANAGED + String^ path = m_hdfsInstance->FromInternalUri(uriString.GetString()); + bool deleted = m_hdfsInstance->DeleteFile(path, false); +#else + + DrString path = FromInternalUri(m_baseUri, uriString); + InstanceAccessor ia(m_hdfsInstance); + bool deleted = false; + ia.DeleteFileOrDir((char *) path.GetChars(), false, &deleted); +#endif + + DrLogI( + "HDFS deleting failed version %s: %s", + uriString.GetChars(), (deleted) ? "succeeded" : "failed" + ); +} + +HRESULT DrHdfsOutputStream::FinalizeSuccessfulPartitions(DrOutputPartitionArrayRef partitionArray) +{ + DrAssert(m_numParts == partitionArray->Allocated()); + DrAssert(m_hdfsInstance != DrNull); + +#ifdef _MANAGED + String^ srcUri = m_baseUri + "-tmp"; + String^ srcPath = m_hdfsInstance->FromInternalUri(srcUri); + HdfsFileInfo^ directoryInfo = m_hdfsInstance->GetFileInfo(srcPath, false); + + if (directoryInfo == DrNull) + { + DrString drSrc(srcPath); + DrLogE("Can't read %s finalizing HDFS output", + drSrc.GetChars()); + return E_FAIL; + } + + if (directoryInfo->fileNameArray->Length == m_numParts) + { + String^ dstPath = m_hdfsInstance->FromInternalUri(m_baseUri); + + bool renamed = m_hdfsInstance->RenameFile(dstPath, srcPath); + if (!renamed) + { + DrString drSrc(srcPath); + DrString drDst(dstPath); + DrLogE("Can't rename %s to %s finalizing HDFS output", + drSrc.GetChars(), drDst.GetChars()); + return E_FAIL; + } + } + else + { + DrString drSrc(srcPath); + DrLogE("Won't rename %s: should contain %d files, but has %d", + drSrc.GetChars(), m_numParts, directoryInfo->fileNameArray->Length); + return E_FAIL; + } +#else + DrString srcUri(m_baseUri); + srcUri.AppendF("-tmp"); + DrString srcPath = FromInternalUri(m_baseUri, srcUri); + + InstanceAccessor ia(m_hdfsInstance); + FileStat* fs; + bool ret = ia.OpenFileStat(srcPath.GetChars(), false, &fs); + if (!ret) + { + char* msg = ia.GetExceptionMessage(); + DrLogE(msg); + free(msg); + return E_FAIL; + } + FileStatAccessor directoryInfo(fs); + if (directoryInfo.GetNumberOfFiles() == m_numParts) + { + DrString dstPath = FromInternalUri(m_baseUri, m_baseUri); + + bool renamed = false; + ia.RenameFileOrDir((char *)dstPath.GetChars(), (char *)srcPath.GetChars(), &renamed); + if (!renamed) + { + DrString drSrc(srcPath); + DrString drDst(dstPath); + DrLogE("Can't rename %s to %s finalizing HDFS output", + drSrc.GetChars(), drDst.GetChars()); + return E_FAIL; + } + } + else + { + DrString drSrc(srcPath); + DrLogE("Won't rename %s: should contain %d files, but has %d", + drSrc.GetChars(), m_numParts, directoryInfo.GetNumberOfFiles()); + return E_FAIL; + } +#endif + + return S_OK; +} + +void DrHdfsOutputStream::ExtendLease(DrTimeInterval /*lease*/) +{ + /* nothing to do here */ +} diff --git a/GraphManager/filesystem/DrHdfsClient.h b/GraphManager/filesystem/DrHdfsClient.h new file mode 100644 index 0000000..97e3ebe --- /dev/null +++ b/GraphManager/filesystem/DrHdfsClient.h @@ -0,0 +1,98 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#ifdef _MANAGED +using namespace Microsoft::Research::Dryad::Hdfs; +#else +#include "HdfsBridgeNative.h" +#endif + +DRCLASS(DrHdfsInputStream) : public DrInputStream +{ +public: + DrHdfsInputStream(); + HRESULT Open(DrUniversePtr universe, DrNativeString streamUri); + HRESULT OpenInternal(DrUniversePtr universe, DrString streamUri); + + virtual DrString GetStreamName() DROVERRIDE; + virtual int GetNumberOfPartitions() DROVERRIDE; + virtual DrAffinityRef GetAffinity(int partitionIndex) DROVERRIDE; + virtual DrString GetURIForRead(int partitionIndex, DrResourcePtr runningResource) DROVERRIDE; + +private: + DrString m_streamUri; + DrAffinityArrayRef m_affinity; + DrUINT64ArrayRef m_partIds; + DrUINT64ArrayRef m_partOffsets; + DrUINT32ArrayRef m_partFileIds; +#ifdef _MANAGED + array^ m_fileNameArray; + HdfsInstance^ m_hdfsInstance; +#else + HdfsBridgeNative::Instance* m_hdfsInstance; + DrString m_hostname; + int m_portNum; + const char** m_fileNameArray; + +#endif +}; +DRREF(DrHdfsInputStream); + +DRCLASS(DrHdfsOutputStream) : public DrOutputStream +{ +public: + DrHdfsOutputStream(); + + HRESULT Open(DrNativeString streamUri); + + virtual void SetNumberOfPartitions(int numberOfPartitions) DROVERRIDE; + virtual DrString GetURIForWrite( + int partitionIndex, + int id, + int version, + int outputPort, + DrResourcePtr runningResource, + DrMetaDataRef metaData) DROVERRIDE; + + virtual void DiscardUnusedPartition( + int partitionIndex, + int id, + int version, + int outputPort, + DrResourcePtr runningResource) DROVERRIDE; + + virtual HRESULT FinalizeSuccessfulPartitions( + DrOutputPartitionArrayRef partitionArray) DROVERRIDE; + + virtual void ExtendLease(DrTimeInterval) DROVERRIDE; + +private: + int m_numParts; +#ifdef _MANAGED + System::String^ m_baseUri; + HdfsInstance^ m_hdfsInstance; +#else + HdfsBridgeNative::Instance* m_hdfsInstance; + DrString m_baseUri; +#endif +}; +DRREF(DrPartitionOutputStream); diff --git a/GraphManager/filesystem/DrPartitionFile.cpp b/GraphManager/filesystem/DrPartitionFile.cpp new file mode 100644 index 0000000..c57885b --- /dev/null +++ b/GraphManager/filesystem/DrPartitionFile.cpp @@ -0,0 +1,518 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +static DrString ReadLineFromFile(FILE* f) +{ + DrString line(""); + + char buf[1024]; + char* s; + bool found = false; + bool foundEndOfLine = false; + + while ((s = fgets(buf, sizeof(buf)-1, f)) != NULL) + { + found = true; + size_t sLen; + + buf[sizeof(buf)-1] = '\0'; + + sLen = ::strlen(buf); + + if (sLen > 0 && buf[sLen-1] == '\n') + { + --sLen; + buf[sLen] = '\0'; + foundEndOfLine = true; + } + if (sLen > 0 && buf[sLen-1] == '\r') + { + --sLen; + buf[sLen] = '\0'; + foundEndOfLine = true; + } + + line = line.AppendF("%s", buf); + if (foundEndOfLine) + { + break; + } + } + + if (found) + { + return line; + } + else + { + DrString nullString; + return nullString; + } +} + +static bool ParseReplicatedFromPartitionLine(int partitionNumber, + DrAffinityPtr affinity, + DrStringR remoteName, + DrPartitionInputStream::OverridePtr over, + bool mustOverride, + DrString line, + DrUniversePtr universe) +{ + DrString lineCopy = line; + + int sep = lineCopy.IndexOfChar(','); + if (sep == DrStr_InvalidIndex) + { + return false; + } + else + { + DrString partitionNumberString; + partitionNumberString.SetSubString(lineCopy.GetChars(), sep); + DrString shorter; + shorter.SetF("%s", lineCopy.GetChars() + sep + 1); + lineCopy = shorter; + + int parsedNumber; + int n = sscanf_s(partitionNumberString.GetChars(), "%d", &parsedNumber); + if (n != 1 || parsedNumber != partitionNumber) + { + DrLogW("Mismatched partition numbers in line %s: Expected %d got %d", + line.GetChars(), partitionNumber, parsedNumber); + return false; + } + } + + UINT64 parsedSize = 0; + + sep = lineCopy.IndexOfChar(','); + if (sep == DrStr_InvalidIndex) + { + DrLogW("Malformed line %s: no list of machines", line.GetChars()); + return false; + } + else + { + DrString partitionSizeString; + partitionSizeString.SetSubString(lineCopy.GetChars(), sep); + DrString shorter; + shorter.SetF("%s", lineCopy.GetChars() + sep + 1); + lineCopy = shorter; + + int n = sscanf_s(partitionSizeString.GetChars(), "%I64u", &parsedSize); + if (n != 1) + { + DrLogW("Malformed line %s: can't parse size", line.GetChars()); + return false; + } + + affinity->SetWeight(parsedSize); + } + + if (lineCopy.GetCharsLength() == 0) + { + DrLogW("Malformed line %s: no partition machines", line.GetChars()); + return false; + } + + int numberOfReplicas = 0; + while (lineCopy.GetCharsLength() > 0) + { + DrString thisMachineName; + + sep = lineCopy.IndexOfChar(','); + if (sep == DrStr_InvalidIndex) + { + thisMachineName = lineCopy; + lineCopy = DrString(""); + } + else + { + thisMachineName.SetSubString(lineCopy.GetChars(), sep); + DrString shorter; + shorter.SetF("%s", lineCopy.GetChars() + sep + 1); + lineCopy = shorter; + } + + sep = thisMachineName.IndexOfChar(':'); + if (sep == DrStr_InvalidIndex) + { + thisMachineName = thisMachineName.ToUpperCase(); + } + else + { + DrString overrideFile; + overrideFile.SetF("%s", thisMachineName.GetChars() + sep + 1); + DrString shorter; + shorter.SetSubString(thisMachineName.GetChars(), sep); + thisMachineName = shorter.ToUpperCase(); + + over->Add(thisMachineName.GetString(), overrideFile); + } + + DrResourceRef location = universe->LookUpResourceInternal(thisMachineName); + if (location == DrNull) + { + remoteName.Set(thisMachineName); + } + else + { + affinity->AddLocality(location); + } + + ++numberOfReplicas; + } + + if (mustOverride && over->GetSize() == 0) + { + DrLogW("Malformed partition file: All filenames must be overrides when path is empty"); + return false; + } + + return true; +} + +HRESULT DrPartitionInputStream::Open(DrUniversePtr universe, DrNativeString streamName) +{ + return OpenInternal(universe, DrString(streamName)); +} + +HRESULT DrPartitionInputStream::OpenInternal(DrUniversePtr universe, DrString streamName) +{ + HRESULT err = S_OK; + + FILE* f; + errno_t ferr = fopen_s(&f, streamName.GetChars(), "rb"); + if (ferr != 0) + { + err = HRESULT_FROM_WIN32(GetLastError()); + DrLogW("Failed to open input file %s error %s", streamName.GetChars(), DRERRORSTRING(err)); + return err; + } + + m_pathNameOnComputer = ReadLineFromFile(f); + DrString partitionSizeLine = ReadLineFromFile(f); + + if (partitionSizeLine.GetString() == DrNull) + { + err = DrError_EndOfStream; + DrLogW("Failed to read pathname and partition size from input file %s", streamName.GetChars()); + fclose(f); + return err; + } + + bool mustOverride = false; + if (m_pathNameOnComputer.GetCharsLength() == 0) + { + mustOverride = true; + } + + int numberOfPartitions; + int n = sscanf_s(partitionSizeLine.GetChars(), "%d", &numberOfPartitions); + if (n != 1) + { + DrLogW("Unable to read partition size from line '%s' Filename %s", + partitionSizeLine.GetChars(), streamName.GetChars()); + fclose(f); + return DrError_Unexpected; + } + + if (numberOfPartitions == 0) + { + DrLogI("Read empty partitioned file details PathName: '%s'", + m_pathNameOnComputer.GetChars()); + fclose(f); + return S_OK; + } + + m_affinity = DrNew DrAffinityArray(numberOfPartitions); + m_remoteName = DrNew DrStringArray(numberOfPartitions); + m_override = DrNew OverrideArray(numberOfPartitions); + + DrLogI("Reading partitioned file details PathName: '%s' NumberOfPartitions=%d", + m_pathNameOnComputer.GetChars(), numberOfPartitions); + + int i; + for (i=0; iAllocated(); +} + +DrAffinityRef DrPartitionInputStream::GetAffinity(int partitionIndex) +{ + return m_affinity[partitionIndex]; +} + +DrString DrPartitionInputStream::GetURIForRead(int partitionIndex, DrResourcePtr runningResource) +{ + OverridePtr over = m_override[partitionIndex]; + DrAffinityPtr affinity = m_affinity[partitionIndex]; + DrResourceListRef location = affinity->GetLocalityArray(); + + DrString computerName; + + if (location->Size() == 0) + { + computerName = m_remoteName[partitionIndex]; + } + else + { + DrResourcePtr resource = DrNull; + + int i; + for (i=0; iSize(); ++i) + { + if (location[i] == runningResource) + { + resource = location[i]; + break; + } + } + + if (resource == DrNull) + { + for (i=0; iSize(); ++i) + { + if (location[i]->GetParent() == runningResource->GetParent()) + { + resource = location[i]; + break; + } + } + + if (resource == DrNull) + { + resource = location[rand() % location->Size()]; + } + } + + computerName = resource->GetName(); + } + + DrString uri; + + DrString overrideString; + if (over->TryGetValue(computerName.GetString(), overrideString)) + { + uri.SetF("file://\\\\%s\\%s", computerName.GetChars(), overrideString.GetChars()); + } + else + { + uri.SetF("file://\\\\%s\\%s.%08x", + computerName.GetChars(), m_pathNameOnComputer.GetChars(), partitionIndex); + } + + return uri; +} + +HRESULT DrPartitionOutputStream::Open(DrNativeString streamName, DrNativeString pathBase) +{ + return OpenInternal(DrString(streamName), DrString(pathBase)); +} + +HRESULT DrPartitionOutputStream::OpenInternal(DrString streamName, DrString pathBase) +{ + FILE* f; + errno_t ferr = fopen_s(&f, streamName.GetChars(), "w"); + if (ferr != 0) + { + HRESULT err = HRESULT_FROM_WIN32(GetLastError()); + DrLogW("Failed to open output file %s error %s", streamName.GetChars(), DRERRORSTRING(err)); + return err; + } + fclose(f); + + m_streamName = streamName; + m_pathBase = pathBase; + + return S_OK; +} + +void DrPartitionOutputStream::SetNumberOfPartitions(int /* unused numberOfPartitions*/) +{ +} + +DrString DrPartitionOutputStream::GetURIForWrite(int partitionIndex, + int id, int version, int outputPort, + DrResourcePtr runningResource, + DrMetaDataRef metaData) +{ + DrString uri; + uri.SetF("file://\\\\%s\\%s.%08x---%d_%d_%d.tmp", + runningResource->GetName().GetChars(), + m_pathBase.GetChars(), partitionIndex, id, outputPort, version); + + DrMTagVoidRef tag = DrNew DrMTagVoid(DrProp_TryToCreateChannelPath); + metaData->Append(tag); + + return uri; +} + +void DrPartitionOutputStream::DiscardUnusedPartition(int partitionIndex, + int id, int version, int outputPort, + DrResourcePtr runningResource) +{ + DrMetaDataRef metaData = DrNew DrMetaData(); + DrString uri = GetURIForWrite(partitionIndex, id, version, outputPort, runningResource, metaData); + BOOL bRet = ::DeleteFileA(uri.GetChars() + 7); + if (!bRet) + { + HRESULT err = HRESULT_FROM_WIN32(GetLastError()); + DrAssert(err != S_OK); + + if (err == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND)) + { + DrLogI("Delete ignoring nonexistent URI %s", uri.GetChars()); + } + else + { + DrLogW("DeleteFile(%s), error %s", uri.GetChars() + 7, DRERRORSTRING(err)); + } + } + else + { + DrLogI("Deleted URI %s", uri.GetChars()); + } +} + +HRESULT DrPartitionOutputStream::RenameSuccessfulPartition(int partitionIndex, DrOutputPartition p) +{ + DrMetaDataRef metaData = DrNew DrMetaData(); + DrString uri = GetURIForWrite(partitionIndex, p.m_id, + p.m_version, p.m_outputPort, p.m_resource, + metaData); + + DrString finalName; + finalName.SetF("\\\\%s\\%s.%08x", + p.m_resource->GetName().GetChars(), + m_pathBase.GetChars(), partitionIndex); + + HRESULT err = S_OK; + + BOOL bRet = ::MoveFileA(uri.GetChars()+7, finalName.GetChars()); + if (!bRet) + { + err = HRESULT_FROM_WIN32(GetLastError()); + DrAssert(err != S_OK); + + DrLogW("MoveFile(%s, %s), error %s", uri.GetChars()+7, finalName.GetChars(), DRERRORSTRING(err)); + } + else + { + DrLogI("Renamed Native URI %s -> %s", uri.GetChars(), finalName.GetChars()); + } + + return err; +} + +HRESULT DrPartitionOutputStream::FinalizeSuccessfulPartitions(DrOutputPartitionArrayRef partitionArray) +{ + HRESULT err = S_OK; + + if (!SUCCEEDED(err)) + { + return err; + } + + FILE* f; + errno_t ferr = fopen_s(&f, m_streamName.GetChars(), "w"); + if (ferr != 0) + { + HRESULT err = HRESULT_FROM_WIN32(GetLastError()); + DrLogW("Failed to open output file %s error %s", m_streamName.GetChars(), DRERRORSTRING(err)); + return err; + } + + fprintf(f, "%s\n%d\n", m_pathBase.GetChars(), partitionArray->Allocated()); + int i; + for (i=0; iAllocated(); ++i) + { + DrOutputPartition p = partitionArray[i]; + + HRESULT thisErr = RenameSuccessfulPartition(i, p); + if (err == S_OK && !SUCCEEDED(thisErr)) + { + err = thisErr; + } + + fprintf(f, "%d,%I64u,%s\n", i, p.m_size, p.m_resource->GetName().GetChars()); + } + + fclose(f); + + return err; +} + +void DrPartitionOutputStream::ExtendLease(DrTimeInterval) +{ +} + diff --git a/GraphManager/filesystem/DrPartitionFile.h b/GraphManager/filesystem/DrPartitionFile.h new file mode 100644 index 0000000..dbcd554 --- /dev/null +++ b/GraphManager/filesystem/DrPartitionFile.h @@ -0,0 +1,73 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRCLASS(DrPartitionInputStream) : public DrInputStream +{ +public: + HRESULT Open(DrUniversePtr universe, DrNativeString streamName); + HRESULT OpenInternal(DrUniversePtr universe, DrString streamName); + + virtual DrString GetStreamName() DROVERRIDE; + virtual int GetNumberOfPartitions() DROVERRIDE; + virtual DrAffinityRef GetAffinity(int partitionIndex) DROVERRIDE; + virtual DrString GetURIForRead(int partitionIndex, DrResourcePtr runningResource) DROVERRIDE; + + typedef DrStringStringDictionary Override; + DRREF(Override); + + typedef DrArray OverrideArray; + DRAREF(OverrideArray,OverrideRef); + +private: + DrString m_streamName; + DrString m_pathNameOnComputer; + DrAffinityArrayRef m_affinity; + DrStringArrayRef m_remoteName; + OverrideArrayRef m_override; +}; +DRREF(DrPartitionInputStream); + + +DRCLASS(DrPartitionOutputStream) : public DrOutputStream +{ +public: + HRESULT Open(DrNativeString streamName, DrNativeString pathBase); + HRESULT OpenInternal(DrString streamName, DrString pathBase); + + virtual void SetNumberOfPartitions(int numberOfPartitions) DROVERRIDE; + virtual DrString GetURIForWrite(int partitionIndex, int id, int version, int outputPort, + DrResourcePtr runningResource, + DrMetaDataRef metaData) DROVERRIDE; + virtual void DiscardUnusedPartition(int partitionIndex, int id, int version, int outputPort, + DrResourcePtr runningResource) DROVERRIDE; + virtual HRESULT FinalizeSuccessfulPartitions(DrOutputPartitionArrayRef partitionArray) DROVERRIDE; + virtual void ExtendLease(DrTimeInterval) DROVERRIDE; + +private: + HRESULT RenameSuccessfulPartition(int partitionIndex, DrOutputPartition p); + + DrString m_streamName; + DrString m_pathBase; + + DrOutputPartitionArrayRef m_successfulPartitions; +}; +DRREF(DrPartitionOutputStream); diff --git a/GraphManager/gang/DrGangHeaders.h b/GraphManager/gang/DrGangHeaders.h new file mode 100644 index 0000000..f644a45 --- /dev/null +++ b/GraphManager/gang/DrGangHeaders.h @@ -0,0 +1,28 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +#include +#include +#include +#include diff --git a/GraphManager/gang/DrMetaData.cpp b/GraphManager/gang/DrMetaData.cpp new file mode 100644 index 0000000..dda4d4a --- /dev/null +++ b/GraphManager/gang/DrMetaData.cpp @@ -0,0 +1,150 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +struct DrPropertyType +{ + UINT16 m_tag; + UINT16 m_type; +}; + +static DrPropertyType s_propertyType[] = +{ + { DrProp_ChannelState, DrMTT_HRESULT }, + { DrProp_ChannelURI, DrMTT_String }, + { DrProp_ChannelBufferOffset, DrMTT_UInt64 }, + { DrProp_ChannelTotalLength, DrMTT_UInt64 }, + { DrProp_ChannelProcessedLength, DrMTT_UInt64 }, + { DrProp_StreamExpireTimeWhileOpen, DrMTT_TimeInterval }, + { DrProp_StreamExpireTimeWhileClosed, DrMTT_TimeInterval }, + { DrProp_VertexState, DrMTT_HRESULT }, + { DrProp_VertexErrorCode, DrMTT_HRESULT }, + { DrProp_VertexId, DrMTT_UInt32 }, + { DrProp_VertexVersion, DrMTT_UInt32 }, + { DrProp_VertexInputChannelCount, DrMTT_UInt32 }, + { DrProp_VertexOutputChannelCount, DrMTT_UInt32 }, + { DrProp_VertexCommand, DrMTT_VertexCommand }, + { DrProp_VertexArgumentCount, DrMTT_UInt32 }, + { DrProp_VertexArgument, DrMTT_String }, + { DrProp_VertexSerializedBlock, DrMTT_Blob }, + { DrProp_DebugBreak, DrMTT_Boolean }, + { DrProp_AssertFailure, DrMTT_String }, + { DrProp_CanShareWorkQueue, DrMTT_Boolean }, + { DrProp_VertexMaxOpenInputChannelCount, DrMTT_UInt32 }, + { DrProp_VertexMaxOpenOutputChannelCount, DrMTT_UInt32 }, + { DrProp_ErrorCode, DrMTT_HRESULT }, + { DrProp_ErrorString, DrMTT_String }, + { DrProp_ItemBufferStartOffset, DrMTT_UInt64 }, + { DrProp_ItemBufferEndOffset, DrMTT_UInt64 }, + { DrProp_BufferLength, DrMTT_UInt64 }, + { DrProp_ItemStreamStartOffset, DrMTT_UInt64 }, + { DrProp_ItemStreamEndOffset, DrMTT_UInt64 }, + { DrProp_ItemDataSequenceNumber, DrMTT_UInt64 }, + { DrProp_ItemDeliverySequenceNumber, DrMTT_UInt64 }, + { DrProp_InputPortCount, DrMTT_UInt32 }, + { DrProp_OutputPortCount, DrMTT_UInt32 }, + { DrProp_NumberOfVertices, DrMTT_UInt32 }, + { DrProp_SourceVertex, DrMTT_UInt32 }, + { DrProp_SourcePort, DrMTT_UInt32 }, + { DrProp_DestinationVertex, DrMTT_UInt32 }, + { DrProp_DestinationPort, DrMTT_UInt32 }, + { DrProp_NumberOfEdges, DrMTT_UInt32 }, + { DrProp_TryToCreateChannelPath, DrMTT_Void }, + { DrProp_InitialChannelWriteSize, DrMTT_UInt64 }, + { 0xffff, 0xffff } +}; + +DrMetaData::DrMetaData() +{ + m_tagList = DrNew DrMTagList(); +} + +void DrMetaData::Append(DrMTagPtr tag) +{ + m_tagList->Add(tag); +} + +DrMTagListPtr DrMetaData::GetTags() +{ + return m_tagList; +} + +DrMTagPtr DrMetaData::LookUp(UINT16 enumId) +{ + int i; + for (i=0; iSize(); ++i) + { + if (m_tagList[i]->GetMTag() == enumId) + { + return m_tagList[i]; + } + } + + return DrNull; +} + +void DrMetaData::Serialize(DrPropertyWriterPtr writer) +{ + if (m_cachedSerialization == DrNull) + { + int i; + for (i=0; iSize(); ++i) + { + m_tagList[i]->Serialize(writer); + } + } + else + { + DRPIN(BYTE) src = &(m_cachedSerialization[0]); + writer->WriteBytes(src, m_cachedSerialization->Allocated()); + } +} + +void DrMetaData::CacheSerialization() +{ + DrPropertyWriterRef writer = DrNew DrPropertyWriter(); + Serialize(writer); + m_cachedSerialization = writer->GetBuffer(); +} + +HRESULT DrMetaData::ParseProperty(DrPropertyReaderPtr reader, UINT16 enumId, UINT32 dataLen) +{ + UINT16 type = DrMTT_Unknown; + int i; + for (i=0; s_propertyType[i].m_tag != 0xffff; ++i) + { + if (s_propertyType[i].m_tag == enumId) + { + type = s_propertyType[i].m_type; + break; + } + } + + DrMTagRef tag = DrMTag::MakeTyped(enumId, type); + + HRESULT status = tag->ParseProperty(reader, enumId, dataLen); + if (status == S_OK) + { + Append(tag); + } + + return status; +} \ No newline at end of file diff --git a/GraphManager/gang/DrMetaData.h b/GraphManager/gang/DrMetaData.h new file mode 100644 index 0000000..efcacd9 --- /dev/null +++ b/GraphManager/gang/DrMetaData.h @@ -0,0 +1,41 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + + +DRBASECLASS(DrMetaData), public DrPropertyParser +{ +public: + DrMetaData(); + + void Append(DrMTagPtr tag); + DrMTagPtr LookUp(UINT16 enumId); + DrMTagListPtr GetTags(); + + void Serialize(DrPropertyWriterPtr writer); + void CacheSerialization(); + virtual HRESULT ParseProperty(DrPropertyReaderPtr reader, UINT16 enumID, UINT32 dataLen); + +private: + DrMTagListRef m_tagList; + DrByteArrayRef m_cachedSerialization; +}; +DRREF(DrMetaData); \ No newline at end of file diff --git a/GraphManager/gang/DrMetaDataTag.cpp b/GraphManager/gang/DrMetaDataTag.cpp new file mode 100644 index 0000000..d708e10 --- /dev/null +++ b/GraphManager/gang/DrMetaDataTag.cpp @@ -0,0 +1,129 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +DrMTag::DrMTag(UINT16 tag, UINT16 type) +{ + m_tag = tag; + m_type = type; +} + +DrMTag::~DrMTag() +{ +} + +UINT16 DrMTag::GetMTag() +{ + return m_tag; +} + +UINT16 DrMTag::GetMType() +{ + return m_type; +} + +#define DRMTAGCASESTMT(_type) \ + case DrMTT_##_type: \ + return DrNew DrMTag##_type(tag); \ + break; \ + + +DrMTagRef DrMTag::MakeTyped(UINT16 tag, UINT16 type) +{ + switch (type) + { + DRMTAGCASESTMT(Void) + DRMTAGCASESTMT(Int16) + DRMTAGCASESTMT(Int32) + DRMTAGCASESTMT(Int64) + DRMTAGCASESTMT(UInt16) + DRMTAGCASESTMT(UInt32) + DRMTAGCASESTMT(UInt64) + DRMTAGCASESTMT(HRESULT) + DRMTAGCASESTMT(String) + + case DrMTT_Unknown: + default: + return DrNew DrMTagUnknown(tag, type); + } +}; + + +DrMTagUnknown::DrMTagUnknown(UINT16 tag, UINT16 originalType) : DrMTag(tag, DrMTT_Unknown) +{ + m_originalType = originalType; +} + +void DrMTagUnknown::SetData(DrByteArrayPtr data) +{ + m_data = data; +} + +DrByteArrayPtr DrMTagUnknown::GetData() +{ + return m_data; +} + +UINT16 DrMTagUnknown::GetOriginalType() +{ + return m_originalType; +} + +HRESULT DrMTagUnknown::ParseProperty(DrPropertyReaderPtr reader, UINT16 tag, UINT32 dataLen) +{ + DrAssert(tag == GetMTag()); + DrAssert(dataLen < 0x80000000); + + m_data = DrNew DrByteArray((int) dataLen); + HRESULT status; + { + DRPIN(BYTE) dst = &(m_data[0]); + status = reader->ReadNextProperty(tag, dataLen, dst); + } + if (status != S_OK) + { + m_data = DrNull; + } + return status; +} + +void DrMTagUnknown::Serialize(DrPropertyWriterPtr writer) +{ + DRPIN(BYTE) data = &(m_data[0]); + writer->WriteProperty(GetMTag(), m_data->Allocated(), data); +} + + +DrMTagVoid::DrMTagVoid(UINT16 tag) : DrMTag(tag, DrMTT_Void) +{ +} + +HRESULT DrMTagVoid::ParseProperty(DrPropertyReaderPtr reader, UINT16 tag, + UINT32 /* unused dataLen */) +{ + DrAssert(tag == GetMTag()); + return reader->ReadNextProperty(tag); +} + +void DrMTagVoid::Serialize(DrPropertyWriterPtr writer) +{ + writer->WriteProperty(GetMTag()); +} \ No newline at end of file diff --git a/GraphManager/gang/DrMetaDataTag.h b/GraphManager/gang/DrMetaDataTag.h new file mode 100644 index 0000000..8c473b1 --- /dev/null +++ b/GraphManager/gang/DrMetaDataTag.h @@ -0,0 +1,201 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +/* these are all the tag types used in the old Dryad: almost all of them will turn + up as DrMTagUnknown, but individual types can be added as desired */ +DRENUM(DrMTagType) +{ + DrMTT_Unknown, + DrMTT_Boolean, + DrMTT_Int16, + DrMTT_Int32, + DrMTT_Int64, + DrMTT_UInt16, + DrMTT_UInt32, + DrMTT_UInt64, + DrMTT_HexUInt16, + DrMTT_HexUInt32, + DrMTT_HexUInt64, + DrMTT_String, + DrMTT_Guid, + DrMTT_TimeStamp, + DrMTT_TimeInterval, + DrMTT_BeginTag, + DrMTT_EndTag, + DrMTT_HRESULT, + DrMTT_ExitCode, + DrMTT_Blob, + DrMTT_EnvironmentBlock, + DrMTT_PropertyList, + DrMTT_TagIdValue, + + DrMTT_AppendExtentOptions, + DrMTT_SyncOptions, + DrMTT_SyncDirectiveOptions, + DrMTT_ReadExtentOptions, + DrMTT_AppendStreamOptions, + DrMTT_EnumDirectoryOptions, + DrMTT_EnInfoBits, + DrMTT_UpdateExtentMetadataOptions, + DrMTT_StreamInfoBits, + DrMTT_ExtentInfoBits, + DrMTT_ExtentInstanceInfoBits, + + DrMTT_MetaData = 0x1000, + DrMTT_InputChannelDescription, + DrMTT_OutputChannelDescription, + DrMTT_VertexProcessStatus, + DrMTT_VertexStatus, + DrMTT_VertexCommandBlock, + + DrMTT_Void, + DrMTT_VertexCommand +}; + +DRDECLARECLASS(DrMTag); +DRREF(DrMTag); + +DRBASECLASS(DrMTag abstract), public DrPropertyParser +{ +public: + /* this returns a typed DrMTag. Almost all types just give DrMTagUnknown */ + static DrMTagRef MakeTyped(UINT16 tag, UINT16 type); + + UINT16 GetMTag(); + UINT16 GetMType(); + + virtual void Serialize(DrPropertyWriterPtr writer) = 0; + virtual HRESULT ParseProperty(DrPropertyReaderPtr reader, + UINT16 tag, UINT32 dataLen) = 0; + +protected: + DrMTag(UINT16 tag, UINT16 type); + virtual ~DrMTag(); + +private: + UINT16 m_tag; + UINT16 m_type; +}; + +DRCLASS(DrMTagUnknown) : public DrMTag +{ +public: + DrMTagUnknown(UINT16 tag, UINT16 originalType); + + void SetData(DrByteArrayPtr data); + DrByteArrayPtr GetData(); + UINT16 GetOriginalType(); + + virtual HRESULT ParseProperty(DrPropertyReaderPtr reader, + UINT16 tag, UINT32 dataLen) DROVERRIDE; + virtual void Serialize(DrPropertyWriterPtr writer) DROVERRIDE; + +private: + DrByteArrayRef m_data; + UINT16 m_originalType; +}; +DRREF(DrMTagUnknown); + +DRCLASS(DrMTagVoid) : public DrMTag +{ +public: + DrMTagVoid(UINT16 tag); + + virtual HRESULT ParseProperty(DrPropertyReaderPtr reader, + UINT16 tag, UINT32 dataLen) DROVERRIDE; + virtual void Serialize(DrPropertyWriterPtr writer) DROVERRIDE; +}; +DRREF(DrMTagVoid); + +template DRCLASS(DrMTagBase) : public DrMTag +{ +public: + DrMTagBase(UINT16 tag) : DrMTag(tag, _type) + { + } + + DrMTagBase(UINT16 tag, T value) : DrMTag(tag, _type) + { + m_value = value; + } + + void SetValue(T value) + { + m_value = value; + } + + T GetValue() + { + return m_value; + } + + virtual HRESULT ParseProperty(DrPropertyReaderPtr reader, + UINT16 tag, UINT32 /* unused dataLen*/) DROVERRIDE + { + DrAssert(tag == GetMTag()); + return reader->ReadNextProperty(tag, m_value); + } + + virtual void Serialize(DrPropertyWriterPtr writer) DROVERRIDE + { + writer->WriteProperty(GetMTag(), m_value); + } + +private: + T m_value; +}; + +typedef DrMTagBase DrMTagInt16; +DRTEMPLATE DrMTagBase; +DRREF(DrMTagInt16); + +typedef DrMTagBase DrMTagInt32; +DRTEMPLATE DrMTagBase; +DRREF(DrMTagInt32); + +typedef DrMTagBase DrMTagInt64; +DRTEMPLATE DrMTagBase; +DRREF(DrMTagInt64); + +typedef DrMTagBase DrMTagUInt16; +DRTEMPLATE DrMTagBase; +DRREF(DrMTagUInt16); + +typedef DrMTagBase DrMTagUInt32; +DRTEMPLATE DrMTagBase; +DRREF(DrMTagUInt32); + +typedef DrMTagBase DrMTagUInt64; +DRTEMPLATE DrMTagBase; +DRREF(DrMTagUInt64); + +typedef DrMTagBase DrMTagHRESULT; +DRTEMPLATE DrMTagBase; +DRREF(DrMTagHRESULT); + +typedef DrMTagBase DrMTagString; +DRTEMPLATE DrMTagBase; +DRREF(DrMTagString); + + +typedef DrArrayList DrMTagList; +DRAREF(DrMTagList,DrMTagRef); diff --git a/GraphManager/gang/DrProperties.h b/GraphManager/gang/DrProperties.h new file mode 100644 index 0000000..4801a6a --- /dev/null +++ b/GraphManager/gang/DrProperties.h @@ -0,0 +1,105 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +// property type flag +const UINT16 DrPropTypeMask = 0xc000; + +// DrPropType_Atom is a leaf property which is an element of a list or a +// set. There may be nested properties within the leaf. +const UINT16 DrPropType_Atom = 0x0000; + +// length type flag +const UINT16 DrPropLengthMask = 0x2000; + +// A property with DrPropLength_Short has a 1-byte length field +const UINT16 DrPropLength_Short = 0x0000; +// A property with DrPropLength_Long has a 4-byte length field +const UINT16 DrPropLength_Long = 0x2000; + +// mask for the remaining 13-bit namespace + +const UINT16 DrPropValueMask = 0x1fff; + +#define DRPROP_SHORTATOM(x_) ((x_) | DrPropType_Atom | DrPropLength_Short) +#define DRPROP_LONGATOM(x_) ((x_) | DrPropType_Atom | DrPropLength_Long) + +const UINT16 DrProp_BeginTag = DRPROP_SHORTATOM(0x1200); +const UINT16 DrProp_EndTag = DRPROP_SHORTATOM(0x1201); + +const UINT16 DrProp_ChannelState = DRPROP_SHORTATOM(0x4000); +const UINT16 DrProp_ChannelURI = DRPROP_LONGATOM(0x4003); +const UINT16 DrProp_ChannelBufferOffset = DRPROP_SHORTATOM(0x4004); +const UINT16 DrProp_ChannelTotalLength = DRPROP_SHORTATOM(0x4005); +const UINT16 DrProp_ChannelProcessedLength = DRPROP_SHORTATOM(0x4006); +const UINT16 DrProp_StreamExpireTimeWhileOpen = DRPROP_SHORTATOM(0x4007); +const UINT16 DrProp_StreamExpireTimeWhileClosed = DRPROP_SHORTATOM(0x4008); +const UINT16 DrProp_VertexState = DRPROP_SHORTATOM(0x4010); +const UINT16 DrProp_VertexErrorCode = DRPROP_SHORTATOM(0x4011); +const UINT16 DrProp_VertexId = DRPROP_SHORTATOM(0x4012); +const UINT16 DrProp_VertexVersion = DRPROP_SHORTATOM(0x4013); +const UINT16 DrProp_VertexInputChannelCount = DRPROP_SHORTATOM(0x4015); +const UINT16 DrProp_VertexOutputChannelCount = DRPROP_SHORTATOM(0x4016); +const UINT16 DrProp_VertexCommand = DRPROP_SHORTATOM(0x4017); +const UINT16 DrProp_VertexArgumentCount = DRPROP_SHORTATOM(0x4018); +const UINT16 DrProp_VertexArgument = DRPROP_LONGATOM(0x4019); +const UINT16 DrProp_VertexSerializedBlock = DRPROP_LONGATOM(0x401a); +const UINT16 DrProp_DebugBreak = DRPROP_SHORTATOM(0x401b); +const UINT16 DrProp_AssertFailure = DRPROP_LONGATOM(0x401c); +const UINT16 DrProp_CanShareWorkQueue = DRPROP_SHORTATOM(0x401d); +const UINT16 DrProp_VertexMaxOpenInputChannelCount = DRPROP_SHORTATOM(0x401e); +const UINT16 DrProp_VertexMaxOpenOutputChannelCount = DRPROP_SHORTATOM(0x401f); +const UINT16 DrProp_ErrorCode = DRPROP_SHORTATOM(0x4040); +const UINT16 DrProp_ErrorString = DRPROP_LONGATOM(0x4041); +const UINT16 DrProp_ItemBufferStartOffset = DRPROP_SHORTATOM(0x4042); +const UINT16 DrProp_ItemBufferEndOffset = DRPROP_SHORTATOM(0x4043); +const UINT16 DrProp_BufferLength = DRPROP_SHORTATOM(0x4044); +const UINT16 DrProp_ItemStreamStartOffset = DRPROP_SHORTATOM(0x4045); +const UINT16 DrProp_ItemStreamEndOffset = DRPROP_SHORTATOM(0x4046); +const UINT16 DrProp_ItemDataSequenceNumber = DRPROP_SHORTATOM(0x4047); +const UINT16 DrProp_ItemDeliverySequenceNumber = DRPROP_SHORTATOM(0x4048); +const UINT16 DrProp_InputPortCount = DRPROP_SHORTATOM(0x4060); +const UINT16 DrProp_OutputPortCount = DRPROP_SHORTATOM(0x4061); +const UINT16 DrProp_NumberOfVertices = DRPROP_SHORTATOM(0x4062); +const UINT16 DrProp_SourceVertex = DRPROP_SHORTATOM(0x4063); +const UINT16 DrProp_SourcePort = DRPROP_SHORTATOM(0x4064); +const UINT16 DrProp_DestinationVertex = DRPROP_SHORTATOM(0x4065); +const UINT16 DrProp_DestinationPort = DRPROP_SHORTATOM(0x4066); +const UINT16 DrProp_NumberOfEdges = DRPROP_SHORTATOM(0x4067); +const UINT16 DrProp_TryToCreateChannelPath = DRPROP_SHORTATOM(0x4068); +const UINT16 DrProp_InitialChannelWriteSize = DRPROP_SHORTATOM(0x4069); + + +const UINT16 DrTag_InputChannelDescription = 10000; +const UINT16 DrTag_OutputChannelDescription = 10001; +const UINT16 DrTag_VertexProcessStatus = 10002; +const UINT16 DrTag_VertexStatus = 10003; +const UINT16 DrTag_VertexCommand = 10004; +const UINT16 DrTag_ItemStart = 10005; +const UINT16 DrTag_ItemEnd = 10006; +const UINT16 DrTag_ChannelMetaData = 10007; +const UINT16 DrTag_VertexMetaData = 10008; +const UINT16 DrTag_ArgumentArray = 10009; +const UINT16 DrTag_VertexArray = 10010; +const UINT16 DrTag_VertexInfo = 10011; +const UINT16 DrTag_EdgeArray = 10012; +const UINT16 DrTag_EdgeInfo = 10013; +const UINT16 DrTag_GraphDescription = 10014; \ No newline at end of file diff --git a/GraphManager/gang/DrProperty.cpp b/GraphManager/gang/DrProperty.cpp new file mode 100644 index 0000000..1cc098f --- /dev/null +++ b/GraphManager/gang/DrProperty.cpp @@ -0,0 +1,556 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +DrPropertyReader::DrPropertyReader(DrByteArrayPtr byteArray) +{ + m_status = S_OK; + m_byteArray = byteArray; + m_readPtr = 0; +} + +HRESULT DrPropertyReader::SetStatus(HRESULT newStatus) +{ + if (m_status == S_OK) + { + m_status = newStatus; + } + + return m_status; +} + +HRESULT DrPropertyReader::PeekNextPropertyTag(/* out */ UINT16 *pEnumId, /* out */ UINT32 *pDataLen) +{ + if (PeekUInt16(pEnumId) != S_OK) + { + return m_status; + } + + BYTE tmp[sizeof(UINT16) + sizeof(UINT32)]; + + if (((*pEnumId) & DrPropLengthMask) == DrPropLength_Short) + { + if (PeekBytes(tmp, sizeof(UINT16) + sizeof(UINT8)) == S_OK) + { + *pDataLen = tmp[sizeof(UINT16)]; + } + } + else + { + if (PeekBytes(tmp, sizeof(UINT16) + sizeof(UINT32)) == S_OK) + { + memcpy(pDataLen, tmp+sizeof(UINT16), sizeof(UINT32)); + } + } + + return m_status; +} + +HRESULT DrPropertyReader::PeekNextAggregateTag(/* out */ UINT16 *pValue) +{ + UINT16 enumIdActual; + UINT32 dataLenActual; + UINT16 enumId = DrProp_BeginTag; + UINT32 dataLen = sizeof(UINT16); + + if (PeekNextPropertyTag(&enumIdActual, &dataLenActual) == S_OK) + { + if (enumIdActual != enumId || dataLenActual != dataLen) + { + DrLogW("Mismatched property peek %u,%u %u,%u", enumIdActual, enumId, dataLenActual, dataLen); + SetStatus(HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER)); + } + else + { + UINT32 hdrLen = sizeof(UINT16) + sizeof(UINT8); + + BYTE prop[sizeof(UINT16) + sizeof(UINT8) + sizeof(UINT16)]; + if (PeekBytes(prop, hdrLen + dataLen) == S_OK) + { + memcpy(pValue, prop + hdrLen, dataLen); + } + } + } + + return m_status; +} + +HRESULT DrPropertyReader::PeekBytes(/* out */ BYTE* pBytes, int length) +{ + if (m_status != S_OK) + { + return m_status; + } + else if (length + m_readPtr > m_byteArray->Allocated()) + { + return SetStatus(HRESULT_FROM_WIN32(ERROR_INSUFFICIENT_BUFFER)); + } + else + { + DRPIN(BYTE) srcPtr = &(m_byteArray[m_readPtr]); + memcpy(pBytes, srcPtr, length); + return S_OK; + } +} + +HRESULT DrPropertyReader::PeekUInt16(/* out */ UINT16 *pVal) +{ + // assumes little endian + return PeekBytes((BYTE *)(void *)pVal, sizeof(*pVal)); +} + +HRESULT DrPropertyReader::SkipBytes(int length) +{ + if (m_status != S_OK) + { + return m_status; + } + else if (length + m_readPtr > m_byteArray->Allocated()) + { + return SetStatus(HRESULT_FROM_WIN32(ERROR_INSUFFICIENT_BUFFER)); + } + else + { + m_readPtr += length; + return S_OK; + } +} + +HRESULT DrPropertyReader::ReadBytes(/* out */ BYTE* pBytes, int length) +{ + if (m_status != S_OK) + { + return m_status; + } + else if (length + m_readPtr > m_byteArray->Allocated()) + { + return SetStatus(HRESULT_FROM_WIN32(ERROR_INSUFFICIENT_BUFFER)); + } + else + { + DRPIN(BYTE) srcPtr = &(m_byteArray[m_readPtr]); + memcpy(pBytes, srcPtr, length); + m_readPtr += length; + return S_OK; + } +} + +HRESULT DrPropertyReader::ReadUInt8(/* out */ UINT8 *pVal) +{ + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); +} + +HRESULT DrPropertyReader::ReadUInt16(/* out */ UINT16 *pVal) +{ + // assumes little endian + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); +} + +HRESULT DrPropertyReader::ReadUInt32(/* out */ UINT32 *pVal) +{ + // assumes little endian + return ReadBytes((BYTE *)(void *)pVal, sizeof(*pVal)); +} + +HRESULT DrPropertyReader::ReadNextPropertyTag(/* out */ UINT16 *pEnumId, /* out */ UINT32 *pDataLen) +{ + if (ReadUInt16(pEnumId) != S_OK) + { + return m_status; + } + + if (((*pEnumId) & DrPropLengthMask) == DrPropLength_Short) + { + UINT8 lengthByte; + if (ReadUInt8(&lengthByte) == S_OK) + { + *pDataLen = lengthByte; + } + } + else + { + ReadUInt32(pDataLen); + } + + return m_status; +} + +HRESULT DrPropertyReader::ReadNextProperty(UINT16 enumId, UINT32 dataLen, void *pDest) +{ + UINT16 realEnumId; + UINT32 realDataLen; + + if (ReadNextPropertyTag(&realEnumId, &realDataLen) == S_OK) + { + if (realEnumId != enumId || realDataLen != dataLen) + { + DrLogW("Mismatched property read %u,%u %u,%u", realEnumId, enumId, realDataLen, dataLen); + SetStatus(HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER)); + } + else + { + ReadBytes((BYTE*) pDest, dataLen); + } + } + + return m_status; +} + +HRESULT DrPropertyReader::ReadNextProperty(UINT16 enumId) +{ + return ReadNextProperty(enumId, 0, NULL); +} + +#ifdef _MANAGED +#define MAKEDRPROPREADER(_type) \ + HRESULT DrPropertyReader::ReadNextProperty(UINT16 enumId, /* out */ _type %pValue) \ + { \ + _type tmp; \ + if (ReadNextProperty(enumId, sizeof(_type), &tmp) == S_OK) \ + { \ + pValue = tmp; \ + } \ + return m_status; \ + } +#else +#define MAKEDRPROPREADER(_type) \ + HRESULT DrPropertyReader::ReadNextProperty(UINT16 enumId, /* out */ _type &pValue) \ + { \ + _type tmp; \ + if (ReadNextProperty(enumId, sizeof(_type), &tmp) == S_OK) \ + { \ + pValue = tmp; \ + } \ + return m_status; \ + } +#endif + +MAKEDRPROPREADER(bool) +MAKEDRPROPREADER(INT8) +MAKEDRPROPREADER(INT16) +MAKEDRPROPREADER(INT32) +MAKEDRPROPREADER(INT64) +MAKEDRPROPREADER(UINT8) +MAKEDRPROPREADER(UINT16) +MAKEDRPROPREADER(UINT32) +MAKEDRPROPREADER(UINT64) +MAKEDRPROPREADER(HRESULT) +MAKEDRPROPREADER(float) +MAKEDRPROPREADER(double) +MAKEDRPROPREADER(GUID) + +HRESULT DrPropertyReader::ReadNextProperty(UINT16 enumId, /* out */ DrStringR pValue) +{ + UINT32 length; + UINT16 realEnumId; + + if (ReadNextPropertyTag(&realEnumId, &length) == S_OK) + { + if (realEnumId != enumId) + { + DrLogW("Mismatched string property read %u,%u", realEnumId, enumId); + SetStatus(HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER)); + } + else + { + if (length > 0) + { + DrByteArrayRef array = DrNew DrByteArray(length+1); + { + DRPIN(BYTE) dst = &(array[0]); + if (SUCCEEDED(ReadBytes(dst, length))) + { + /* ensure there's a terminator just in case */ + array[(int)length] = 0; + DrString s; + s.SetF("%s", (const char *) dst); + pValue = s; + return S_OK; + } + } + } + else + { + pValue = DrNull; + } + } + } + + return m_status; +} + +HRESULT DrPropertyReader::ReadAggregate(UINT16 desiredTagType, DrPropertyParserPtr parser) +{ + HRESULT err; + UINT16 beginTagType; + + if (ReadNextProperty(DrProp_BeginTag, sizeof(UINT16), &beginTagType) != S_OK) + { + return m_status; + } + + if (beginTagType != desiredTagType) + { + DrLogW("Mismatched aggregate read %u,%u", beginTagType, desiredTagType); + return SetStatus(HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER)); + } + + for (;;) + { + UINT16 propertyType; + UINT32 dataLen; + + if (PeekNextPropertyTag(&propertyType, &dataLen) != S_OK) + { + return m_status; + } + + // If we find an end tag, it must be for the begin tag we consumed + if (propertyType == DrProp_EndTag) + { + UINT16 endTagType; + + // Consume it + if (ReadNextProperty(DrProp_EndTag, endTagType) != S_OK) + { + return m_status; + } + + if (desiredTagType != endTagType) + { + DrLogW("Mismatched aggregate end read %u,%u", desiredTagType, endTagType); + return SetStatus(HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER)); + } + + // We're done + return S_OK; + } + else + { + // This could be a begin tag - it's up to the caller to call ReadAggregate() + // or SkipNextPropertyOrAggregate() + err = parser->ParseProperty(this, propertyType, dataLen); + if (err != S_OK) + { + return SetStatus(err); + } + } + } +} + +HRESULT DrPropertyReader::SkipNextProperty() +{ + UINT16 enumId; + UINT32 dataLen; + + if (ReadNextPropertyTag(&enumId, &dataLen) == S_OK) + { + SkipBytes(dataLen); + } + + return m_status; +} + +HRESULT DrPropertyReader::SkipNextPropertyOrAggregate() +{ + UINT32 dataLen; + UINT16 propertyType; + UINT16 beginTagType; + + if (PeekNextPropertyTag(&propertyType, &dataLen) != S_OK) + { + return m_status; + } + + // If it's not a begin tag, just skip the property and return + if (propertyType != DrProp_BeginTag) + { + return SkipNextProperty(); + } + + // Read the begin tag type + if (ReadNextProperty(DrProp_BeginTag, beginTagType) != S_OK) + { + return m_status; + } + + // Skip until corresponding end tag + // If another BeginTag is encountered, recurse as appropriate + for (;;) + { + if (PeekNextPropertyTag(&propertyType, &dataLen) != S_OK) + { + return m_status; + } + + if (propertyType == DrProp_BeginTag) + { + if (SkipNextPropertyOrAggregate() != S_OK) + { + return m_status; + } + } + else if (propertyType == DrProp_EndTag) + { + UINT16 endTagType; + + if (ReadNextProperty(DrProp_EndTag, endTagType) != S_OK) + { + return m_status; + } + + if (endTagType != beginTagType) + { + DrLogW("Mismatched aggregate matchup %u,%u", endTagType, beginTagType); + return SetStatus(HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER)); + } + + return S_OK; + } + else + { + if (SkipNextProperty() != S_OK) + { + return m_status; + } + } + } +} + + +DrPropertyWriter::DrPropertyWriter() +{ + m_buffer = DrNew DrByteArrayList(); +} + +DrByteArrayRef DrPropertyWriter::GetBuffer() +{ + DrByteArrayRef array = DrNew DrByteArray(m_buffer->Size()); + int i; + for (i=0; iSize(); ++i) + { + array[i] = m_buffer[i]; + } + return array; +} + +void DrPropertyWriter::WriteByte(BYTE b) +{ + m_buffer->Add(b); +} + +void DrPropertyWriter::WriteBytes(BYTE* pBytes, int length) +{ + int i; + for (i=0; i= 0) { + WriteBytes((BYTE *) string.GetChars(), length); + WriteByte((BYTE)0); + } +} + +void DrPropertyWriter::WriteProperty(UINT16 enumId, UINT32 dataLen, BYTE* pDest) +{ + WritePropertyTag(enumId, dataLen); + WriteBytes(pDest, dataLen); +} \ No newline at end of file diff --git a/GraphManager/gang/DrProperty.h b/GraphManager/gang/DrProperty.h new file mode 100644 index 0000000..3b7c620 --- /dev/null +++ b/GraphManager/gang/DrProperty.h @@ -0,0 +1,132 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRDECLARECLASS(DrPropertyReader); +DRREF(DrPropertyReader); + +DRINTERFACE(DrPropertyParser) +{ +public: + virtual HRESULT ParseProperty(DrPropertyReaderPtr reader, UINT16 enumId, UINT32 dataLength) DRABSTRACT; +}; +DRREF(DrPropertyParser); + +/* make managed/native variants of foo& or foo% for the types we need to marshal */ +DRRREF(bool); +DRRREF(INT8); +DRRREF(INT16); +DRRREF(INT32); +DRRREF(INT64); +DRRREF(UINT8); +DRRREF(UINT16); +DRRREF(UINT32); +DRRREF(UINT64); +DRRREF(HRESULT); +DRRREF(float); +DRRREF(double); +DRRREF(GUID); + +DRBASECLASS(DrPropertyReader) +{ +public: + DrPropertyReader(DrByteArrayPtr byteArray); + + HRESULT SetStatus(HRESULT newStatus); + + HRESULT PeekNextPropertyTag(/* out */ UINT16 *pEnumId, /* out */ UINT32 *pDataLen); + HRESULT PeekNextAggregateTag(/* out */ UINT16 *pValue); + + HRESULT ReadNextPropertyTag(/* out */ UINT16 *pEnumId, /* out */ UINT32 *pDataLen); + HRESULT ReadNextProperty(UINT16 enumId); + HRESULT ReadNextProperty(UINT16 enumId, /* out */ boolR pValue); + HRESULT ReadNextProperty(UINT16 enumId, /* out */ INT8R pValue); + HRESULT ReadNextProperty(UINT16 enumId, /* out */ INT16R pValue); + HRESULT ReadNextProperty(UINT16 enumId, /* out */ INT32R pValue); + HRESULT ReadNextProperty(UINT16 enumId, /* out */ INT64R pValue); + HRESULT ReadNextProperty(UINT16 enumId, /* out */ UINT8R pValue); + HRESULT ReadNextProperty(UINT16 enumId, /* out */ UINT16R pValue); + HRESULT ReadNextProperty(UINT16 enumId, /* out */ UINT32R pValue); + HRESULT ReadNextProperty(UINT16 enumId, /* out */ UINT64R pValue); + HRESULT ReadNextProperty(UINT16 enumId, /* out */ HRESULTR pValue); + HRESULT ReadNextProperty(UINT16 enumId, /* out */ floatR pValue); + HRESULT ReadNextProperty(UINT16 enumId, /* out */ doubleR pValue); + HRESULT ReadNextProperty(UINT16 enumId, /* out */ GUIDR pValue); + HRESULT ReadNextProperty(UINT16 enumId, /* out */ DrStringR pValue); + HRESULT ReadNextProperty(UINT16 enumId, UINT32 dataLen, /* out */ void *pDest); + HRESULT SkipNextProperty(); + + HRESULT ReadAggregate(UINT16 desiredTag, DrPropertyParserPtr parser); + HRESULT SkipNextPropertyOrAggregate(); + +private: + HRESULT PeekBytes(/* out */ BYTE *pBytes, int length); + HRESULT PeekUInt16(/* out */ UINT16 *pVal); + + HRESULT SkipBytes(int length); + + HRESULT ReadBytes(/* out */BYTE *pBytes, int length); + HRESULT ReadUInt8(/* out */ UINT8 *pVal); + HRESULT ReadUInt16(/* out */ UINT16 *pVal); + HRESULT ReadUInt32(/* out */ UINT32 *pVal); + + HRESULT m_status; + DrByteArrayRef m_byteArray; + int m_readPtr; +}; + +DRBASECLASS(DrPropertyWriter) +{ +public: + DrPropertyWriter(); + + DrByteArrayRef GetBuffer(); + + void WritePropertyTag(UINT16 enumId, UINT32 dataLen); + void WriteProperty(UINT16 enumId); + void WriteProperty(UINT16 enumId, bool value); + void WriteProperty(UINT16 enumId, INT8 value); + void WriteProperty(UINT16 enumId, INT16 value); + void WriteProperty(UINT16 enumId, INT32 value); + void WriteProperty(UINT16 enumId, INT64 value); + void WriteProperty(UINT16 enumId, UINT8 value); + void WriteProperty(UINT16 enumId, UINT16 value); + void WriteProperty(UINT16 enumId, UINT32 value); + void WriteProperty(UINT16 enumId, UINT64 value); + void WriteProperty(UINT16 enumId, float value); + void WriteProperty(UINT16 enumId, double value); + void WriteProperty(UINT16 enumId, GUID value); + void WriteProperty(UINT16 enumId, DrString value); + void WriteProperty(UINT16 enumId, UINT32 dataLen, BYTE *pDest); + + void WriteByte(BYTE b); + void WriteBytes(BYTE* pBytes, int length); + void WriteValue(UINT8 value); + void WriteValue(UINT16 value); + void WriteValue(UINT32 value); + +private: + void WritePropertyTagShort(UINT16 enumId, UINT8 dataLen); + void WritePropertyTagLong(UINT16 enumId, UINT32 dataLen); + + DrByteArrayListRef m_buffer; +}; +DRREF(DrPropertyWriter); diff --git a/GraphManager/graph/DrDefaultParameters.h b/GraphManager/graph/DrDefaultParameters.h new file mode 100644 index 0000000..1cf9620 --- /dev/null +++ b/GraphManager/graph/DrDefaultParameters.h @@ -0,0 +1,30 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRCLASS(DrDefaultParameters) +{ +public: + static DrGraphParametersRef Make(DrNativeString exeName, DrNativeString jobClass, bool enableSpeculativeDuplication); + + static DrProcessTemplateRef MakeProcessTemplate(DrNativeString exeName, DrNativeString jobClass); + static DrVertexTemplateRef MakeVertexTemplate(); +}; diff --git a/GraphManager/graph/DrFileSystem.cpp b/GraphManager/graph/DrFileSystem.cpp new file mode 100644 index 0000000..134f8b6 --- /dev/null +++ b/GraphManager/graph/DrFileSystem.cpp @@ -0,0 +1,206 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +DrInputStreamManager::DrInputStreamManager(DrInputStreamPtr stream, DrStageManagerPtr stage) +{ + m_stage = stage; + m_vertices = DrNew DrStorageVertexList(); + int numberOfPartitions = stream->GetNumberOfPartitions(); + + for (int i = 0; i < numberOfPartitions; ++i) + { + m_vertices->Add(DrNew DrStorageVertex(m_stage, i, stream)); + } + + SetName(m_stage->GetStageName()); +} + +void DrInputStreamManager::Discard() +{ + m_stage->Discard(); + m_stage = DrNull; + + int i; + for (i=0; iSize(); ++i) + { + m_vertices[i]->Discard(); + } + m_vertices = DrNull; +} + +DrStageManagerPtr DrInputStreamManager::GetStageManager() +{ + return m_stage; +} + +void DrInputStreamManager::SetName(DrString name) +{ + m_name = name; + + DrString nameBase; + if (m_name.GetString() == DrNull) + { + nameBase = "In"; + } + else + { + nameBase = m_name; + } + + int i; + for (i=0; iSize(); ++i) + { + DrString vertexName; + vertexName.SetF("%s[%d]", nameBase.GetChars(), i); + m_vertices[i]->SetName(vertexName); + } +} + +DrString DrInputStreamManager::GetName() +{ + return m_name; +} + +DrStorageVertexListPtr DrInputStreamManager::GetVertices() +{ + return m_vertices; +} + +DrOutputStreamManager::DrOutputStreamManager(DrOutputStreamPtr stream, DrStageManagerPtr stage) +{ + m_stream = stream; + m_stage = stage; + m_vertices = DrNew DrOutputVertexList(); + + SetName(stage->GetStageName()); +} + +void DrOutputStreamManager::Discard() +{ + m_stream = DrNull; + + m_stage->Discard(); + m_stage = DrNull; + + int i; + for (i=0; iSize(); ++i) + { + m_vertices[i]->Discard(); + } + m_vertices = DrNull; +} + +DrStageManagerPtr DrOutputStreamManager::GetStageManager() +{ + return m_stage; +} + +void DrOutputStreamManager::SetName(DrString name) +{ + m_name = name; + + DrString nameBase; + if (m_name.GetString() == DrNull) + { + nameBase = "Out"; + } + else + { + nameBase = m_name; + } + + int i; + for (i=0; iSize(); ++i) + { + DrString vertexName; + vertexName.SetF("%s[%d]", nameBase.GetChars(), i); + m_vertices[i]->SetName(vertexName); + } +} + +DrString DrOutputStreamManager::GetName() +{ + return m_name; +} + +void DrOutputStreamManager::SetNumberOfPartitions(int numberOfPartitions) +{ + m_vertices = DrNew DrOutputVertexList(numberOfPartitions); + + int i; + for (i=0; iAdd(v); + } + + SetName(m_name); + m_stream->SetNumberOfPartitions(numberOfPartitions); + +} + +DrOutputVertexListPtr DrOutputStreamManager::GetVertices() +{ + return m_vertices; +} + +void DrOutputStreamManager::AddDynamicSplitVertex(DrOutputVertexPtr newVertex) +{ + if (m_startedSplitting == false) + { + DrAssert(m_vertices->Size() == 1); + SetNumberOfPartitions(0); + m_startedSplitting = true; + } + + m_vertices->Add(newVertex); +} + +HRESULT DrOutputStreamManager::FinalizeSuccessfulPartitions() +{ + DrOutputPartitionArrayRef partitionArray = DrNew DrOutputPartitionArray(m_vertices->Size()); + + int i; + for (i=0; iSize(); ++i) + { + partitionArray[i] = m_vertices[i]->FinalizeVersions(); + } + + return m_stream->FinalizeSuccessfulPartitions(partitionArray); +} + +DrString DrOutputStreamManager::GetURIForWrite(int partitionIndex, int id, int version, int outputPort, + DrResourcePtr runningResource, DrMetaDataRef metaData) +{ + return m_stream->GetURIForWrite(partitionIndex, id, version, outputPort, runningResource, metaData); +} + +void DrOutputStreamManager::AbandonVersion(int partitionIndex, int id, int version, int outputPort, + DrResourcePtr runningResource) +{ + m_stream->DiscardUnusedPartition(partitionIndex, id, version, outputPort, runningResource); +} + +void DrOutputStreamManager::ExtendLease(DrTimeInterval lease) +{ + m_stream->ExtendLease(lease); +} \ No newline at end of file diff --git a/GraphManager/graph/DrFileSystem.h b/GraphManager/graph/DrFileSystem.h new file mode 100644 index 0000000..cd190b6 --- /dev/null +++ b/GraphManager/graph/DrFileSystem.h @@ -0,0 +1,114 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRBASECLASS(DrInputStream abstract), public DrIInputPartitionReader +{ +public: + virtual DrString GetStreamName() DRABSTRACT; + virtual int GetNumberOfPartitions() DRABSTRACT; + virtual DrAffinityRef GetAffinity(int partitionIndex) DRABSTRACT; + virtual DrString GetURIForRead(int partitionIndex, DrResourcePtr runningResource) DRABSTRACT; +}; +DRREF(DrInputStream); + +/* This class is needed for DrOutputStream. It is defined in DrVertex.h but repeated here for + convenience +DRVALUECLASS(DrOutputPartition) +{ +public: + int m_id; + int m_version; + int m_outputPort; + DrResourceRef m_resource; + UINT64 m_size; +}; +*/ + +DRBASECLASS(DrOutputStream abstract) +{ +public: + virtual void SetNumberOfPartitions(int numberOfPartitions) DRABSTRACT; + virtual DrString GetURIForWrite(int partitionIndex, int id, int version, int outputPort, + DrResourcePtr runningResource, + DrMetaDataRef metaData) DRABSTRACT; + virtual void DiscardUnusedPartition(int partitionIndex, int id, int version, int outputPort, + DrResourcePtr runningResource) DRABSTRACT; + virtual HRESULT FinalizeSuccessfulPartitions(DrOutputPartitionArrayRef partitionArray) DRABSTRACT; + virtual void ExtendLease(DrTimeInterval) DRABSTRACT; +}; +DRREF(DrOutputStream); + + +DRBASECLASS(DrInputStreamManager) +{ +public: + DrInputStreamManager(DrInputStreamPtr stream, DrStageManagerPtr stage); + + void Discard(); + + void SetName(DrString name); + DrString GetName(); + + DrStageManagerPtr GetStageManager(); + + DrStorageVertexListPtr GetVertices(); + +private: + DrString m_name; + DrStageManagerRef m_stage; + DrStorageVertexListRef m_vertices; +}; +DRREF(DrInputStreamManager); + + +DRBASECLASS(DrOutputStreamManager), public DrIOutputPartitionGenerator +{ +public: + DrOutputStreamManager(DrOutputStreamPtr stream, DrStageManagerPtr stage); + + void Discard(); + + void SetName(DrString name); + DrString GetName(); + + DrStageManagerPtr GetStageManager(); + + void SetNumberOfPartitions(int numberOfPartitions); + DrOutputVertexListPtr GetVertices(); + + /* the DrIOutputPartitionGenerator implementation */ + virtual void AddDynamicSplitVertex(DrOutputVertexPtr newVertex); + virtual HRESULT FinalizeSuccessfulPartitions(); + virtual DrString GetURIForWrite(int partitionIndex, int id, int version, int outputPort, + DrResourcePtr runningResource, DrMetaDataRef metaData); + virtual void AbandonVersion(int partitionIndex, int id, int version, int outputPort, + DrResourcePtr runningResource); + virtual void ExtendLease(DrTimeInterval); + +private: + DrString m_name; + DrOutputStreamRef m_stream; + DrStageManagerRef m_stage; + DrOutputVertexListRef m_vertices; + bool m_startedSplitting; +}; +DRREF(DrOutputStreamManager); diff --git a/GraphManager/graph/DrGraphExecutor.cpp b/GraphManager/graph/DrGraphExecutor.cpp new file mode 100644 index 0000000..5e4caf0 --- /dev/null +++ b/GraphManager/graph/DrGraphExecutor.cpp @@ -0,0 +1,91 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +DrGraphExecutor::DrGraphExecutor() +{ + m_event = ::CreateEvent(NULL, TRUE, FALSE, NULL); +} + +DrGraphExecutor::~DrGraphExecutor() +{ + ::CloseHandle(m_event); +} + +DrGraphPtr DrGraphExecutor::Initialize(DrGraphParametersPtr parameters) +{ + DrMessagePumpRef pump = DrNew DrMessagePump(8, 4); + pump->Start(); + + DrUniverseRef cluster = DrNew DrUniverse(); + + DrXComputeRef xc = DrXCompute::Create(); + if (SUCCEEDED( xc->Initialize(cluster, pump) )) + { + m_graph = DrNew DrGraph(xc, parameters); + } + + return m_graph; +} + +void DrGraphExecutor::Run() +{ + m_graph->AddListener(this); + { + DrAutoCriticalSection acs(m_graph); + m_graph->StartRunning(); + } +} + +void DrGraphExecutor::ReceiveMessage(DrErrorRef exitStatus) +{ + m_exitStatus = exitStatus; + ::SetEvent(m_event); +} + +DrErrorPtr DrGraphExecutor::Join() +{ + ::WaitForSingleObject(m_event, INFINITE); + + if (m_exitStatus && m_exitStatus->m_code != 0) + { + m_graph->GetXCompute()->CompleteProgress( m_exitStatus->m_explanation.GetChars()); + } + else + { + m_graph->GetXCompute()->CompleteProgress( "" ); + } + + DrMessagePumpRef pump = m_graph->GetXCompute()->GetMessagePump(); + + pump->Stop(); + + m_graph->CancelListener(this); + + DrUniverseRef cluster = m_graph->GetXCompute()->GetUniverse(); + cluster->Discard(); + + m_graph->Discard(); + + m_graph = DrNull; + + return m_exitStatus; +} diff --git a/GraphManager/graph/DrGraphExecutor.h b/GraphManager/graph/DrGraphExecutor.h new file mode 100644 index 0000000..cf2d6f5 --- /dev/null +++ b/GraphManager/graph/DrGraphExecutor.h @@ -0,0 +1,41 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRCLASS(DrGraphExecutor) : public DrCritSec, public DrErrorListener +{ +public: + DrGraphExecutor(); + ~DrGraphExecutor(); + + DrGraphPtr Initialize(DrGraphParametersPtr parameters); + + void Run(); + DrErrorPtr Join(); + + virtual void ReceiveMessage(DrErrorRef exitStatus); + +private: + DrGraphRef m_graph; + HANDLE m_event; + DrErrorRef m_exitStatus; +}; +DRREF(DrGraphExecutor); \ No newline at end of file diff --git a/GraphManager/graph/DrGraphHeaders.h b/GraphManager/graph/DrGraphHeaders.h new file mode 100644 index 0000000..cf12445 --- /dev/null +++ b/GraphManager/graph/DrGraphHeaders.h @@ -0,0 +1,27 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +#include +#include +#include diff --git a/GraphManager/graph/DrGraphParameters.cpp b/GraphManager/graph/DrGraphParameters.cpp new file mode 100644 index 0000000..a6fa5a7 --- /dev/null +++ b/GraphManager/graph/DrGraphParameters.cpp @@ -0,0 +1,76 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + + +DrProcessTemplateRef DrDefaultParameters::MakeProcessTemplate(DrNativeString exeName, DrNativeString jobClass) +{ + DrProcessTemplateRef t = DrNew DrProcessTemplate(); + + t->SetCommandLineBase(DrString(exeName)); + t->SetProcessClass(DrString(jobClass)); + + t->SetFailedRetainAndLeaseGraceTime(DrTimeInterval_Minute * 3, DrTimeInterval_Minute * 2); + t->SetCompletedRetainAndLeaseGraceTime(DrTimeInterval_Minute * 3, DrTimeInterval_Minute * 2); + + t->SetTimeOutBetweenProcessEndAndVertexNotification(DrTimeInterval_Second * 30); + + return t; +} + +DrVertexTemplateRef DrDefaultParameters::MakeVertexTemplate() +{ + DrVertexTemplateRef t = DrNew DrVertexTemplate(); + + t->SetStatusBlockTime(DrTimeInterval_Second * 10); + + return t; +} + +DrGraphParametersRef DrDefaultParameters::Make(DrNativeString exeName, DrNativeString jobClass, bool enableSpeculativeDuplication) +{ + DrGraphParametersRef p = DrNew DrGraphParameters(); + + p->m_processAbortTimeOut = DrTimeInterval_Second * 30; + p->m_maxActiveFailureCount = 6; + + p->m_duplicateEverythingThreshold = 10; + if(enableSpeculativeDuplication) + { + p->m_defaultOutlierThreshold = 10 * DrTimeInterval_Minute; + /* Wait until this fraction of vertices have completed before computing outlier time estimate */ + p->m_nonParametricThresholdFraction = 0.50; + } + else + { + // to disable speculative dupilcation, set minimum time to inifinite + // and require all vertices complete before calculating non-parametric threshold + p->m_defaultOutlierThreshold = DrTimeInterval_Infinite; + p->m_nonParametricThresholdFraction = 1.0; + } + + p->m_minOutlierThreshold = 10 * DrTimeInterval_Second; + + p->m_defaultProcessTemplate = MakeProcessTemplate(exeName, jobClass); + p->m_defaultVertexTemplate = MakeVertexTemplate(); + + return p; +} \ No newline at end of file diff --git a/GraphManager/jobmanager/DrHeaders.h b/GraphManager/jobmanager/DrHeaders.h new file mode 100644 index 0000000..a155f35 --- /dev/null +++ b/GraphManager/jobmanager/DrHeaders.h @@ -0,0 +1,31 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +//#define _CRTDBG_MAP_ALLOC +//#include +//#include + +#define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers +#include + +#include +#include \ No newline at end of file diff --git a/GraphManager/jobmanager/targetver.h b/GraphManager/jobmanager/targetver.h new file mode 100644 index 0000000..9aebc69 --- /dev/null +++ b/GraphManager/jobmanager/targetver.h @@ -0,0 +1,28 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +// Including SDKDDKVer.h defines the highest available Windows platform. + +// If you wish to build your application for a previous Windows platform, include WinSDKVer.h and +// set the _WIN32_WINNT macro to the platform you wish to support before including SDKDDKVer.h. + +#include diff --git a/GraphManager/jobmanager/version.cpp b/GraphManager/jobmanager/version.cpp new file mode 100644 index 0000000..9a08424 --- /dev/null +++ b/GraphManager/jobmanager/version.cpp @@ -0,0 +1,23 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#ifdef _MANAGED +[assembly:System::Runtime::InteropServices::ComVisible(false)]; +#endif diff --git a/GraphManager/kernel/DrCluster.cpp b/GraphManager/kernel/DrCluster.cpp new file mode 100644 index 0000000..1a7d1f2 --- /dev/null +++ b/GraphManager/kernel/DrCluster.cpp @@ -0,0 +1,338 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrKernel.h" + +template DRDECLARECLASS(DrStringDictionary); + +DrResource::DrResource(DrResourceLevel level, DrString name, DrString locality, DrResourcePtr parent) +{ + m_level = level; + m_name = name; + m_locality = locality; + m_parent = parent; + m_children = DrNew DrResourceList(); + + if (m_parent != DrNull) + { + m_parent->GetChildren()->Add(this); + } +} + +void DrResource::Discard() +{ + m_parent = DrNull; + if (m_children != DrNull) + { + int i; + for (i=0; iSize(); ++i) + { + m_children[i]->Discard(); + } + } + m_children = DrNull; +} + +DrResourceLevel DrResource::GetLevel() +{ + return m_level; +} + +DrString DrResource::GetName() +{ + return m_name; +} + +DrString DrResource::GetLocality() +{ + return m_locality; +} + +DrResourcePtr DrResource::GetParent() +{ + return m_parent; +} + +DrResourceListRef DrResource::GetChildren() +{ + return m_children; +} + +bool DrResource::Contains(DrResourcePtr resource) +{ + if (resource->GetLevel() > m_level) + { + return false; + } + else if (resource->GetLevel() == m_level) + { + return (resource == this); + } + else + { + return Contains(resource->GetParent()); + } +} + + +DrUniverse::DrUniverse() +{ + m_resourceLock = DrNew DrCritSec(); + + m_resourceAtLevel = DrNew RAArray(DRL_Cluster+1); + int i; + for (i=DRL_Core; i<=DRL_Cluster; ++i) + { + m_resourceAtLevel[i] = DrNew DrResourceList(); + } + m_resource = DrNew DrResourceDictionary(); +} + +void DrUniverse::Discard() +{ + int i; + for (i=DRL_Core; i<=DRL_Cluster; ++i) + { + int j; + for (j=0; jSize(); ++j) + { + m_resourceAtLevel[i][j]->Discard(); + } + m_resourceAtLevel[i] = DrNull; + } + m_resourceAtLevel = DrNull; + m_resource = DrNull; +} + +DrCritSecPtr DrUniverse::GetResourceLock() +{ + return m_resourceLock; +} + +void DrUniverse::AddResource(DrResourcePtr resource) +{ + m_resourceAtLevel[resource->GetLevel()]->Add(resource); + m_resource->Add(resource->GetName().GetString(), resource); +} + +DrResourceListRef DrUniverse::GetResources(DrResourceLevel level) +{ + return m_resourceAtLevel[level]; +} + +DrResourcePtr DrUniverse::LookUpResource(DrNativeString name) +{ + return LookUpResourceInternal(DrString(name)); +} + +DrResourcePtr DrUniverse::LookUpResourceInternal(DrString name) +{ + DrResourceRef resource; + if (m_resource->TryGetValue(name.GetString(), resource)) + { + return resource; + } + else + { + return DrNull; + } +} + + +DrAffinity::DrAffinity() +{ + m_isHardConstraint = false; + m_weight = 0; + m_locality = DrNew DrResourceList(); +} + +void DrAffinity::SetHardConstraint(bool isHardConstraint) +{ + m_isHardConstraint = isHardConstraint; +} + +bool DrAffinity::GetHardConstraint() +{ + return m_isHardConstraint; +} + +void DrAffinity::SetWeight(UINT64 weight) +{ + m_weight = weight; +} + +UINT64 DrAffinity::GetWeight() +{ + return m_weight; +} + +void DrAffinity::AddLocality(DrResourcePtr locality) +{ + m_locality->Add(locality); +} + +DrResourceListRef DrAffinity::GetLocalityArray() +{ + return m_locality; +} + +DrAffinityRef DrAffinityIntersector::IntersectHardConstraints(DrAffinityPtr existingConstraints, + DrAffinityListRef newAffinities) +{ + DrAffinityRef constraints = existingConstraints; + + int i; + for (i=0; iSize(); ++i) + { + DrAffinityPtr a = newAffinities[i]; + if (a->GetHardConstraint()) + { + if (constraints == DrNull) + { + constraints = DrNew DrAffinity(); + constraints->SetHardConstraint(true); + + int j; + for (j=0; jGetLocalityArray()->Size(); ++j) + { + constraints->GetLocalityArray()->Add(a->GetLocalityArray()[j]); + } + } + else + { + int j = 0; + while (jGetLocalityArray()->Size()) + { + DrResourcePtr c = constraints->GetLocalityArray()[j]; + + int k; + for (k=0; kGetLocalityArray()->Size(); ++k) + { + DrResourcePtr r = a->GetLocalityArray()[k]; + if (c->Contains(r) || r->Contains(c)) + { + if (r->GetLevel() < c->GetLevel()) + { + /* narrow the constraint */ + constraints->GetLocalityArray()[j] = r; + } + break; + } + } + if (k == a->GetLocalityArray()->Size()) + { + /* we can't keep this constraint */ + constraints->GetLocalityArray()->RemoveAt(j); + } + else + { + ++j; + } + } + } + } + } + + return constraints; +} + + +DrAffinityMerger::DrAffinityMerger() +{ + m_dictionary = DrNew ResourceWeightDictionary(); +} + +void DrAffinityMerger::AccumulateWeights(DrAffinityListRef affinityList) +{ + int i; + for (i=0; iSize(); ++i) + { + AccumulateWeights(affinityList[i]); + } +} + +void DrAffinityMerger::AccumulateWeights(DrAffinityPtr affinity) +{ + DrAssert(affinity->GetHardConstraint() == false); + + DrResourceListRef list = affinity->GetLocalityArray(); + int i; + for (i=0; iSize(); ++i) + { + DrResourcePtr r = list[i]; + do + { + UINT64 weight; + if (m_dictionary->TryGetValue(r, weight)) + { + weight += affinity->GetWeight(); + m_dictionary->Replace(r, weight); + } + else + { + m_dictionary->Add(r, affinity->GetWeight()); + } + + /* also accumulate the coarser-level information */ + r = r->GetParent(); + } while (r != DrNull); + } +} + +DrAffinityListRef DrAffinityMerger::GetMergedAffinities(DrFloatArrayRef levelThreshold) +{ + DrUINT64ArrayRef levelTotal = DrNew DrUINT64Array(DRL_Cluster + 1); + + int level; + for (level = DRL_Core; level <= DRL_Cluster; ++level) + { + levelTotal[level] = 0; + } + + ResourceWeightDictionary::DrEnumerator eSum = m_dictionary->GetDrEnumerator(); + while (eSum.MoveNext()) + { + levelTotal[eSum.GetKey()->GetLevel()] += eSum.GetValue(); + } + + for (level = DRL_Core; level <= DRL_Cluster; ++level) + { + levelTotal[level] = (UINT64) ((float) levelTotal[level] * levelThreshold[level]); + } + + DrAffinityListRef l = DrNew DrAffinityList(); + + ResourceWeightDictionary::DrEnumerator eFilter = m_dictionary->GetDrEnumerator(); + while (eFilter.MoveNext()) + { + DrResourcePtr resource = eFilter.GetKey(); + UINT64 weight = eFilter.GetValue(); + level = resource->GetLevel(); + if (weight >= levelTotal[level]) + { + DrAffinityRef a = DrNew DrAffinity(); + a->SetWeight(weight); + a->AddLocality(resource); + l->Add(a); + } + } + + return l; +} diff --git a/GraphManager/kernel/DrCluster.h b/GraphManager/kernel/DrCluster.h new file mode 100644 index 0000000..c37ed0b --- /dev/null +++ b/GraphManager/kernel/DrCluster.h @@ -0,0 +1,148 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRENUM(DrResourceLevel) +{ + DRL_Core, + DRL_Socket, + DRL_Computer, + DRL_Rack, + DRL_Cluster +}; + +DRDECLARECLASS(DrResource); +DRREF(DrResource); + +typedef DrArrayList DrResourceList; +DRAREF(DrResourceList,DrResourceRef); + +DRBASECLASS(DrResource) +{ +public: + DrResource(DrResourceLevel level, DrString name, DrString locality, DrResourcePtr parent); + + void Discard(); + + DrResourceLevel GetLevel(); + DrString GetName(); + DrString GetLocality(); + DrResourcePtr GetParent(); + DrResourceListRef GetChildren(); + + bool Contains(DrResourcePtr resource); + +private: + DrResourceLevel m_level; + DrString m_name; + DrString m_locality; + DrResourcePtr m_parent; /* does not hold a reference to its parent */ + DrResourceListRef m_children; +}; + +typedef DrStringDictionary DrResourceDictionary; +DRREF(DrResourceDictionary); +/* the following exercises the template machinery to avoid a spurious compiler error */ +template DRDECLARECLASS(DrStringDictionary); + +typedef DrArray DrResourceArray; +DRAREF(DrResourceArray,DrResourceRef); + +DRBASECLASS(DrUniverse) +{ +public: + DrUniverse(); + + void Discard(); + + DrCritSecPtr GetResourceLock(); + + void AddResource(DrResourcePtr resource); + + DrResourceListRef GetResources(DrResourceLevel level); + DrResourcePtr LookUpResource(DrNativeString name); + DrResourcePtr LookUpResourceInternal(DrString name); + +private: + typedef DrArray RAArray; + DRAREF(RAArray,DrResourceListRef); + + DrCritSecRef m_resourceLock; + RAArrayRef m_resourceAtLevel; + DrResourceDictionaryRef m_resource; +}; +DRREF(DrUniverse); + +DRBASECLASS(DrAffinity) +{ +public: + DrAffinity(); + + void SetHardConstraint(bool isHardConstraint); + bool GetHardConstraint(); + + void SetWeight(UINT64 weight); + UINT64 GetWeight(); + + void AddLocality(DrResourcePtr locality); + DrResourceListRef GetLocalityArray(); + +private: + bool m_isHardConstraint; + UINT64 m_weight; + DrResourceListRef m_locality; +}; +DRREF(DrAffinity); + +typedef DrArray DrAffinityArray; +DRAREF(DrAffinityArray,DrAffinityRef); + +typedef DrArrayList DrAffinityList; +DRAREF(DrAffinityList,DrAffinityRef); + +typedef DrArrayList DrAffinityListList; +DRAREF(DrAffinityListList,DrAffinityListRef); + +class DrAffinityIntersector +{ +public: + static DrAffinityRef IntersectHardConstraints(DrAffinityPtr existingConstraints, + DrAffinityListRef newAffinities); +}; + +DRBASECLASS(DrAffinityMerger) +{ +public: + DrAffinityMerger(); + + void AccumulateWeights(DrAffinityPtr affinity); + void AccumulateWeights(DrAffinityListRef affinityList); + + DrAffinityListRef GetMergedAffinities(DrFloatArrayRef levelThreshold); + +private: + typedef DrDictionary ResourceWeightDictionary; + DRREF(ResourceWeightDictionary); + + ResourceWeightDictionaryRef m_dictionary; +}; +DRREF(DrAffinityMerger); + diff --git a/GraphManager/kernel/DrKernel.h b/GraphManager/kernel/DrKernel.h new file mode 100644 index 0000000..46d2496 --- /dev/null +++ b/GraphManager/kernel/DrKernel.h @@ -0,0 +1,29 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +#include + +#include +#include +#include diff --git a/GraphManager/kernel/DrMessagePump.cpp b/GraphManager/kernel/DrMessagePump.cpp new file mode 100644 index 0000000..bf1d2b0 --- /dev/null +++ b/GraphManager/kernel/DrMessagePump.cpp @@ -0,0 +1,674 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +#define DRMESSAGEPUMP_CONTINUE (0) +#define DRMESSAGEPUMP_EXIT (1) + +DrMessageBase::~DrMessageBase() +{ +} + +DRCLASS(DrMessageDummy) : public DrMessageBase +{ +public: + virtual void Deliver() DROVERRIDE + { + DrLogA("Dummy message asked to deliver"); + } + + virtual DrCritSecPtr GetBaseLock() DROVERRIDE + { + DrLogA("Dummy message lock requested"); + return DrNull; + } +}; + + +DrOverlapped::~DrOverlapped() +{ +} + +HRESULT* DrOverlapped::GetOperationStatePtr() +{ + return &m_operationState; +} + + +DrMessagePump::DrMessagePump(int numWorkerThreads, + int numConcurrentThreads) +{ + m_state = MPS_Stopped; + m_numWorkerThreads = numWorkerThreads; + m_numConcurrentThreads = numConcurrentThreads; + m_completionPort = INVALID_HANDLE_VALUE; + m_threadHandle = DrNew ThreadArray(m_numWorkerThreads+1); +#ifndef _MANAGED + ThreadArrayR thArray = *m_threadHandle; + int i; + for (i=0; im_nextMessage = m_listDummy; + m_listDummy->m_prevMessage = m_listDummy; + m_listLength = 0; + + m_pendingMessages = DrNew MessageQueue(); + m_submittedOverlapped = DrNew OverlappedSet(); +} + +DrMessagePump::~DrMessagePump() +{ + DrAssert(m_completionPort == INVALID_HANDLE_VALUE); + /* free the circular references */ + m_listDummy->m_nextMessage = DrNull; + m_listDummy->m_prevMessage = DrNull; +} + +static DrDateTime GetSystemTimeStamp() +{ + union { + FILETIME ft; + DrDateTime ts; + }; + GetSystemTimeAsFileTime(&ft); + return ts; +} + +DrDateTime DrMessagePump::GetCurrentTimeStamp() +{ + return GetSystemTimeStamp(); +} + +void DrMessagePump::AddToListTail(DrMessageBasePtr message) +{ + DrAssert(message->m_nextMessage == DrNull); + DrAssert(message->m_prevMessage == DrNull); + + message->m_nextMessage = m_listDummy; + message->m_prevMessage = m_listDummy->m_prevMessage; + message->m_prevMessage->m_nextMessage = message; + message->m_nextMessage->m_prevMessage = message; + + ++m_listLength; +} + +void DrMessagePump::RemoveFromList(DrMessageBasePtr message) +{ + DrAssert(message != m_listDummy); + DrAssert(m_listLength > 0); + + --m_listLength; + + message->m_nextMessage->m_prevMessage = message->m_prevMessage; + message->m_prevMessage->m_nextMessage = message->m_nextMessage; + + message->m_nextMessage = DrNull; + message->m_prevMessage = DrNull; +} + +bool DrMessagePump::ListEmpty() +{ + if (m_listDummy->m_nextMessage == m_listDummy) + { + DrAssert(m_listDummy->m_prevMessage == m_listDummy); + DrAssert(m_listLength == 0); + return true; + } + else + { + DrAssert(m_listDummy->m_prevMessage != m_listDummy); + DrAssert(m_listLength > 0); + return false; + } +} + +void DrMessagePump::TimerThread() +{ + DrLogI("starting timer thread"); + + bool finished = false; + + do + { + Sleep(1000); + + { + DrAutoCriticalSection acs(this); + + DrDateTime currentTime = GetCurrentTimeStamp(); + + if (m_state == MPS_Running) + { + MessageQueue::Iter iter = m_pendingMessages->Begin(); + while (iter != m_pendingMessages->End() && iter->first <= currentTime) + { + EnQueueInternal(iter->second); + iter = m_pendingMessages->Erase(iter); + } + } + else + { + finished = true; + } + } + } while (!finished); + + DrLogI("exiting timer thread"); +} + +void DrMessagePump::ThreadMain(int threadId) +{ + DrAssert(m_completionPort != INVALID_HANDLE_VALUE); + + DrLogI("starting thread %d", threadId); + + bool finished = false; + + do + { + DWORD numBytes; + ULONG_PTR completionKey; + LPOVERLAPPED overlapped; + + BOOL retval = ::GetQueuedCompletionStatus(m_completionPort, + &numBytes, + &completionKey, + &overlapped, + INFINITE); + + bool mustDecrementCount = false; + if (retval != 0) + { + if (overlapped == NULL) + { + /* This is a queue wakeup event */ + mustDecrementCount = true; + finished = (numBytes == DRMESSAGEPUMP_EXIT); + + if (finished) + { + DrLogI("received shutdown event"); + } + else + { + //DrLogI("Received queued wakeup"); + } + } + else + { + /* This is an async completion event from xcompute */ + DrAssert(numBytes == 0); + DrAssert(completionKey == NULL); + DrOverlapped* messageWrapper = (DrOverlapped *) overlapped; + + { + DrAutoCriticalSection acs(this); +#ifdef _MANAGED + System::IntPtr messagePtr(messageWrapper); + bool removed = m_submittedOverlapped->Remove(messagePtr); + DrAssert(removed); +#else + bool removed = m_submittedOverlapped->Remove(messageWrapper); + DrAssert(removed); +#endif + } + + messageWrapper->Process(); + delete messageWrapper; + } + } + else + { + DWORD errCode = GetLastError(); + DrLogA("error code", "%d", errCode); + } + + bool foundMessage = false; + do + { + DrMessageBaseRef message = DrNull; + + { + DrAutoCriticalSection acs(this); + + message = m_listDummy->m_nextMessage; + + if (finished) + { + /* Received a shutdown message - verify that the message queue is now empty */ + DrAssert(message == m_listDummy); + } + + foundMessage = false; + while (!foundMessage && message != m_listDummy) + { + int i; + /* Check whether another thread is holding the same + lock that this message wants to acquire. If so, skip + this message so we don't block waiting to acquire + the lock. */ + for (i=0; iGetBaseLock()) + { + message = message->m_nextMessage; + break; + } + } + + if (i == m_numWorkerThreads) + { + /* Found a message - no other thread is holding the lock */ + foundMessage = true; + RemoveFromList(message); + m_currentListener[threadId] = message->GetBaseLock(); + } + } + + if (!foundMessage && mustDecrementCount) + { + mustDecrementCount = false; + DrAssert(m_numQueuedWakeUps > 0); + --m_numQueuedWakeUps; + } + } + + if (foundMessage) + { + /* this acquires the lock and sends the message */ + message->Deliver(); + + { + DrAutoCriticalSection acs(this); + + if (m_state == MPS_Stopping) + { + /* If the message pump is stopping - verify that there + are no more messages left on the queue */ + DrAssert(m_listLength == 0); + } + + m_currentListener[threadId] = DrNull; + + /* If we didn't receive a shutdown message, check whether + there are any more messages in the queue, and + wake up any free threads to help process them */ + if (!finished) + { + int numberOfSpareMessages = m_listLength; + int numberOfFreeThreads = 0; + int i; + for (i=0; i 0; ++i) + { + if (m_currentListener[i] == DrNull) + { + ++numberOfFreeThreads; + --numberOfSpareMessages; + } + } + + /* we are free by construction: if anyone else is, wake them up */ + for (i=m_numQueuedWakeUps; iStart((int) i); + } + + m_threadHandle[i] = DrNew System::Threading::Thread( + DrNew System::Threading::ThreadStart(this, &DrMessagePump::TimerFunc)); + m_threadHandle[i]->Start(); +} + +void DrMessagePump::WaitForThreads() +{ + int i; + for (i=0; iJoin(); + } +} + +#else + +#include + +struct threadBlock +{ + DrMessagePumpRef m_pump; + int m_threadId; +}; + +unsigned __stdcall DrMessagePump::TimerFunc(void* arg) +{ + threadBlock* tb = (threadBlock *) arg; + tb->m_pump->TimerThread(); + delete tb; + return 0; +} + +unsigned __stdcall DrMessagePump::ThreadFunc(void* arg) +{ + threadBlock* tb = (threadBlock *) arg; + tb->m_pump->ThreadMain(tb->m_threadId); + delete tb; + return 0; +} + +void DrMessagePump::StartThreads() +{ + ThreadArrayR thArray = *m_threadHandle; + threadBlock* tb; + unsigned threadAddr; + int i; + for (i=0; im_pump = this; + tb->m_threadId = i; + thArray[i] = + (HANDLE) ::_beginthreadex(NULL, + 0, + DrMessagePump::ThreadFunc, + tb, + 0, + &threadAddr); + DrAssert(thArray[i] != 0); + } + + tb = new threadBlock; + tb->m_pump = this; + tb->m_threadId = -1; + thArray[i] = + (HANDLE) ::_beginthreadex(NULL, + 0, + DrMessagePump::TimerFunc, + tb, + 0, + &threadAddr); + DrAssert(thArray[i] != 0); +} + +void DrMessagePump::WaitForThreads() +{ + ThreadArrayR thArray = *m_threadHandle; + DWORD waitRet = ::WaitForMultipleObjects(m_numWorkerThreads + 1, + &(thArray[0]), + TRUE, + INFINITE); + DrAssert(waitRet < (WAIT_OBJECT_0 + m_numWorkerThreads + 1)); + + { + DrAutoCriticalSection acs(this); + BOOL bRetval; + + int i; + for (i=0; im_nextMessage); + } + + m_numQueuedWakeUps += m_numWorkerThreads; + } + + int i; + for (i=0; iGetDrEnumerator(); + while (e.MoveNext()) + { + DrOverlapped* element; +#ifdef _MANAGED + element = (DrOverlapped*) e.GetElement().ToPointer(); +#else + element = e.GetElement(); +#endif + element->Discard(); + } + m_submittedOverlapped = DrNew OverlappedSet(); + + DrAssert(ListEmpty()); + + m_pendingMessages = DrNew MessageQueue(); + + m_completionPort = INVALID_HANDLE_VALUE; + m_state = MPS_Stopped; + } + + DrLogI("exiting"); +} + +void DrMessagePump::EnQueueInternal(DrMessageBasePtr message) +{ + AddToListTail(message); + + if (m_numQueuedWakeUps < m_numWorkerThreads) + { + ++m_numQueuedWakeUps; + BOOL retval = ::PostQueuedCompletionStatus(m_completionPort, + DRMESSAGEPUMP_CONTINUE, + NULL, + NULL); + if (retval == 0) + { + DWORD errCode = GetLastError(); + DrLogA("post completion status", "error code: %d", errCode); + } + } +} + +bool DrMessagePump::EnQueue(DrMessageBasePtr message) +{ + { + DrAutoCriticalSection acs(this); + + if (m_state == MPS_Stopping) + { + DrLogI("rejecting stopping item"); + return false; + } + else + { + DrAssert(m_state == MPS_Running); + } + + EnQueueInternal(message); + } + + return true; +} + +bool DrMessagePump::EnQueueDelayed(DrTimeInterval delay, DrMessageBasePtr message) +{ + DrAssert(delay > 0); + + { + DrAutoCriticalSection acs(this); + + DrDateTime currentTime = GetCurrentTimeStamp(); + + if (m_state == MPS_Stopping) + { + DrLogI("rejecting stopping item"); + return false; + } + else + { + DrAssert(m_state == MPS_Running); + } + + m_pendingMessages->Insert(currentTime + delay, message); + } + + return true; +} + +void DrMessagePump::NotifySubmissionToCompletionPort(DrOverlapped* overlapped) +{ + + DrAutoCriticalSection acs(this); + +#ifdef _MANAGED + System::IntPtr ptr(overlapped); + m_submittedOverlapped->Add(ptr); +#else + m_submittedOverlapped->Add(overlapped); +#endif +} \ No newline at end of file diff --git a/GraphManager/kernel/DrMessagePump.h b/GraphManager/kernel/DrMessagePump.h new file mode 100644 index 0000000..332b1e2 --- /dev/null +++ b/GraphManager/kernel/DrMessagePump.h @@ -0,0 +1,317 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +class WorkQueue; + +DRDECLARECLASS(DrMessageBase); +DRREF(DrMessageBase); + +DRBASECLASS(DrMessageBase abstract) +{ +public: + virtual ~DrMessageBase(); + virtual void Deliver() = 0; + virtual DrCritSecPtr GetBaseLock() = 0; + + DrMessageBaseRef m_nextMessage; + DrMessageBaseRef m_prevMessage; +}; + +template DRINTERFACE(DrListener) +{ +public: + virtual void ReceiveMessage(Notification message) DRABSTRACT; +}; + +template DRCLASS(DrMessage) : public DrMessageBase +{ + typedef DrListener Listener; + DRIREF(Listener); + +public: + DrMessage(ListenerPtr listener, Notification notification) + { + DrAssert(listener != DrNull); + DrICritSecPtr lCritSec = dynamic_cast(listener); + if (lCritSec == DrNull) + { + DrLogA("All listeners must inherit from DrICritSec"); + } + m_listener = lCritSec; + m_payload = notification; + } + + DrMessage(DrLockBox listener, Notification notification) + { + /* we are going to use this forbidden interface, because nobody actually gets + the listener out of here again except the pump, and that only delivers + it which takes place under the lock */ + m_listener = listener.DoNotUse(); + m_payload = notification; + } + + virtual ~DrMessage() + { + } + + virtual DrCritSecPtr GetBaseLock() DROVERRIDE + { + return m_listener->GetBaseLock(); + } + + virtual void Deliver() DROVERRIDE + { + DrAutoCriticalSection acs(m_listener); + + ListenerPtr listener = dynamic_cast((DrICritSecPtr) m_listener); + DrAssert(listener != DrNull); + + listener->ReceiveMessage(m_payload); + } + + void SetPayload(Notification payload) + { + m_payload = payload; + } + + Notification GetPayload() + { + return m_payload; + } + +private: + DrICritSecIRef m_listener; + Notification m_payload; +}; + +class DrOverlapped : public OVERLAPPED +{ +public: + virtual ~DrOverlapped(); + + HRESULT* GetOperationStatePtr(); + + virtual void Process() = 0; + virtual void Discard() = 0; + +private: + HRESULT m_operationState; +}; + +DRENUM(MessagePumpState) +{ + MPS_Stopped, + MPS_Running, + MPS_Stopping +}; + +/* stop spurious compiler warning by exercising template machinery */ +template ref class DrMultiMap; + +DRCLASS(DrMessagePump) : public DrCritSec +{ + public: + DrMessagePump(int numWorkerThreads, + int numConcurrentThreads); + ~DrMessagePump(); + + void Start(); + void Stop(); + + DrDateTime GetCurrentTimeStamp(); + + bool EnQueue(DrMessageBasePtr request); + bool EnQueueDelayed(DrTimeInterval delay, DrMessageBasePtr request); + + HANDLE GetCompletionPort(); + void NotifySubmissionToCompletionPort(DrOverlapped* overlapped); + +private: + typedef DrArray CSArray; + DRAREF(CSArray,DrCritSecRef); + typedef DrMultiMap MessageQueue; + DRREF(MessageQueue); + +#ifdef _MANAGED + typedef DrSet OverlappedSet; + DRREF(OverlappedSet); + + + typedef DrArray ThreadArray; + DRAREF(ThreadArray,System::Threading::Thread^); + void TimerFunc(); + void ThreadFunc(Object^ parameter); + +#else + typedef DrSet OverlappedSet; + DRREF(OverlappedSet); + + /* stop spurious compiler warning by exercising template machinery */ + template class DrMultiMap; + typedef DrArray ThreadArray; + DRAREF(ThreadArray,HANDLE); + static unsigned __stdcall TimerFunc(void* parameter); + static unsigned __stdcall ThreadFunc(void* parameter); + +#endif + + void AddToListTail(DrMessageBasePtr message); + void RemoveFromList(DrMessageBasePtr message); + bool ListEmpty(); + + void StartThreads(); + void WaitForThreads(); + + void TimerThread(); + void ThreadMain(int threadId); + + void EnQueueInternal(DrMessageBasePtr message); + + DrMessageBaseRef m_listDummy; /* list of messages */ + int m_listLength; + MessageQueueRef m_pendingMessages; /* list of delayed messages */ + OverlappedSetRef m_submittedOverlapped; /* set of overlapped objects in the completion port */ + + MessagePumpState m_state; + + int m_numWorkerThreads; + int m_numConcurrentThreads; + HANDLE m_completionPort; + ThreadArrayRef m_threadHandle; + CSArrayRef m_currentListener; + int m_numQueuedWakeUps; +}; +DRREF(DrMessagePump); + +template DRCLASS(DrNotifier) : public DrSharedCritSec +{ + typedef DrMessage Message; + DRREF(Message); + typedef DrListener Listener; + DRIREF(Listener); + typedef DrArrayList LArray; + DRAREF(LArray,ListenerIRef); + +public: + DrNotifier(DrMessagePumpPtr pump) : DrSharedCritSec(DrNew DrCritSec()) + { + m_pump = pump; + m_listener = DrNew LArray(); + } + + DrNotifier(DrMessagePumpPtr pump, DrICritSecPtr cs) : DrSharedCritSec(cs) + { + m_pump = pump; + m_listener = DrNew LArray(); + } + + bool AddListener(ListenerPtr listener) + { + int nListeners = m_listener->Size(); + int i; + for (i=0; iAdd(listener); + + return true; + } + + bool CancelListener(ListenerPtr listener) + { + return m_listener->Remove(listener); + } + +protected: + void DeliverNotification(Notification notification) + { + int nListeners = m_listener->Size(); + int i; + for (i=0; iEnQueue(message); + } + } + + void DeliverDelayedNotification(DrTimeInterval delay, Notification notification) + { + int nListeners = m_listener->Size(); + int i; + for (i=0; iEnQueueDelayed(delay, message); + } + } + + void DeliverMessage(DrMessageBasePtr message) + { + m_pump->EnQueue(message); + } + + void DeliverDelayedMessage(DrTimeInterval delay, DrMessageBasePtr message) + { + m_pump->EnQueueDelayed(delay, message); + } + +private: + DrMessagePumpRef m_pump; + LArrayRef m_listener; +}; + + + +typedef DrListener DrErrorListener; +DRIREF(DrErrorListener); + +typedef DrMessage DrErrorMessage; +DRREF(DrErrorMessage); + +typedef DrNotifier DrErrorNotifier; +DRREF(DrErrorNotifier); + +/* Output stream lease extension */ +typedef bool DrLeaseExtender; + +typedef DrListener DrLeaseListener; +DRIREF(DrLeaseListener); + +typedef DrMessage DrLeaseMessage; +DRREF(DrLeaseMessage); + +/* Graph shutdown */ +typedef HRESULT DrExitStatus; + +typedef DrListener DrShutdownListener; +DRIREF(DrShutdownListener); + +typedef DrMessage DrShutdownMessage; +DRREF(DrShutdownMessage); diff --git a/GraphManager/kernel/DrProcess.cpp b/GraphManager/kernel/DrProcess.cpp new file mode 100644 index 0000000..1f25280 --- /dev/null +++ b/GraphManager/kernel/DrProcess.cpp @@ -0,0 +1,507 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +DrProcessHandle::~DrProcessHandle() +{ +} + +void DrProcessHandle::SetAssignedNode(DrResourcePtr node) +{ + m_node = node; +} + +DrResourcePtr DrProcessHandle::GetAssignedNode() +{ + return m_node; +} + + +DrProcessTemplate::DrProcessTemplate() +{ + m_failedRetainTime = 0; + m_failedLeaseGraceTime = 0; + m_completedRetainTime = 0; + m_completedLeaseGraceTime = 0; + m_maxMemory = MAX_UINT64; + m_timeOutBetweenProcessEndAndVertexNotification = 0; + m_affinityLevelThresholds = DrNew DrFloatArray(DRL_Cluster + 1); + int i; + for (i=0; i<=DRL_Cluster; ++i) + { + /* default threshold is keep all affinities */ + m_affinityLevelThresholds[i] = 0.0f; + } + m_listenerList = DrNew DrProcessListenerIRefList(); +} + +void DrProcessTemplate::SetCommandLineBase(DrString commandLine) +{ + m_commandLineBase = commandLine; +} + +DrString DrProcessTemplate::GetCommandLineBase() +{ + return m_commandLineBase; +} + +DrProcessListenerIRefListPtr DrProcessTemplate::GetListenerList() +{ + return m_listenerList; +} + +void DrProcessTemplate::SetProcessClass(DrString processClass) +{ + m_processClass = processClass; +} + +DrString DrProcessTemplate::GetProcessClass() +{ + return m_processClass; +} + +void DrProcessTemplate::SetFailedRetainAndLeaseGraceTime(DrTimeInterval retainTime, + DrTimeInterval leaseGraceTime) +{ + DrAssert(retainTime > leaseGraceTime); + m_failedRetainTime = retainTime; + m_failedLeaseGraceTime = leaseGraceTime; +} + +DrTimeInterval DrProcessTemplate::GetFailedRetainTime() +{ + return m_failedRetainTime; +} + +DrTimeInterval DrProcessTemplate::GetFailedLeaseWaitTime() +{ + return m_failedRetainTime - m_failedLeaseGraceTime; +} + +void DrProcessTemplate::SetCompletedRetainAndLeaseGraceTime(DrTimeInterval retainTime, + DrTimeInterval leaseGraceTime) +{ + DrAssert(retainTime > leaseGraceTime); + m_completedRetainTime = retainTime; + m_completedLeaseGraceTime = leaseGraceTime; +} + +DrTimeInterval DrProcessTemplate::GetCompletedRetainTime() +{ + return m_completedRetainTime; +} + +DrTimeInterval DrProcessTemplate::GetCompletedLeaseWaitTime() +{ + return m_completedRetainTime - m_completedLeaseGraceTime; +} + +void DrProcessTemplate::SetMaxMemory(UINT64 maxMemory) +{ + m_maxMemory = maxMemory; +} + +UINT64 DrProcessTemplate::GetMaxMemory() +{ + return m_maxMemory; +} + +void DrProcessTemplate::SetTimeOutBetweenProcessEndAndVertexNotification(DrTimeInterval timeOut) +{ + m_timeOutBetweenProcessEndAndVertexNotification = timeOut; +} + +DrTimeInterval DrProcessTemplate::GetTimeOutBetweenProcessEndAndVertexNotification() +{ + return m_timeOutBetweenProcessEndAndVertexNotification; +} + +DrFloatArrayPtr DrProcessTemplate::GetAffinityLevelThresholds() +{ + return m_affinityLevelThresholds; +} + + +DrProcessStateRecord::DrProcessStateRecord() +{ + m_state = DPS_NotStarted; +} + +DrProcessStateRecordRef DrProcessStateRecord::Clone() +{ + DrProcessStateRecordRef r = DrNew DrProcessStateRecord(); + + r->m_state = m_state; + r->m_process = m_process; + r->m_exitCode = m_exitCode; + r->m_status = m_status; + + return r; +} + + +DrProcessStats::DrProcessStats() +{ + m_exitCode = STILL_ACTIVE; + m_pid = 0; + m_createdTime = DrDateTime_Never; + m_beginExecutionTime = DrDateTime_Never; + m_terminatedTime = DrDateTime_Never; + m_userTime = 0; + m_kernelTime = 0; + m_pageFaults = 0; + m_peakVMUsage = 0; + m_peakMemUsage = 0; + m_memUsageSeconds = 0; + m_totalIO = 0; +} + +bool DrProcessStats::Different(DrProcessStatsPtr other) +{ + return + m_exitCode != other->m_exitCode || + m_pid != other->m_pid || + m_createdTime != other->m_createdTime || + m_beginExecutionTime != other->m_beginExecutionTime || + m_terminatedTime != other->m_terminatedTime || + m_userTime != other->m_userTime || + m_kernelTime != other->m_kernelTime || + m_pageFaults != other->m_pageFaults || + m_peakVMUsage != other->m_peakMemUsage || + m_peakMemUsage != other->m_peakMemUsage || + m_memUsageSeconds != other->m_memUsageSeconds || + m_totalIO != other->m_totalIO; +} + +DrProcessStatsRef DrProcessStats::Clone() +{ + DrProcessStatsRef ps = DrNew DrProcessStats(); + + ps->m_exitCode = m_exitCode; + ps->m_pid = m_pid; + ps->m_createdTime = m_createdTime; + ps->m_beginExecutionTime = m_beginExecutionTime; + ps->m_terminatedTime = m_terminatedTime; + ps->m_userTime = m_userTime; + ps->m_kernelTime = m_kernelTime; + ps->m_pageFaults = m_pageFaults; + ps->m_peakVMUsage = m_peakVMUsage; + ps->m_peakMemUsage = m_peakMemUsage; + ps->m_memUsageSeconds = m_memUsageSeconds; + ps->m_totalIO = m_totalIO; + + return ps; +} + + +DrPropertyStatus::DrPropertyStatus(DrProcessBasicState state, UINT32 exitCode, DrErrorPtr error) +{ + m_processState = state; + m_exitCode = exitCode; + m_status = error; + m_statusVersion = 0; +} + + +DrProcess::DrProcess(DrXComputePtr xc, DrString name, DrString commandLine, + DrProcessTemplatePtr processTemplate) + : DrNotifier(xc->GetMessagePump()) +{ + m_xc = xc; + m_name = name; + m_commandLine = commandLine; + m_template = processTemplate; + + m_affinity = DrNew DrAffinityList(); + + m_info = DrNew DrProcessInfo(); + m_info->m_process = DrNull; /* don't create a circular reference */ + m_info->m_state = DrNew DrProcessStateRecord(); + m_info->m_statistics = DrNew DrProcessStats(); + + m_info->m_jmProcessCreatedTime = m_xc->GetCurrentTimeStamp(); + m_info->m_jmProcessScheduledTime = DrDateTime_Never; + + m_hasEverRequestedProperty = false; + + DrProcessListenerIRefListRef listeners = processTemplate->GetListenerList(); + int i; + for (i=0; iSize(); ++i) + { + AddListener(listeners[i]); + } +} + +void DrProcess::SetAffinityList(DrAffinityListPtr list) +{ + m_affinity = list; +} + +DrAffinityListPtr DrProcess::GetAffinityList() +{ + return m_affinity; +} + +DrString DrProcess::GetName() +{ + return m_name; +} + +DrProcessInfoPtr DrProcess::GetInfo() +{ + return m_info; +} + +void DrProcess::CloneAndDeliverNotification(bool delay) +{ + DrProcessInfoRef info = DrNew DrProcessInfo(); + + info->m_process = this; + info->m_state = m_info->m_state->Clone(); + info->m_statistics = m_info->m_statistics->Clone(); + info->m_jmProcessCreatedTime = m_info->m_jmProcessCreatedTime; + info->m_jmProcessScheduledTime = m_info->m_jmProcessScheduledTime; + + if (delay) + { + DeliverDelayedNotification(m_template->GetTimeOutBetweenProcessEndAndVertexNotification(), info); + } + else + { + DeliverNotification(info); + } +} + +void DrProcess::Schedule() +{ + DrAssert(m_info->m_state->m_state == DPS_NotStarted); + + m_info->m_state->m_state = DPS_Initializing; + m_info->m_jmProcessScheduledTime = m_xc->GetCurrentTimeStamp(); + + m_xc->ScheduleProcess(m_affinity, m_name, m_commandLine, m_template, this); +} + +void DrProcess::RequestProperty(UINT64 lastSeenVersion, DrString propertyName, + DrTimeInterval maxBlockTime, DrPropertyListenerPtr listener) +{ + DrAssert(m_info->m_state->m_state > DPS_Scheduling); + + if (m_info->m_state->m_state < DPS_Zombie) + { + DrAssert(m_info->m_state->m_process != DrNull); + /* once the higher level has ever requested a property we are going to assume it continues to do so. + this will delay raw process end messages to the listeners to try to order them after the end + messages that return with property fetches */ + m_hasEverRequestedProperty = true; + + m_xc->GetProcessProperty(m_info->m_state->m_process, lastSeenVersion, propertyName, + maxBlockTime, this, listener); + } + else + { + /* there's no process to get the property from any more, so send a dummy dead version */ + DrString reason = "Process already in zombie state"; + DrErrorRef error = DrNew DrError(DrError_Unexpected, "DrProcess", reason); + DrPropertyStatusRef status = DrNew DrPropertyStatus(DPBS_Failed, 1, error); + DrPropertyMessageRef message = DrNew DrPropertyMessage(listener, status); + DeliverMessage(message); + } +} + +void DrProcess::SendCommand(UINT64 newVersion, DrString propertyName, + DrString propertyDescription, DrByteArrayPtr propertyBlock) +{ + DrAssert(m_info->m_state->m_state > DPS_Scheduling); + + if (m_info->m_state->m_state < DPS_Failed) + { + DrAssert(m_info->m_state->m_process != DrNull); + m_xc->SetProcessCommand(m_info->m_state->m_process, newVersion, propertyName, + propertyDescription, propertyBlock, this); + } +} + +void DrProcess::Terminate() +{ + if (m_info->m_state->m_process == DrNull) + { + /* this should be an extremely rare race: the process has been requested, + but XCompute has not yet delivered the message with the process handle. + We will delay another 10 seconds and then try to terminate again: we + can't just give up, because otherwise the process, when it does get + through the XCompute machinery, would be orphaned and would sit there + consuming a cluster resource: we actually do want to call CancelScheduleProcess + on it eventually */ + + DrLogI("Process %s has not yet received its handle: rescheduling termination", m_name.GetChars()); + DrPStateMessageRef message = DrNew DrPStateMessage(this, DPS_Failed); + DeliverDelayedMessage(DrTimeInterval_Second * 10, message); + return; + } + + DrAssert(m_info->m_state->m_state > DPS_NotStarted); + + DrAssert(m_info->m_state->m_process != DrNull); + + /* This method is also called by completed processes in order to close + the process handle. Don't bother cancelling already completed processes. */ + if (m_info->m_state->m_state <= DPS_Running) + { + m_xc->CancelScheduleProcess(m_info->m_state->m_process); + } + m_info->m_state->m_process->CloseHandle(); +} + +/* this is called whenever a state change message returns */ +void DrProcess::ReceiveMessage(DrProcessStateRecordRef message) +{ + DrAssert(m_info->m_state->m_state > DPS_NotStarted); + + DrProcessState oldState = m_info->m_state->m_state; + + if (oldState == DPS_Zombie) + { + DrLogI("Process %s ignoring message while in zombie state", m_name.GetChars()); + /* this process is being discarded so don't do anything */ + return; + } + + if (message->m_state < oldState) + { + /* this shouldn't happen */ + DrLogW("Process %s ignoring message because state is retreating old %d new %d", + m_name.GetChars(), oldState, message->m_state); + return; + } + + if (m_info->m_state->m_state == DPS_Initializing) + { + DrAssert(m_info->m_state->m_process == DrNull); + } + else + { + DrAssert(m_info->m_state->m_process == message->m_process); + } + + m_info->m_state = message; + + DrString errorText = DrError::ToShortText(message->m_status); + DrLogI("Process %s in state %d message with state %d status %s", + m_name.GetChars(), oldState, message->m_state, errorText.GetChars()); + + /* make sure the listeners see an orderly state machine transition. The rule is that we go in an orderly + sequence from Initializing=>Completed, except we can jump to Failed or Zombie at any time. From Failed + we can only move to Zombie, and from Zombie nowhere */ + bool delayForProperty; + DrAssert(oldState > DPS_NotStarted); + while (m_info->m_state->m_state < DPS_Failed && + (int) (m_info->m_state->m_state) > ((int) oldState + 1)) + { + DrProcessState savedState = m_info->m_state->m_state; + + /* insert the phantom intermediate state to the reported state sequence. If this update reports + a termination and the listeners have started their property-fetch state machine then we will + delay delivery a little to give them a chance to fetch the property reporting termination in + more detail */ + m_info->m_state->m_state = (DrProcessState) ((int) oldState + 1); + delayForProperty = (m_info->m_state->m_state > DPS_Running) && m_hasEverRequestedProperty; + CloneAndDeliverNotification(delayForProperty); + + /* then fix things up for the next iteration of this loop */ + oldState = m_info->m_state->m_state; + m_info->m_state->m_state = savedState; + } + + /* tell everyone that is listening that something has happened. If this update reports a termination + and the listeners have started their property-fetch state machine then we will delay delivery a + little to give them a chance to fetch the property reporting termination in more detail */ + delayForProperty = (m_info->m_state->m_state > DPS_Running) && m_hasEverRequestedProperty; + CloneAndDeliverNotification(delayForProperty); + + if (m_info->m_state->m_state < DPS_Starting || m_info->m_state->m_state > DPS_Running) + { + /* we are still waiting to get assigned somewhere or the process has finished: do nothing. + In the future we may need to think about leases for working directories. */ + return; + } + + if (m_info->m_state->m_state < DPS_Running) + { + m_xc->WaitUntilStart(m_info->m_state->m_process, this); + return; + } + else + { + m_xc->WaitUntilCompleted(m_info->m_state->m_process, this); + } +} + +void DrProcess::ReceiveMessage(DrProcessPropertyStatusRef message) +{ + /* forward the property info to the listener that asked for it */ + DeliverMessage(message->m_message); + + if (message->m_statistics != DrNull && + (m_info->m_statistics == DrNull || + m_info->m_statistics->Different(message->m_statistics))) + { + /* we've got updated statistics about the process: tell the world. Since this is + guaranteed to be delivered to the graph after the message, we don't need to + delay (the cohort and vertex share a lock, so their deliveries are serialized) */ + m_info->m_statistics = message->m_statistics; + CloneAndDeliverNotification(false); + } +} + +void DrProcess::ReceiveMessage(DrErrorRef message) +{ + if (message != DrNull) + { + if (SUCCEEDED(message->m_code)) + { + DrString errorText = DrError::ToShortText(message); + DrLogW("Command received non-error status %s", errorText.GetChars()); + } + else + { + /* if we failed to send a message then we assume the process is in an unknown state + and tell our listeners to abandon it */ + m_info->m_state->m_state = DPS_Zombie; + DrString reason; + reason.SetF("XCompute command send failed with error %s", DRERRORSTRING(message->m_code)); + m_info->m_state->m_status = DrNew DrError(DrError_XComputeError, "DrProcess", reason); + m_info->m_state->m_status->AddProvenance(message); + CloneAndDeliverNotification(false); + } + } +} + +void DrProcess::ReceiveMessage(DrProcessState message) +{ + if (message == DPS_Failed) + { + /* This is typically the result of a vertex completing with it's cohort process still in the + running state, which is perfectly normal. In that case, the cohort sends us a delayed message + to terminate, so that cluster resources used by the process can be cleaned up. */ + Terminate(); + } +} \ No newline at end of file diff --git a/GraphManager/kernel/DrProcess.h b/GraphManager/kernel/DrProcess.h new file mode 100644 index 0000000..b2e3d21 --- /dev/null +++ b/GraphManager/kernel/DrProcess.h @@ -0,0 +1,274 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRDECLARECLASS(DrXCompute); +DRREF(DrXCompute); + +// DrProcessState is also in Java class DryadAppMaster +DRENUM(DrProcessState) +{ + DPS_NotStarted, + DPS_Initializing, + DPS_Scheduling, + DPS_Starting, + DPS_Running, + DPS_Completed, + DPS_Failed, + DPS_Zombie +}; + +DRENUM(DrProcessBasicState) +{ + DPBS_NotStarted, + DPBS_Running, + DPBS_Completed, + DPBS_Failed +}; + +DRBASECLASS(DrProcessHandle abstract) +{ +public: + virtual ~DrProcessHandle(); + + virtual void CloseHandle() DRABSTRACT; + virtual DrString GetHandleIdAsString() DRABSTRACT; + virtual DrProcessState GetState(HRESULT& reason) DRABSTRACT; + virtual DrString GetFileURIBase() DRABSTRACT; + + void SetAssignedNode(DrResourcePtr); + DrResourcePtr GetAssignedNode(); + +private: + DrResourceRef m_node; +}; +DRREF(DrProcessHandle); + +DRDECLARECLASS(DrProcessInfo); +DRREF(DrProcessInfo); + +typedef DrListener DrProcessListener; +DRIREF(DrProcessListener); +DRMAKEARRAYLIST(DrProcessListenerIRef); + +DRBASECLASS(DrProcessTemplate) +{ +public: + DrProcessTemplate(); + + void SetCommandLineBase(DrString commandLine); + DrString GetCommandLineBase(); + + void SetProcessClass(DrString processClass); + DrString GetProcessClass(); + + DrProcessListenerIRefListPtr GetListenerList(); + + void SetFailedRetainAndLeaseGraceTime(DrTimeInterval time, + DrTimeInterval leaseGraceTime); + DrTimeInterval GetFailedRetainTime(); + DrTimeInterval GetFailedLeaseWaitTime(); + void SetCompletedRetainAndLeaseGraceTime(DrTimeInterval time, + DrTimeInterval leaseGraceTime); + DrTimeInterval GetCompletedRetainTime(); + DrTimeInterval GetCompletedLeaseWaitTime(); + + void SetMaxMemory(UINT64 maxMemory); + UINT64 GetMaxMemory(); + + void SetTimeOutBetweenProcessEndAndVertexNotification(DrTimeInterval timeOut); + DrTimeInterval GetTimeOutBetweenProcessEndAndVertexNotification(); + + DrFloatArrayPtr GetAffinityLevelThresholds(); + +private: + DrString m_commandLineBase; + DrString m_processClass; + + DrProcessListenerIRefListRef m_listenerList; + + DrTimeInterval m_failedRetainTime; + DrTimeInterval m_failedLeaseGraceTime; + DrTimeInterval m_completedRetainTime; + DrTimeInterval m_completedLeaseGraceTime; + + UINT64 m_maxMemory; + + DrTimeInterval m_timeOutBetweenProcessEndAndVertexNotification; + + DrFloatArrayRef m_affinityLevelThresholds; +}; +DRREF(DrProcessTemplate); + +DRDECLARECLASS(DrProcessStateRecord); +DRREF(DrProcessStateRecord); +DRBASECLASS(DrProcessStateRecord) +{ +public: + DrProcessStateRecord(); + DrProcessStateRecordRef Clone(); + + DrProcessHandleRef m_process; + DrProcessState m_state; + UINT32 m_exitCode; + DrErrorRef m_status; +}; + +DRDECLARECLASS(DrProcessStats); +DRREF(DrProcessStats); +DRBASECLASS(DrProcessStats) +{ +public: + DrProcessStats(); + bool Different(DrProcessStatsPtr other); + DrProcessStatsRef Clone(); + + DWORD m_exitCode; + UINT32 m_pid; + DrDateTime m_createdTime; + DrDateTime m_beginExecutionTime; + DrDateTime m_terminatedTime; + + DrTimeInterval m_userTime; + DrTimeInterval m_kernelTime; + INT32 m_pageFaults; + UINT64 m_peakVMUsage; + UINT64 m_peakMemUsage; + UINT64 m_memUsageSeconds; + UINT64 m_totalIO; +}; + +DRDECLARECLASS(DrProcess); +DRREF(DrProcess); + +DRBASECLASS(DrProcessInfo) +{ +public: + DrLockBox m_process; + DrProcessStateRecordRef m_state; + DrProcessStatsRef m_statistics; + + DrDateTime m_jmProcessCreatedTime; + DrDateTime m_jmProcessScheduledTime; +}; +DRREF(DrProcessInfo); + +typedef DrListener DrPSRListener; +DRIREF(DrPSRListener); + +typedef DrMessage DrPSRMessage; +DRREF(DrPSRMessage); + +typedef DrMessage DrProcessMessage; +DRREF(DrProcessMessage); + +typedef DrNotifier DrProcessNotifier; + +DRBASECLASS(DrPropertyStatus) +{ +public: + DrPropertyStatus(DrProcessBasicState state, UINT32 exitCode, DrErrorPtr error); + + DrProcessBasicState m_processState; + UINT32 m_exitCode; + DrErrorRef m_status; + DrLockBox m_process; + UINT64 m_statusVersion; + DrByteArrayRef m_statusBlock; +}; +DRREF(DrPropertyStatus); + +typedef DrListener DrPropertyListener; +DRIREF(DrPropertyListener); + +typedef DrMessage DrPropertyMessage; +DRREF(DrPropertyMessage); + +typedef DrNotifier DrPropertyNotifier; + +DRBASECLASS(DrProcessPropertyStatus) +{ +public: + DrProcessHandleRef m_process; + DrProcessStatsRef m_statistics; + DrPropertyMessageRef m_message; +}; +DRREF(DrProcessPropertyStatus); + +typedef DrListener DrPPSListener; +DRIREF(DrPPSListener); + +typedef DrMessage DrPPSMessage; +DRREF(DrPPSMessage); + +typedef DrListener DrPStateListener; +DRIREF(DrPStateListener); + +typedef DrMessage DrPStateMessage; +DRREF(DrPStateMessage); + +DRCLASS(DrProcess) + : public DrProcessNotifier, public DrPSRListener, public DrPPSListener, public DrErrorListener, + public DrPStateListener +{ +public: + DrProcess(DrXComputePtr xc, DrString name, DrString commandLine, + DrProcessTemplatePtr processTemplate); + + void SetAffinityList(DrAffinityListPtr list); + DrAffinityListPtr GetAffinityList(); + DrProcessInfoPtr GetInfo(); + DrString GetName(); + + void Schedule(); + void RequestProperty(UINT64 lastSeenVersion, DrString propertyName, DrTimeInterval maxBlockTime, + DrPropertyListenerPtr listener); + void SendCommand(UINT64 version, DrString propertyName, + DrString propertyDescription, DrByteArrayPtr propertyBlock); + void Terminate(); + + /* DrPSRListener implementation */ + virtual void ReceiveMessage(DrProcessStateRecordRef message); + + /* DrPPSListener implementation */ + virtual void ReceiveMessage(DrProcessPropertyStatusRef message); + + /* DrErrorListener implementation, used for the result of sending a command */ + virtual void ReceiveMessage(DrErrorRef message); + + /* DrPStateListener implementation, used to send a delayed request for termination */ + virtual void ReceiveMessage(DrProcessState message); + +private: + void CloneAndDeliverNotification(bool delay); + + DrXComputeRef m_xc; + DrString m_name; + DrString m_commandLine; + DrProcessTemplateRef m_template; + DrAffinityListRef m_affinity; + + bool m_hasEverRequestedProperty; + DrProcessInfoRef m_info; +}; + +typedef DrSet DrProcessSet; +DRREF(DrProcessSet); diff --git a/GraphManager/kernel/DrXCompute.h b/GraphManager/kernel/DrXCompute.h new file mode 100644 index 0000000..fa9a518 --- /dev/null +++ b/GraphManager/kernel/DrXCompute.h @@ -0,0 +1,89 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRDECLARECLASS(DrXCompute); +DRREF(DrXCompute); + +/* DrXCompute abstracts away all the internal types used by XCompute so none of the rest of the system +actually includes XCompute.h or XComputeTypes.h. DrXComputeInternal.h defines the concrete implementation +of DrXCompute that includes the gory details */ + +class DrXComputeOverlapped : public DrOverlapped +{ +public: + DrXComputeOverlapped(DrXComputePtr parent, DrMessageBasePtr message); + virtual ~DrXComputeOverlapped(); + + DrXComputeRef ExtractParent(); + DrMessageBaseRef ExtractMessage(); + void Discard(); + +private: + DrRefHolder m_parent; + DrRefHolder m_message; +}; + +DRCLASS(DrXCompute abstract) : public DrCritSec +{ +public: + /* this returns an object of the concrete type */ + static DrXComputeRef Create(); + + virtual ~DrXCompute(); + + virtual HRESULT Initialize(DrUniversePtr universe, DrMessagePumpPtr pump) = 0; + virtual void Shutdown() = 0; + virtual DrUniversePtr GetUniverse() = 0; + virtual DrMessagePumpPtr GetMessagePump() = 0; + virtual DrDateTime GetCurrentTimeStamp() = 0; + + virtual void ScheduleProcess(DrAffinityListRef affinities, + DrString name, DrString commandLine, + DrProcessTemplatePtr processTemplate, + DrPSRListenerPtr listener) = 0; + virtual void CancelScheduleProcess(DrProcessHandlePtr process) = 0; + + virtual void WaitUntilStart(DrProcessHandlePtr process, DrPSRListenerPtr listener) = 0; + virtual void WaitUntilCompleted(DrProcessHandlePtr process, DrPSRListenerPtr listener) = 0; + + virtual void GetProcessProperty(DrProcessHandlePtr process, + UINT64 lastSeenVersion, DrString propertyName, + DrTimeInterval maxBlockTime, + DrPPSListenerPtr processListener, DrPropertyListenerPtr propertyListener) = 0; + + virtual void SetProcessCommand(DrProcessHandlePtr p, + UINT64 newVersion, DrString propertyName, + DrString propertyDescription, + DrByteArrayRef propertyBlock, + DrErrorListenerPtr listener) = 0; + + virtual void TerminateProcess(DrProcessHandlePtr p, + DrErrorListenerPtr listener) = 0; + + virtual void ResetProgress(UINT32 totalSteps, bool update) = 0; + virtual void IncrementTotalSteps(bool update) = 0; + virtual void DecrementTotalSteps(bool update) = 0; + virtual void IncrementProgress(PCSTR message) = 0; + virtual void CompleteProgress(PCSTR message) = 0; + +}; +DRREF(DrXCompute); \ No newline at end of file diff --git a/GraphManager/kernel/drxcompute.cpp b/GraphManager/kernel/drxcompute.cpp new file mode 100644 index 0000000..ab695df --- /dev/null +++ b/GraphManager/kernel/drxcompute.cpp @@ -0,0 +1,1107 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include +#include + +DrXComputeResource::DrXComputeResource(DrResourceLevel level, + DrString name, DrString locality, + DrResourcePtr parent, XCPROCESSNODEID node) +: DrResource(level, name, locality, parent) +{ + m_node = node; +} + +XCPROCESSNODEID DrXComputeResource::GetNode() +{ + return m_node; +} + + +static const UINT8* GuidWriteWord(char *pDst, const UINT8 *pSrc, int uBytes) +{ + static const char Hex[] = "0123456789ABCDEF"; + int c; + + pDst += uBytes * 2; + *pDst = '-'; + do + { + c = *pSrc++; + pDst -= 2; + pDst[1] = Hex[c & 15]; + pDst[0] = Hex[c >> 4]; + } + while (--uBytes); + + return pSrc; +} + +static void GuidWrite(char *pDst, const GUID *pGuid) +{ + const UINT8 *p; + + p = (const UINT8 *) pGuid; + + p = GuidWriteWord(pDst, p, 4); + pDst += 9; + + p = GuidWriteWord(pDst, p, 2); + pDst += 5; + + p = GuidWriteWord(pDst, p, 2); + pDst += 5; + + GuidWriteWord(pDst, p, 1); + GuidWriteWord(pDst + 2, p + 1, 1); + pDst += 5; + p += 2; + + GuidWriteWord(pDst + 0*2, p + 0, 1); + GuidWriteWord(pDst + 1*2, p + 1, 1); + GuidWriteWord(pDst + 2*2, p + 2, 1); + GuidWriteWord(pDst + 3*2, p + 3, 1); + GuidWriteWord(pDst + 4*2, p + 4, 1); + GuidWriteWord(pDst + 5*2, p + 5, 1); + + pDst += 6*2; + *pDst = 0; +} + +void DrXComputeProcessHandle::CloseHandle() +{ + if (m_handle != DrNull) + { + DrLogI("Calling close handle"); + XcCloseProcessHandle(m_handle); + DrLogI("Closed handle"); + m_handle = DrNull; + } +} + +DrString DrXComputeProcessHandle::GetHandleIdAsString() +{ + DrString s; + + GUID id; + XCERROR err = XcGetProcessId(m_handle, &id); + + if (SUCCEEDED(err)) + { + char guidString[37]; + GuidWrite(guidString, &id); + s.SetF("%s", guidString); + } + else + { + s.Set("not yet assigned"); + } + + return s; +} + +static DrProcessState TranslateXcState(XCPROCESSSTATE xcState, HRESULT errorReason) +{ + if (xcState < XCPROCESSSTATE_ASSIGNEDTONODE) + { + if (xcState == XCPROCESSSTATE_SCHEDULINGFAILED) + { + DrAssert(!SUCCEEDED(errorReason)); + return DPS_Failed; + } + else + { + return DPS_Scheduling; + } + } + else if (xcState < XCPROCESSSTATE_RUNNING) + { + return DPS_Starting; + } + else if (xcState < XCPROCESSSTATE_COMPLETED) + { + return DPS_Running; + } + else if (xcState == XCPROCESSSTATE_COMPLETED) + { + if (SUCCEEDED(errorReason)) + { + return DPS_Completed; + } + else + { + return DPS_Failed; + } + } + else /* xcState > XCPROCESSSTATE_COMPLETED */ + { + return DPS_Zombie; + } +} + +DrProcessState DrXComputeProcessHandle::GetState(HRESULT& errorReason) +{ + XCPROCESSSTATE xcState; + XCERROR schedulingError; + XCERROR err = XcGetProcessState(m_handle, &xcState, &schedulingError); + DrAssert(SUCCEEDED(err)); + + if (xcState == XCPROCESSSTATE_COMPLETED || xcState == XCPROCESSSTATE_SCHEDULINGFAILED) + { + errorReason = schedulingError; + } + else + { + errorReason = S_OK; + } + + return TranslateXcState(xcState, errorReason); +} + +DrString DrXComputeProcessHandle::GetFileURIBase() +{ + char* path = NULL; + + XCERROR err = XcGetProcessUri(m_handle, "", &path); + DrAssert(SUCCEEDED(err)); + + DrString uriBase; + uriBase.SetF("%s", path); + + XcFreeMemory(path); + + return uriBase; +} + +DrXComputeOverlapped::DrXComputeOverlapped(DrXComputePtr parent, DrMessageBasePtr message) +{ + m_parent.Store(parent); + if (message != DrNull) + { + m_message.Store(message); + } +} + + +DrXComputeOverlapped::~DrXComputeOverlapped() +{ +} + +DrXComputeRef DrXComputeOverlapped::ExtractParent() +{ + return m_parent.Extract(); +} + +DrMessageBaseRef DrXComputeOverlapped::ExtractMessage() +{ + return m_message.Extract(); +} + +void DrXComputeOverlapped::Discard() +{ + DrXComputeRef x = m_parent.Extract(); + x = DrNull; + DrMessageBaseRef m = m_message.Extract(); + m = DrNull; +} + + +DrXComputeWaitForStateChangeOverlapped::DrXComputeWaitForStateChangeOverlapped(DrXComputePtr parent, + DrPSRMessagePtr message) + : DrXComputeOverlapped(parent, message) +{ +} + +void DrXComputeWaitForStateChangeOverlapped::Process() +{ + HRESULT status = *GetOperationStatePtr(); + + DrXComputeRef p = ExtractParent(); + DrXComputeInternalRef parent = dynamic_cast((DrXComputePtr) p); + DrAssert(parent != DrNull); + + DrMessageBaseRef m = ExtractMessage(); + DrPSRMessageRef message = dynamic_cast((DrMessageBasePtr) m); + DrAssert(message != DrNull); + + parent->ProcessStateChange(status, message); +} + + +DrXComputeCancelScheduleProcessOverlapped::DrXComputeCancelScheduleProcessOverlapped(DrXComputePtr parent) + : DrXComputeOverlapped(parent, DrNull) +{ +} + +/* we don't do anything with the result, since somebody else is already waiting for the state change */ +void DrXComputeCancelScheduleProcessOverlapped::Process() +{ + HRESULT status = *GetOperationStatePtr(); + + DrLogI("Cancel schedule process returned with status %s", DRERRORSTRING(status)); +} + + +DrXComputeGetSetOverlapped::DrXComputeGetSetOverlapped(DrXComputePtr parent, DrMessageBasePtr message) + : DrXComputeOverlapped(parent, message) +{ + m_results = NULL; +} + +DrXComputeGetSetOverlapped::~DrXComputeGetSetOverlapped() +{ + if (m_results != NULL) + { + PXC_PROCESS_INFO pProcessInfo = m_results->pProcessInfo; + if (pProcessInfo) + { + for (unsigned int i=0; iNumberofProcessProperties; i++) + { + PXC_PROCESSPROPERTY_INFO pprop = pProcessInfo->ppProperties[i]; + XcFreeMemory(pprop->pPropertyLabel); + XcFreeMemory(pprop->pPropertyString); + XcFreeMemory(pprop->pPropertyBlock); + XcFreeMemory(pprop); + } + XcFreeMemory(pProcessInfo->ppProperties); + XcFreeMemory(pProcessInfo->pProcessStatistics); + } + + XcFreeMemory(m_results->pProcessInfo); + XcFreeMemory(m_results->pPropertyVersions); + XcFreeMemory(m_results); + } +} + +PXC_SETANDGETPROCESSINFO_REQRESULTS* DrXComputeGetSetOverlapped::GetResultsPointer() +{ + return &m_results; +} + +DrXComputeGetPropertyOverlapped::DrXComputeGetPropertyOverlapped(DrXComputePtr parent, DrString propertyName, + DrPPSMessagePtr message) + : DrXComputeGetSetOverlapped(parent, message) +{ + DrHeapStringRef hs = DrNew DrHeapString(); + hs->m_payload = propertyName; + m_propertyName.Store(hs); +} + +void DrXComputeGetPropertyOverlapped::Process() +{ + HRESULT status = *GetOperationStatePtr(); + + DrXComputeRef p = ExtractParent(); + DrXComputeInternalRef parent = dynamic_cast((DrXComputePtr) p); + DrAssert(parent != DrNull); + + DrMessageBaseRef m = ExtractMessage(); + DrPPSMessageRef message = dynamic_cast((DrMessageBasePtr) m); + DrAssert(message != DrNull); + + DrHeapStringRef hs = m_propertyName.Extract(); + DrString propertyName = hs->m_payload; + + parent->ProcessPropertyFetch(status, propertyName, m_results, message); +} + +DrXComputeSetCommandOverlapped::DrXComputeSetCommandOverlapped(DrXComputePtr parent, DrErrorMessagePtr message) + : DrXComputeGetSetOverlapped(parent, message) +{ +} + +void DrXComputeSetCommandOverlapped::Process() +{ + HRESULT status = *GetOperationStatePtr(); + + DrXComputeRef p = ExtractParent(); + DrXComputeInternalRef parent = dynamic_cast((DrXComputePtr) p); + DrAssert(parent != DrNull); + + DrMessageBaseRef m = ExtractMessage(); + DrErrorMessageRef message = dynamic_cast((DrMessageBasePtr) m); + DrAssert(message != DrNull); + + parent->ProcessCommandResult(status, message); +} + + +DrXCompute::~DrXCompute() +{ +} + +DrXComputeRef DrXCompute::Create() +{ + //return DrNew DrXComputeYarn(); + return DrNew DrXComputeInternal(); +} + + +DrXComputeInternal::DrXComputeInternal() +{ + m_session = INVALID_XCSESSIONHANDLE; + m_localProcess = INVALID_XCPROCESSHANDLE; +} + +DrXComputeInternal::~DrXComputeInternal() +{ + Shutdown(); +} + +DrUniversePtr DrXComputeInternal::GetUniverse() +{ + return m_universe; +} + +DrMessagePumpPtr DrXComputeInternal::GetMessagePump() +{ + return m_messagePump; +} + +DrDateTime DrXComputeInternal::GetCurrentTimeStamp() +{ + return m_messagePump->GetCurrentTimeStamp(); +} + + +HRESULT DrXComputeInternal::Initialize(DrUniversePtr universe, DrMessagePumpPtr pump) +{ + m_universe = universe; + m_messagePump = pump; + + DrLogI("Initializing XCompute"); + + XCERROR err; + + err = XcInitialize(NULL, "Dryad"); + if (!SUCCEEDED(err)) + { + DrLogE("Failed to initialize XCompute", "error: %s", DRERRORSTRING(err)); + return err; + } + + { + DRPIN(XCSESSIONHANDLE) sessionPtr = &m_session; + err = XcOpenSession(NULL, sessionPtr, NULL); + } + if (!SUCCEEDED(err)) + { + DrLogE("Failed to open XCompute session", "error: %s", DRERRORSTRING(err)); + return err; + } + + { + DRPIN(XCPROCESSHANDLE) processPtr = &m_localProcess; + err = XcOpenCurrentProcessHandle(m_session, processPtr); + } + if (!SUCCEEDED(err)) + { + DrLogE("Failed to open local process handle", "error: %s", DRERRORSTRING(err)); + return err; + } + + return FetchListOfComputers(); +} + +void DrXComputeInternal::Shutdown() +{ + if (m_session!= INVALID_XCSESSIONHANDLE) + { + XcCloseSession(m_session); + m_session = INVALID_XCSESSIONHANDLE; + } + + if (m_localProcess != INVALID_XCPROCESSHANDLE) + { + XcCloseProcessHandle(m_localProcess); + m_localProcess = INVALID_XCPROCESSHANDLE; + } +} + +void DrXComputeInternal::AddNodeToUniverse(XCPROCESSNODEID nodeId) +{ + PCSTR nodeName; + XCERROR err = XcProcessNodeNameFromId(m_session, + nodeId, + &nodeName); + DrAssert(SUCCEEDED(err)); + + DrString name; + name.SetF("%s", nodeName); + XcFreeMemory(nodeName); + + PCSTR nodeLS; + err = XcGetNetworkLocalityPathOfProcessNode(m_session, + nodeId, + NULL, + &nodeLS); + DrAssert(SUCCEEDED(err)); + + DrString nodeLocality; + nodeLocality.SetF("%s", nodeLS); + XcFreeMemory(nodeLS); + + PCSTR podLS; + err = XcGetNetworkLocalityPathOfProcessNode(m_session, + nodeId, + XCLOCALITYPARAM_POD, + &podLS); + DrAssert(SUCCEEDED(err)); + + DrString podName; + podName.SetF("POD-%s", podLS); + DrString podLocality; + podLocality.SetF("%s", podLS); + XcFreeMemory(podLS); + + DrResourceRef pod = m_universe->LookUpResourceInternal(podName); + if (pod == DrNull) + { + pod = DrNew DrXComputeResource(DRL_Rack, podName, podLocality, DrNull, INVALID_XCPROCESSNODEID); + m_universe->AddResource(pod); + DrLogI("Found pod %s", podName.GetChars()); + } + + DrResourceRef node = m_universe->LookUpResourceInternal(name); + DrAssert(node == DrNull); + node = DrNew DrXComputeResource(DRL_Computer, name, nodeLocality, pod, nodeId); + + m_universe->AddResource(node); + DrLogI("Found computer %s in pod %s", name.GetChars(), podName.GetChars()); +} + +HRESULT DrXComputeInternal::FetchListOfComputers() +{ + UINT32 numberOfNodes; + PXCPROCESSNODEID nodeArray = NULL; + + XCERROR err = XcEnumerateProcessNodes(m_session, + &numberOfNodes, + &nodeArray, + NULL); + + if (!SUCCEEDED(err)) + { + DrLogE("Failed to enumerate process nodes error: %s", DRERRORSTRING(err)); + + if (nodeArray != NULL) + { + XcFreeMemory(nodeArray); + } + + return err; + } + + if (numberOfNodes == 0) + { + DrLogI("No process nodes returned"); + + if (nodeArray != NULL) + { + XcFreeMemory(nodeArray); + } + + return HRESULT_FROM_WIN32(ERROR_INVALID_DATA); + } + + DrLogI("Found %d process nodes", numberOfNodes); + + { + DrAutoCriticalSection acs(m_universe->GetResourceLock()); + + UINT32 i; + for (i=0; iSize()); + + PXC_AFFINITY affinityArray = new XC_AFFINITY[affinities->Size()]; + int i; + for (i=0; iSize(); ++i) + { + DrAffinityPtr a = affinities[i]; + + PCSTR* pathArray = new PCSTR[a->GetLocalityArray()->Size()]; + int j; + for (j=0; jGetLocalityArray()->Size(); ++j) + { + pathArray[j] = a->GetLocalityArray()[j]->GetLocality().GetChars(); + DrLogI("Added affinity path %s", pathArray[j]); + } + + XC_AFFINITY& affinity = affinityArray[i]; + memset(&affinity, 0, sizeof(affinity)); + affinity.Size = sizeof(affinity); + affinity.NumberOfNetworkLocalityPaths = a->GetLocalityArray()->Size(); + affinity.pNetworkLocalityPaths = pathArray; + affinity.Weight = a->GetWeight(); + + DrLogI("Added affinity with weight %I64u", affinity.Weight); + } + + XC_LOCALITY_DESCRIPTOR locality; + memset(&locality, 0, sizeof(locality)); + locality.Size = sizeof(locality); + locality.NumberOfAffinities = affinities->Size(); + locality.pAffinities = affinityArray; + + XC_CREATEPROCESS_DESCRIPTOR createProcess; + memset(&createProcess, 0, sizeof(createProcess)); + createProcess.Size = sizeof(createProcess); + createProcess.pCommandLine = commandLine.GetChars(); + createProcess.pProcessClass = processTemplate->GetProcessClass().GetChars(); + createProcess.pProcessFriendlyName = name.GetChars(); + createProcess.pAppProcessConstraints = NULL; + createProcess.NumberOfResourceFileDescriptors = 0; + createProcess.pResourceFileDescriptors = NULL; + + XC_SCHEDULEPROCESS_DESCRIPTOR scheduleDescriptor; + memset(&scheduleDescriptor, 0, sizeof(scheduleDescriptor)); + scheduleDescriptor.Size = sizeof(scheduleDescriptor); + scheduleDescriptor.pLocalityDescriptor = &locality; + scheduleDescriptor.pCreateProcessDescriptor = &createProcess; + + DrLogI("Starting schedule process for %s.%s", + processTemplate->GetProcessClass().GetChars(), name.GetChars()); + + DrProcessState state; + DrString reason; + XCERROR err; + + DrXComputeProcessHandleRef process = DrNew DrXComputeProcessHandle(); + + { + DRPIN(XCPROCESSHANDLE) processHandlePtr = &process->m_handle; + err = XcCreateNewProcessHandle(m_session, NULL, processHandlePtr); + } + + if (SUCCEEDED(err)) + { + err = XcScheduleProcess(process->m_handle, &scheduleDescriptor); + if (SUCCEEDED(err)) + { + DrLogI("Schedule process succeeded for %s.%s", + processTemplate->GetProcessClass().GetChars(), name.GetChars()); + state = DPS_Scheduling; + reason = "Schedule process in progress"; + } + else + { + DrLogW("Schedule process failed immediately for %s.%s error %s", + processTemplate->GetProcessClass().GetChars(), name.GetChars(), DRERRORSTRING(err)); + state = DPS_Zombie; + reason.SetF("Schedule process failed immediately error %s", DRERRORSTRING(err)); + } + } + else + { + DrLogW("Create process failed for %s.%s error %s", + processTemplate->GetProcessClass().GetChars(), name.GetChars(), DRERRORSTRING(err)); + state = DPS_Zombie; + reason.SetF("Create process handle failed error %s", DRERRORSTRING(err)); + } + + for (i=0; iSize(); ++i) + { + XC_AFFINITY& affinity = affinityArray[i]; + delete [] affinity.pNetworkLocalityPaths; + } + delete [] affinityArray; + + /* send the listener notification of the change of state */ + DrProcessStateRecordRef notification = DrNew DrProcessStateRecord(); + notification->m_state = state; + notification->m_exitCode = STILL_ACTIVE; + notification->m_process = process; + if (err != S_OK) + { + /* this can be true in principle if SUCCEEDED(err) but only S_OK corresponds to a null + error object */ + notification->m_status = DrNew DrError(err, "XCompute", reason); + } + + DrPSRMessageRef message = DrNew DrPSRMessage(listener, notification); + m_messagePump->EnQueue(message); + + if (SUCCEEDED(err)) + { + WaitForStateChange(process, XCPROCESSSTATE_ASSIGNEDTONODE, listener); + } +} + +void DrXComputeInternal::WaitUntilStart(DrProcessHandlePtr p, DrPSRListenerPtr listener) +{ + DrXComputeProcessHandlePtr process = dynamic_cast(p); + DrAssert(process != DrNull); + + WaitForStateChange(process, XCPROCESSSTATE_RUNNING, listener); +} + +void DrXComputeInternal::WaitUntilCompleted(DrProcessHandlePtr p, DrPSRListenerPtr listener) +{ + DrXComputeProcessHandlePtr process = dynamic_cast(p); + DrAssert(process != DrNull); + + WaitForStateChange(process, XCPROCESSSTATE_COMPLETED, listener); +} + +void DrXComputeInternal::WaitForStateChange(DrXComputeProcessHandlePtr process, XCPROCESSSTATE targetState, + DrPSRListenerPtr listener) +{ + /* make a new message now to hold on to the reference to listener */ + DrProcessStateRecordRef notification = DrNew DrProcessStateRecord(); + notification->m_exitCode = STILL_ACTIVE; + notification->m_process = process; + DrPSRMessageRef message = DrNew DrPSRMessage(listener, notification); + + DrXComputeOverlapped* overlapped = new DrXComputeWaitForStateChangeOverlapped(this, message); + + XC_ASYNC_INFO asyncInfo; + memset(&asyncInfo, 0, sizeof(asyncInfo)); + asyncInfo.Size = sizeof(asyncInfo); + asyncInfo.pOperationState = overlapped->GetOperationStatePtr(); + asyncInfo.IOCP = m_messagePump->GetCompletionPort(); + asyncInfo.pOverlapped = overlapped; + + DrLogI("Waiting for state change to %x", targetState); + + m_messagePump->NotifySubmissionToCompletionPort(overlapped); + + XCERROR err = XcWaitForStateChange(process->m_handle, targetState, + XCTIMEINTERVAL_INFINITE, &asyncInfo); + + if (err != HRESULT_FROM_WIN32(ERROR_IO_PENDING)) + { + DrLogA("Wait for state change failed synchronously error %s", DRERRORSTRING(err)); + } +} + +void DrXComputeInternal::ProcessStateChange(HRESULT status, DrPSRMessagePtr message) +{ + DrProcessStateRecordPtr notification = message->GetPayload(); + DrString reason; + + if (!SUCCEEDED(status)) + { + /* we really don't want to add cascading retries all through the stack, + so we are going to take the bold stance that if XCompute says there + was an error here, it has diligently retried in as sensible a way as + we would have, and really, things are not going to improve. So the + process is now dead to us. */ + + /* We use the Failed state here instead of the Zombie state because we + are relying on XCompute returning success here, but setting the process + state to some zombie value, if the process has been garbage collected */ + notification->m_state = DPS_Failed; + reason.SetF("Wait for state change RPC failed with error %s", DRERRORSTRING(status)); + notification->m_status = DrNew DrError(status, "XCompute", reason); + notification->m_exitCode = 1; + } + else + { + DrProcessHandlePtr process = notification->m_process; + DrAssert(process != DrNull); + HRESULT schedulingStatus; + notification->m_state = process->GetState(schedulingStatus); + + DrLogI("Processing state change internal state %d status %s", + notification->m_state, DRERRORSTRING(schedulingStatus)); + + if (!SUCCEEDED(schedulingStatus)) + { + reason.SetF("Process in failed state with error %s", DRERRORSTRING(schedulingStatus)); + notification->m_status = DrNew DrError(schedulingStatus, "XCompute", reason); + notification->m_exitCode = 1; + } + else if (schedulingStatus != S_OK) + { + reason.SetF("Process in non-error state %s", DRERRORSTRING(schedulingStatus)); + notification->m_status = DrNew DrError(schedulingStatus, "XCompute", reason); + } + + if (process->GetAssignedNode() == DrNull) + { + DrXComputeProcessHandlePtr xcProcess = dynamic_cast(process); + DrAssert(xcProcess != DrNull); + + DrLogI("Looking up assigned node"); + + XCPROCESSNODEID nodeId; + XCERROR err = XcGetProcessNodeId(xcProcess->m_handle, &nodeId); + + DrLogI("Found assigned node id err %s", DRERRORSTRING(err)); + + if (SUCCEEDED(err)) + { + DrAutoCriticalSection acs(m_universe->GetResourceLock()); + + PCSTR nodeName; + XCERROR err = XcProcessNodeNameFromId(m_session, + nodeId, + &nodeName); + DrAssert(SUCCEEDED(err)); + + DrString name; + name.SetF("%s", nodeName); + XcFreeMemory(nodeName); + + DrLogI("Found node name %s", name.GetChars()); + + DrResourceRef resource = m_universe->LookUpResourceInternal(name); + if (resource == DrNull) + { + /* the scheduler may have added new resources we didn't know about originally, so be + ready for that */ + AddNodeToUniverse(nodeId); + resource = m_universe->LookUpResourceInternal(name); + DrAssert(resource != DrNull); + } + + process->SetAssignedNode(resource); + } + else if (notification->m_state >= DPS_Starting && notification->m_state != DPS_Zombie) + { + /* HACK. XCompute did not fill in the location, so we mustn't call in subsequently. + Setting the state to zombie will ensure everyone abandons it */ + DrLogW("Process didn't get assigned a node: setting to zombie"); + notification->m_state = DPS_Zombie; + reason = "Process didn't get assigned a node"; + notification->m_status = DrNew DrError(DrError_XComputeError, "XCompute", reason); + } + } + } + + m_messagePump->EnQueue(message); +} + +void DrXComputeInternal::CancelScheduleProcess(DrProcessHandlePtr p) +{ + DrXComputeProcessHandlePtr process = dynamic_cast(p); + DrAssert(process != DrNull); + + DrXComputeOverlapped* overlapped = new DrXComputeCancelScheduleProcessOverlapped(this); + + XC_ASYNC_INFO asyncInfo; + memset(&asyncInfo, 0, sizeof(asyncInfo)); + asyncInfo.Size = sizeof(asyncInfo); + asyncInfo.pOperationState = overlapped->GetOperationStatePtr(); + asyncInfo.IOCP = m_messagePump->GetCompletionPort(); + asyncInfo.pOverlapped = overlapped; + + DrLogI("Sending cancellation for scheduled process"); + + m_messagePump->NotifySubmissionToCompletionPort(overlapped); + + XCERROR err = + XcCancelScheduleProcess(process->m_handle, &asyncInfo); + + if (err != HRESULT_FROM_WIN32(ERROR_IO_PENDING)) + { + DrLogA("Cancel schedule process failed synchronously error %s", DRERRORSTRING(err)); + } +} + +void DrXComputeInternal::GetProcessProperty(DrProcessHandlePtr p, + UINT64 lastSeenVersion, DrString propertyName, + DrTimeInterval maxBlockTime, + DrPPSListenerPtr processListener, + DrPropertyListenerPtr propertyListener) +{ + DrXComputeProcessHandlePtr process = dynamic_cast(p); + DrAssert(process != DrNull); + + DrLogI("Requesting property %s lastSeen %I64u maxBlock %lf", propertyName.GetChars(), + lastSeenVersion, (double) maxBlockTime / (double) DrTimeInterval_Second); + + DrProcessPropertyStatusRef notification = DrNew DrProcessPropertyStatus(); + notification->m_process = process; + + DrPropertyStatusRef propertyNotification = DrNew DrPropertyStatus(DPBS_Running, STILL_ACTIVE, DrNull); + DrPropertyMessageRef propertyMessage = DrNew DrPropertyMessage(propertyListener, propertyNotification); + notification->m_message = propertyMessage; + + DrPPSMessageRef message = DrNew DrPPSMessage(processListener, notification); + + DrXComputeGetPropertyOverlapped* overlapped = new DrXComputeGetPropertyOverlapped(this, propertyName, message); + + XC_SETANDGETPROCESSINFO_REQINPUT requestInputs; + memset(&requestInputs, 0, sizeof(requestInputs)); + requestInputs.Size = sizeof(requestInputs); + requestInputs.pAppProcessConstraints = NULL; + requestInputs.NumberOfProcessPropertiesToSet = 0; + requestInputs.ppPropertiesToSet = NULL; + requestInputs.pBlockOnPropertyLabel = propertyName.GetChars(); + requestInputs.BlockOnPropertyversionLastSeen = lastSeenVersion; + requestInputs.MaxBlockTime = maxBlockTime; + requestInputs.pPropertyFetchTemplate = propertyName.GetChars(); + requestInputs.ProcessInfoFetchOptions |= + XCPROCESSINFOOPTION_STATICINFO | + XCPROCESSINFOOPTION_TIMINGINFO | + XCPROCESSINFOOPTION_PROCESSSTAT; + + XC_ASYNC_INFO asyncInfo; + memset(&asyncInfo, 0, sizeof(asyncInfo)); + asyncInfo.Size = sizeof(asyncInfo); + asyncInfo.pOperationState = overlapped->GetOperationStatePtr(); + asyncInfo.IOCP = m_messagePump->GetCompletionPort(); + asyncInfo.pOverlapped = overlapped; + + m_messagePump->NotifySubmissionToCompletionPort(overlapped); + + XCERROR err = XcSetAndGetProcessInfo(process->m_handle, &requestInputs, + overlapped->GetResultsPointer(), &asyncInfo); + if (err != HRESULT_FROM_WIN32(ERROR_IO_PENDING)) + { + DrLogA("XcSetAndGetProcessInfo failed with error %s", DRERRORSTRING(err)); + } +} + +void DrXComputeInternal::ProcessPropertyFetch(HRESULT status, DrString propertyName, + PXC_SETANDGETPROCESSINFO_REQRESULTS xcStatus, + DrPPSMessagePtr message) +{ + /* Payload for process */ + DrProcessPropertyStatusPtr ppProcessStatus = message->GetPayload(); + + /* Payload for vertex record */ + DrPropertyStatusPtr pVertexRecordStatus = ppProcessStatus->m_message->GetPayload(); + + DrLogI("Process property fetch %s status %s", propertyName.GetChars(), DRERRORSTRING(status)); + + if (!SUCCEEDED(status)) + { + /* we really don't want to add cascading retries all through the stack, + so we are going to take the bold stance that if XCompute says there + was an error here, it has diligently retried in as sensible a way as + we would have, and really, things are not going to improve. So the + process is now dead to us. */ + + pVertexRecordStatus->m_processState = DPBS_Failed; + pVertexRecordStatus->m_exitCode = 1; + DrString reason = "Property fetch failed"; + pVertexRecordStatus->m_status = DrNew DrError(status, "XCompute", reason); + } + else + { + PXC_PROCESS_INFO xcInfo = xcStatus->pProcessInfo; + if (xcInfo == NULL) + { + /* we'll say this is Failed rather than Zombie which attempts to keep the process + directory alive instead of just letting its lease expire */ + pVertexRecordStatus->m_processState = DPBS_Failed; + pVertexRecordStatus->m_exitCode = 1; + DrString reason = "Property fetch had no info"; + pVertexRecordStatus->m_status = DrNew DrError(DrError_Unexpected, "XCompute", reason); + } + else + { + DrProcessStatsRef stats = DrNew DrProcessStats(); + ppProcessStatus->m_statistics = stats; + + stats->m_exitCode = xcInfo->ExitCode; + stats->m_pid = xcInfo->Win32Pid; + + pVertexRecordStatus->m_exitCode = xcInfo->ExitCode; + /* TODO this seems very broken on the XCompute side: fake it up for now */ + DrProcessState state = TranslateXcState(xcInfo->ProcessState, xcInfo->ProcessStatus); + if (pVertexRecordStatus->m_exitCode == STILL_ACTIVE) + { + DrLogI("Property fetch came back with process exitcode STILL_ACTIVE state %d", state); + pVertexRecordStatus->m_processState = DPBS_Running; + } + else + { + DrLogI("Property fetch came back with process exitcode %u state %d status %s", + xcInfo->ExitCode, state, DRERRORSTRING(xcInfo->ProcessStatus)); + if (pVertexRecordStatus->m_exitCode == 0) + { + pVertexRecordStatus->m_processState = DPBS_Completed; + } + else + { + pVertexRecordStatus->m_processState = DPBS_Failed; + DrString reason = "Property fetch said process has failed"; + pVertexRecordStatus->m_status = DrNew DrError(xcInfo->ProcessStatus, "DrXCompute", reason); + } + } + + if ((xcInfo->Flags & XCPROCESSINFOOPTION_TIMINGINFO) == XCPROCESSINFOOPTION_TIMINGINFO) + { + stats->m_createdTime = xcInfo->CreatedTime; + stats->m_beginExecutionTime = xcInfo->BeginExecutionTime; + stats->m_terminatedTime = xcInfo->TerminatedTime; + } + + if ((xcInfo->Flags & XCPROCESSINFOOPTION_PROCESSSTAT) == XCPROCESSINFOOPTION_PROCESSSTAT) + { + PXC_PROCESS_STATISTICS xcStat = xcInfo->pProcessStatistics; + DrAssert(xcStat != NULL); + stats->m_userTime = xcStat->ProcessUserTime; + stats->m_kernelTime = xcStat->ProcessKernelTime; + stats->m_pageFaults = xcStat->PageFaults; + stats->m_peakVMUsage = xcStat->PeakVMUsage; + stats->m_peakMemUsage = xcStat->PeakMemUsage; + stats->m_memUsageSeconds = xcStat->MemUsageSeconds; + stats->m_totalIO = xcStat->TotalIo; + } + + if (xcInfo->NumberofProcessProperties > 0) + { + PXC_PROCESSPROPERTY_INFO prop = xcInfo->ppProperties[0]; + if (strcmp(prop->pPropertyLabel, propertyName.GetChars()) == 0) + { + DrLogI("Received new process property %s version %I64u", + propertyName.GetChars(), prop->PropertyVersion); + pVertexRecordStatus->m_statusVersion = prop->PropertyVersion; + if (prop->PropertyBlockSize > 0 && prop->PropertyBlockSize < 0x80000000) + { + pVertexRecordStatus->m_statusBlock = DrNew DrByteArray((int) prop->PropertyBlockSize); + DRPIN(BYTE) arrayDst = &(pVertexRecordStatus->m_statusBlock[0]); + memcpy(arrayDst, prop->pPropertyBlock, prop->PropertyBlockSize); + } + } + else + { + DrLogW("Process fetch returned unexpected property %s", prop->pPropertyLabel); + } + } + } + } + + m_messagePump->EnQueue(message); +} + +void DrXComputeInternal::SetProcessCommand(DrProcessHandlePtr p, + UINT64 newVersion, DrString propertyName, + DrString propertyDescription, + DrByteArrayRef propertyBlock, + DrErrorListenerPtr listener) +{ + DrXComputeProcessHandlePtr process = dynamic_cast(p); + DrAssert(process != DrNull); + + DrErrorMessageRef message = DrNew DrErrorMessage(listener, DrNull); + + DrXComputeSetCommandOverlapped* overlapped = new DrXComputeSetCommandOverlapped(this, message); + + XC_PROCESSPROPERTY_INFO prop; + memset(&prop, 0, sizeof(prop)); + prop.Size = sizeof(prop); + prop.PropertyVersion = newVersion; + prop.pPropertyLabel = (PSTR) propertyName.GetChars(); + prop.pPropertyString = (PSTR) propertyDescription.GetChars(); + DRPIN(BYTE) blockArray = &(propertyBlock[0]); + prop.pPropertyBlock = (char *) blockArray; + prop.PropertyBlockSize = propertyBlock->Allocated(); + PXC_PROCESSPROPERTY_INFO pProp = ∝ + + XC_SETANDGETPROCESSINFO_REQINPUT requestInputs; + memset(&requestInputs, 0, sizeof(requestInputs)); + requestInputs.Size = sizeof(requestInputs); + requestInputs.pAppProcessConstraints = NULL; + requestInputs.NumberOfProcessPropertiesToSet = 1; + requestInputs.ppPropertiesToSet = &pProp; + + XC_ASYNC_INFO asyncInfo; + memset(&asyncInfo, 0, sizeof(asyncInfo)); + asyncInfo.Size = sizeof(asyncInfo); + asyncInfo.pOperationState = overlapped->GetOperationStatePtr(); + asyncInfo.IOCP = m_messagePump->GetCompletionPort(); + asyncInfo.pOverlapped = overlapped; + + m_messagePump->NotifySubmissionToCompletionPort(overlapped); + + XCERROR err = XcSetAndGetProcessInfo(process->m_handle, &requestInputs, + overlapped->GetResultsPointer(), &asyncInfo); + if (err != HRESULT_FROM_WIN32(ERROR_IO_PENDING)) + { + DrLogA("XcSetAndGetProcessInfo failed with error %s", DRERRORSTRING(err)); + } +} + +void DrXComputeInternal::ProcessCommandResult(HRESULT status, DrErrorMessagePtr message) +{ + if (!SUCCEEDED(status)) + { + /* we really don't want to add cascading retries all through the stack, + so we are going to take the bold stance that if XCompute says there + was an error here, it has diligently retried in as sensible a way as + we would have, and really, things are not going to improve. So the + process is now dead to us. */ + + /* We use the Failed state here instead of the Zombie state because we + are relying on XCompute returning success here, but setting the process + state to some zombie value, if the process has been garbage collected */ + DrString reason; + reason.SetF("Set command RPC failed with error %s", DRERRORSTRING(status)); + message->SetPayload(DrNew DrError(status, "XCompute", reason)); + } + else if (status != S_OK) + { + DrString reason; + reason.SetF("Process command send succeeded with non-error status %s", DRERRORSTRING(status)); + message->SetPayload(DrNew DrError(status, "XCompute", reason)); + } + + m_messagePump->EnQueue(message); +} + +void DrXComputeInternal::TerminateProcess(DrProcessHandlePtr /* unused p*/, + DrErrorListenerPtr /* unused listener */) +{ + /* we have no way to do this right now */ +} + +void DrXComputeInternal::ResetProgress(UINT32 totalSteps, bool update) +{ + XcResetProgress(m_session, totalSteps, update); +} + +void DrXComputeInternal::IncrementTotalSteps(bool update) +{ + XcIncrementTotalSteps(m_session, update); +} + +void DrXComputeInternal::DecrementTotalSteps(bool update) +{ + XcDecrementTotalSteps(m_session, update); +} + +void DrXComputeInternal::IncrementProgress(PCSTR message) +{ + XcIncrementProgress(m_session, message); +} + +void DrXComputeInternal::CompleteProgress(PCSTR message) +{ + XcCompleteProgress(m_session, message); +} + diff --git a/GraphManager/kernel/drxcomputeinternal.h b/GraphManager/kernel/drxcomputeinternal.h new file mode 100644 index 0000000..e8156f1 --- /dev/null +++ b/GraphManager/kernel/drxcomputeinternal.h @@ -0,0 +1,164 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +DRCLASS(DrXComputeResource) : public DrResource +{ +public: + DrXComputeResource(DrResourceLevel level, DrString name, DrString locality, DrResourcePtr parent, + XCPROCESSNODEID node); + + XCPROCESSNODEID GetNode(); + +private: + XCPROCESSNODEID m_node; +}; + +DRCLASS(DrXComputeProcessHandle) : public DrProcessHandle +{ +public: + virtual void CloseHandle() DROVERRIDE; + virtual DrString GetHandleIdAsString() DROVERRIDE; + virtual DrProcessState GetState(HRESULT& reason) DROVERRIDE; + virtual DrString GetFileURIBase() DROVERRIDE; + + /* this is public so the managed DrXComputeInternal class can make a pin_ptr out of it */ + XCPROCESSHANDLE m_handle; +}; +DRREF(DrXComputeProcessHandle); + +class DrXComputeWaitForStateChangeOverlapped : public DrXComputeOverlapped +{ +public: + DrXComputeWaitForStateChangeOverlapped(DrXComputePtr parent, DrPSRMessagePtr message); + + void Process(); +}; + +class DrXComputeCancelScheduleProcessOverlapped : public DrXComputeOverlapped +{ +public: + DrXComputeCancelScheduleProcessOverlapped(DrXComputePtr parent); + + void Process(); +}; + +class DrXComputeGetSetOverlapped : public DrXComputeOverlapped +{ +public: + DrXComputeGetSetOverlapped(DrXComputePtr parent, DrMessageBasePtr message); + virtual ~DrXComputeGetSetOverlapped(); + + PXC_SETANDGETPROCESSINFO_REQRESULTS* GetResultsPointer(); + +protected: + PXC_SETANDGETPROCESSINFO_REQRESULTS m_results; +}; + +DRBASECLASS(DrHeapString) +{ +public: + DrString m_payload; +}; +DRREF(DrHeapString); + +class DrXComputeGetPropertyOverlapped : public DrXComputeGetSetOverlapped +{ +public: + DrXComputeGetPropertyOverlapped(DrXComputePtr parent, DrString propertyName, DrPPSMessagePtr message); + + void Process(); + +private: + DrRefHolder m_propertyName; +}; + +class DrXComputeSetCommandOverlapped : public DrXComputeGetSetOverlapped +{ +public: + DrXComputeSetCommandOverlapped(DrXComputePtr parent, DrErrorMessagePtr message); + + void Process(); +}; + +DRCLASS(DrXComputeInternal) : public DrXCompute +{ +public: + DrXComputeInternal(); + ~DrXComputeInternal(); + + virtual HRESULT Initialize(DrUniversePtr universe, DrMessagePumpPtr pump) DROVERRIDE; + virtual void Shutdown() DROVERRIDE; + + virtual DrUniversePtr GetUniverse() DROVERRIDE; + virtual DrMessagePumpPtr GetMessagePump() DROVERRIDE; + virtual DrDateTime GetCurrentTimeStamp() DROVERRIDE; + + virtual void ScheduleProcess(DrAffinityListRef affinities, + DrString name, DrString commandLine, + DrProcessTemplatePtr processTemplate, + DrPSRListenerPtr listener) DROVERRIDE; + virtual void CancelScheduleProcess(DrProcessHandlePtr process) DROVERRIDE; + + virtual void WaitUntilStart(DrProcessHandlePtr process, DrPSRListenerPtr listener) DROVERRIDE; + virtual void WaitUntilCompleted(DrProcessHandlePtr process, DrPSRListenerPtr listener) DROVERRIDE; + + virtual void GetProcessProperty(DrProcessHandlePtr process, + UINT64 lastSeenVersion, DrString propertyName, + DrTimeInterval maxBlockTime, + DrPPSListenerPtr processListener, + DrPropertyListenerPtr propertyListener) DROVERRIDE; + + virtual void SetProcessCommand(DrProcessHandlePtr p, + UINT64 newVersion, DrString propertyName, + DrString propertyDescription, + DrByteArrayRef propertyBlock, + DrErrorListenerPtr listener) DROVERRIDE; + + virtual void TerminateProcess(DrProcessHandlePtr p, + DrErrorListenerPtr listener) DROVERRIDE; + + void ProcessStateChange(HRESULT status, DrPSRMessagePtr message); + void ProcessPropertyFetch(HRESULT status, DrString propertyName, + PXC_SETANDGETPROCESSINFO_REQRESULTS xcStatus, DrPPSMessagePtr message); + void ProcessCommandResult(HRESULT status, DrErrorMessagePtr message); + + virtual void ResetProgress(UINT32 totalSteps, bool update) DROVERRIDE; + virtual void IncrementTotalSteps(bool update) DROVERRIDE; + virtual void DecrementTotalSteps(bool update) DROVERRIDE; + virtual void IncrementProgress(PCSTR message) DROVERRIDE; + virtual void CompleteProgress(PCSTR message) DROVERRIDE; + + +private: + void AddNodeToUniverse(XCPROCESSNODEID node); + HRESULT FetchListOfComputers(); + void WaitForStateChange(DrXComputeProcessHandlePtr p, XCPROCESSSTATE targetState, + DrPSRListenerPtr listener); + + DrUniverseRef m_universe; + DrMessagePumpRef m_messagePump; + XCSESSIONHANDLE m_session; + XCPROCESSHANDLE m_localProcess; +}; +DRREF(DrXComputeInternal); \ No newline at end of file diff --git a/GraphManager/reporting/DrArtemisLegacyReporting.cpp b/GraphManager/reporting/DrArtemisLegacyReporting.cpp new file mode 100644 index 0000000..4d5e82a --- /dev/null +++ b/GraphManager/reporting/DrArtemisLegacyReporting.cpp @@ -0,0 +1,350 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +// +// Prints a timestamp as MM/DD/YYYY HH:MM:SS.MS +// +static void PrintTimestamp() +{ + SYSTEMTIME utc, local; + FILETIME ft; + GetSystemTimeAsFileTime(&ft); + FileTimeToSystemTime(&ft, &utc); + SystemTimeToTzSpecificLocalTime(NULL, &utc, &local); + + printf("[%02d/%02d/%04d %02d:%02d:%02d.%03u] ", + local.wMonth, + local.wDay, + local.wYear, + local.wHour, + local.wMinute, + local.wSecond, + local.wMilliseconds); +} + +/* + +// Abandoning duplicate scheduling of vertex 83.1 (InputTable__26[5]) +vertexAbandonedRegex = new Regex(@"Abandoning duplicate scheduling of vertex (\d+)\.(\d+) \((.+)\)", RegexOptions.Compiled); + +// BasicAggregate__10(1),(Super__3[0].0*Super__3[1].0*Super__3[2].0*Super__3[3].0*Super__3[4].0*Super__3[5].0*Super__3[6].0),(sherwood-068),0.1718728s +topologyRegex = new Regex(@"(.*),\((.*)\),\((\S+)\),([.0-9]+)s$", + RegexOptions.Compiled); + +// total=951722563162 local=37817665237 intrapod=189765117248 crosspod=724139780677 +datareadRegex = new Regex(@"total=(\d+) local=(\d+) intrapod=(\d+) crosspod=(\d+)", RegexOptions.Compiled); + +Also: + +Timing Information Graph Start Time (\d+) +JM Finish Time: (\d+) +ABORTING: +\(.*)\ +Application completed successfully. +Application failed with error code (.*) +Total running time in vertices successful/failed: ([.0-9]+)s\/([.0-9]+)s +Average job parallelism +*/ + +/* + +// Created process execution record for vertex 33 (Super__0[0]) v.0 GUID {B0FC788F-1FFC-4D74-AFC4-3EDFF03AF11A} +vertexCreatedRegex = new Regex(@"\[(.*)\] Created process execution record for vertex (\d+) \((.*)\) v.(\d+) GUID \{?([-A-F0-9]+)\}?", + RegexOptions.Compiled); + +// Process started for vertex 5 (Super__0[1]) v.0 GUID {73EA55E0-0326-43C4-AD61-CB0B8CF8FE49} machine sherwood-025 +// Process started for vertices 23 (Merge__29) 24 (Apply__33) v.0 GUID {E945DC5D-9AF6-4732-8770-2A6BF7FA3041} machine sherwood-237 +vertexStartRegex = new Regex(@"\[(.*)\] Process started for vert(\w+) (.*) v\.(.*) GUID \{?([-A-F0-9]+)\}? machine (\S+)", + RegexOptions.Compiled); + +*/ + +void DrArtemisLegacyReporter::ReceiveMessage(DrProcessInfoRef info) +{ + DrString processName; + { + DrLockBoxKey process(info->m_process); + processName = process->GetName(); + } + + if (info->m_state->m_state == DPS_Starting) + { + PrintTimestamp(); + printf("Created process execution record for %s GUID {%s}\n", + processName.GetChars(), info->m_state->m_process->GetHandleIdAsString().GetChars()); + PrintTimestamp(); + printf("Process started for %s GUID {%s} machine %s\n", + processName.GetChars(), info->m_state->m_process->GetHandleIdAsString().GetChars(), + info->m_state->m_process->GetAssignedNode()->GetName().GetChars()); + fflush(stdout); + } +} + +/* + +// Vertex 5.0 (Super__0[1]) machine sherwood-025 guid {73EA55E0-0326-43C4-AD61-CB0B8CF8FE49} status Vertex Has Completed, +terminationRegex = new Regex(@"Vertex (\d+)\.(\d+) \((.+)\) machine (\S+) guid \{?([-0-9A-F]+)\}? status (.*)", + RegexOptions.Compiled); + +// Canceling vertex 1461.0 (Merge__13[258]) due to dependent failure +cancelRegex = new Regex(@"\[(.*)\] Canceling vertex (\d+)\.(\d+) \((.+)\) due to (.*)", RegexOptions.Compiled); + +// Process was terminated Vertex 11.0 (Select__6[1]) GUID {C1E35A88-F5AD-4A26-BE5F-46B6D515623F} machine sherwood-118 status The operation succeeded +terminatedRegex = new Regex(@"\[(.*)\] Process was terminated Vertex (\d+)\.(\d+) \((.+)\) GUID \{?([-A-F0-9]+)\}? machine (\S+) status (.*)", + RegexOptions.Compiled); + +// Process has failed Vertex 11.0 (Select__6[1]) GUID {C1E35A88-F5AD-4A26-BE5F-46B6D515623F} machine sherwood-118 status The operation succeeded +failedRegex = new Regex(@"\[(.*)\] Process has failed Vertex (\d+)\.(\d+) \((.+)\) GUID \{?([-A-F0-9]+)\}? machine (\S+) Exitcode (.*)", + RegexOptions.Compiled); + +// Timing Information 5 1 Super__0[1] 128654556602334453 0.0000 0.0000 0.0000 0.0000 0.2969 +timingInfoRegex = new Regex(@"Timing Information (\d+) (\d+) (.+) (\d+) ([-.0-9]+) ([-.0-9]+) ([-.0-9]+) ([-.0-9]+) ([-.0-9]+)", + RegexOptions.Compiled); + +*/ + +void DrArtemisLegacyReporter::ReceiveMessage(DrVertexInfoRef info) +{ + DrString processGuid = "(no guid)"; + DrString machineName = "(no computer)"; + + DrDateTime processSchedule = DrDateTime_Never; + DrDateTime processStartCreate = DrDateTime_Never; + DrDateTime processFinishCreate = DrDateTime_Never; + + if (info->m_process.IsEmpty() == false) + { + DrLockBoxKey process(info->m_process); + DrProcessHandlePtr handle = process->GetInfo()->m_state->m_process; + if (handle != DrNull) + { + processGuid = handle->GetHandleIdAsString(); + if (handle->GetAssignedNode() != DrNull) + { + machineName = handle->GetAssignedNode()->GetName(); + } + } + + processSchedule = process->GetInfo()->m_jmProcessScheduledTime; + processStartCreate = process->GetInfo()->m_statistics->m_createdTime; + processFinishCreate = process->GetInfo()->m_statistics->m_beginExecutionTime; + } + + if (info->m_state == DVS_Completed) + { + PrintTimestamp(); + printf("Vertex %d.%d (%s) machine %s guid {%s} status Vertex Has Completed\n", + info->m_info->GetProcessStatus()->GetVertexId(), + info->m_info->GetProcessStatus()->GetVertexInstanceVersion(), + info->m_name.GetChars(), machineName.GetChars(), processGuid.GetChars()); + } + else if (info->m_state == DVS_Failed) + { + if (info->m_statistics->m_exitStatus == DrError_CohortShutdown) + { + PrintTimestamp(); + printf("Canceling vertex %d.%d (%s) due to dependent failure\n", + info->m_info->GetProcessStatus()->GetVertexId(), + info->m_info->GetProcessStatus()->GetVertexInstanceVersion(), + info->m_name.GetChars()); + } + + if (info->m_statistics->m_exitStatus == DrError_Unexpected) + { + PrintTimestamp(); + printf("Process was terminated Vertex %d.%d (%s) GUID {%s} machine %s status The operation succeeded\n", + info->m_info->GetProcessStatus()->GetVertexId(), + info->m_info->GetProcessStatus()->GetVertexInstanceVersion(), + info->m_name.GetChars(), processGuid.GetChars(), machineName.GetChars()); + } + else + { + PrintTimestamp(); + printf("Process has failed Vertex %d.%d (%s) GUID {%s} machine %s Exitcode %x\n", + info->m_info->GetProcessStatus()->GetVertexId(), + info->m_info->GetProcessStatus()->GetVertexInstanceVersion(), + info->m_name.GetChars(), processGuid.GetChars(), machineName.GetChars(), + info->m_statistics->m_exitStatus); + } + } + + if (info->m_state == DVS_Completed || info->m_state == DVS_Failed) + { + if ((info->m_statistics != DrNull) + && (info->m_statistics->m_totalInputData != DrNull) + && (info->m_statistics->m_totalOutputData != DrNull)) + { + printf("Io information %d %d %s read %I64u wrote %I64u\n", + info->m_info->GetProcessStatus()->GetVertexId(), + info->m_info->GetProcessStatus()->GetVertexInstanceVersion(), + info->m_name.GetChars(), + info->m_statistics->m_totalInputData->m_dataRead, + info->m_statistics->m_totalOutputData->m_dataWritten); + + printf("Io locality information %d %d %s read %I64u ( %I64u local )\n", + info->m_info->GetProcessStatus()->GetVertexId(), + info->m_info->GetProcessStatus()->GetVertexInstanceVersion(), + info->m_name.GetChars(), + info->m_statistics->m_totalInputData->m_dataRead, + info->m_statistics->m_totalLocalInputData); + } + } + + if (info->m_state >= DVS_Completed) + { + DrVertexExecutionStatisticsPtr eStats = info->m_statistics; + DrDateTime execRunning = eStats->m_runningTime; + DrDateTime completion = eStats->m_completionTime; + + if ((processSchedule == DrDateTime_Never) || (processSchedule < eStats->m_creationTime)) + { + /* this vertex is part of a cohort, and was created after the + process had already started */ + processSchedule = eStats->m_creationTime; + } + + if ((processStartCreate == DrDateTime_Never) || (processStartCreate < processSchedule)) + { + /* this vertex is part of a cohort, and was created after the + process had already started, or clock skew means it appears to have started before it + was scheduled */ + processStartCreate = processSchedule; + } + + if ((processFinishCreate == DrDateTime_Never) || (processFinishCreate < processStartCreate)) + { + /* this vertex is part of a cohort, and was created after the + process creation had already completed */ + processFinishCreate = processStartCreate; + } + + if ((execRunning == DrDateTime_Never) || (execRunning < processFinishCreate)) + { + /* this vertex had never run */ + execRunning = processFinishCreate; + } + + if ((completion == DrDateTime_Never) || (completion < execRunning)) + { + /* clock skew??? */ + completion = execRunning; + } + + double creatToScheduleTime = + (double) (processSchedule - eStats->m_creationTime) / (double) DrTimeInterval_Second; + double schedToStartProcessTime = + (double) (processStartCreate - processSchedule) / (double) DrTimeInterval_Second; + double pStartToCreatedProcessTime = + (double) (processFinishCreate - processStartCreate) / (double) DrTimeInterval_Second; + double cProcessToRunTime = + (double) (execRunning - processFinishCreate) / (double) DrTimeInterval_Second; + double runToCompTime = + (double) (completion - execRunning) / (double) DrTimeInterval_Second; + + // No need to print timestamp for timing report + printf("Timing Information %u %u %s %I64u %.4f %.4f %.4f %.4f %.4f\n", + info->m_info->GetProcessStatus()->GetVertexId(), + info->m_info->GetProcessStatus()->GetVertexInstanceVersion(), + info->m_name.GetChars(), + eStats->m_creationTime, + creatToScheduleTime, schedToStartProcessTime, pStartToCreatedProcessTime, + cProcessToRunTime, runToCompTime); + fflush(stdout); + } +} + +void DrArtemisLegacyReporter::ReportFinalTopology(DrVertexPtr vertex, DrResourcePtr runningMachine, + DrTimeInterval runningTime) +{ + DrString machineName = "nowhere"; + if (runningMachine != DrNull) + { + machineName = runningMachine->GetName(); + } + + // No need to print timestamp for topology reporting + printf("%s(%d),(", vertex->GetName().GetChars(), vertex->GetOutputs()->GetNumberOfEdges()); + + int i; + for (i=0; iGetInputs()->GetNumberOfEdges(); ++i) + { + DrEdge e = vertex->GetInputs()->GetEdge(i); + + if (e.m_remoteVertex == DrNull) + { + printf("%sNULL", (i > 0) ? "*" : ""); + } + else + { + printf("%s%s.%d", (i > 0) ? "*" : "", e.m_remoteVertex->GetName().GetChars(), e.m_remotePort); + } + } + + printf("),(%s),%lfs\n", machineName.GetChars(), (double) runningTime / (double) DrTimeInterval_Second); + fflush(stdout); +} + + +void DrArtemisLegacyReporter::ReportStart(DrDateTime startTime) +{ + union { + FILETIME ft; + DrDateTime ts; + }; + ts = startTime; + SYSTEMTIME utc, local; + FileTimeToSystemTime(&ft, &utc); + SystemTimeToTzSpecificLocalTime(NULL, &utc, &local); + + printf("Start time: %02d/%02d/%04d %02d:%02d:%02d.%03u\n", + local.wMonth, + local.wDay, + local.wYear, + local.wHour, + local.wMinute, + local.wSecond, + local.wMilliseconds); + + fflush(stdout); +} + +void DrArtemisLegacyReporter::ReportStop(UINT exitCode) +{ + SYSTEMTIME utc, local; + FILETIME ft; + GetSystemTimeAsFileTime(&ft); + FileTimeToSystemTime(&ft, &utc); + SystemTimeToTzSpecificLocalTime(NULL, &utc, &local); + + printf("Stop time (Exit code = %u): %02d/%02d/%04d %02d:%02d:%02d.%03u\n", + exitCode, + local.wMonth, + local.wDay, + local.wYear, + local.wHour, + local.wMinute, + local.wSecond, + local.wMilliseconds + ); + + fflush(stdout); +} \ No newline at end of file diff --git a/GraphManager/reporting/DrArtemisLegacyReporting.h b/GraphManager/reporting/DrArtemisLegacyReporting.h new file mode 100644 index 0000000..c73cfd9 --- /dev/null +++ b/GraphManager/reporting/DrArtemisLegacyReporting.h @@ -0,0 +1,41 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRCLASS(DrArtemisLegacyReporter) : public DrCritSec, public DrProcessListener, public DrVertexListener, + public DrVertexTopologyReporter +{ +public: + /* the DrProcessListener implementation */ + virtual void ReceiveMessage(DrProcessInfoRef info); + + /* the DrVertexListener implementation */ + virtual void ReceiveMessage(DrVertexInfoRef info); + + /* the DrVertexTopologyReporter implementation */ + virtual void ReportFinalTopology(DrVertexPtr vertex, DrResourcePtr runningMachine, + DrTimeInterval runningTime); + + static void ReportStart(DrDateTime startTime); + + static void ReportStop(UINT exitCode); +}; +DRREF(DrArtemisLegacyReporter); \ No newline at end of file diff --git a/GraphManager/reporting/DrReporting.h b/GraphManager/reporting/DrReporting.h new file mode 100644 index 0000000..1ad1db7 --- /dev/null +++ b/GraphManager/reporting/DrReporting.h @@ -0,0 +1,25 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +#include diff --git a/GraphManager/shared/DrArray.h b/GraphManager/shared/DrArray.h new file mode 100644 index 0000000..14731ca --- /dev/null +++ b/GraphManager/shared/DrArray.h @@ -0,0 +1,170 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "DrRef.h" + +#ifdef _MANAGED + +template DRBASECLASS(DrArray) +{ +public: + DrArray(int s) + { + m_array = DrNew array(s); + } + + T% operator[](int element) + { + return m_array[element]; + } + + int Allocated() + { + return m_array->Length; + } + + array^ GetArray() + { + return m_array; + } + +protected: + array^ m_array; +}; + +template DRBASECLASS(DrValueWrapperArray) +{ +public: + DrValueWrapperArray(int s) + { + m_array = DrNew array< DrValueWrapper >(s); + } + + T% operator[](int element) + { + return m_array[element].T(); + } + + int Allocated() + { + return m_array->Length; + } + + array< DrValueWrapper >^ GetArray() + { + return m_array; + } + +protected: + array< DrValueWrapper >^ m_array; +}; + +#define DRMAKEARRAY(T_) typedef DrArray T_##Array; \ + typedef T_##Array^ T_##ArrayRef; typedef T_##Array^ T_##ArrayPtr; typedef T_##Array% T_##ArrayR; \ + template ref class DrArray + +#define DRMAKEVWARRAY(T_) typedef DrValueWrapperArray T_##Array; \ + typedef T_##Array^ T_##ArrayRef; typedef T_##Array^ T_##ArrayPtr; typedef T_##Array% T_##ArrayR; \ + template ref class DrValueWrapperArray + +#else + +template DRBASECLASS(DrArray) +{ +public: + DrArray(int s) + { + m_allocated = s; + m_array = new T[m_allocated]; + +#ifdef _DEBUG_DRREF + EnterCriticalSection(&DrRefCounter::s_debugCS); + bool inserted = DrRefCounter::s_arrayStorage.insert(std::make_pair(m_array,this)).second; + DrAssert(inserted); + LeaveCriticalSection(&DrRefCounter::s_debugCS); +#endif + } + + ~DrArray() + { + delete [] m_array; + +#ifdef _DEBUG_DRREF + EnterCriticalSection(&DrRefCounter::s_debugCS); + size_t nRemoved = DrRefCounter::s_arrayStorage.erase(m_array); + DrAssert(nRemoved == 1); + LeaveCriticalSection(&DrRefCounter::s_debugCS); +#endif + } + + T& operator[](int element) + { + DrAssert(element < m_allocated); + return m_array[element]; + } + + int Allocated() + { + return m_allocated; + } + + T* GetPtr() + { + return m_array; + } + +protected: + int m_allocated; + T* m_array; +}; + +#ifdef _MANAGED +#define DRMAKEARRAY(T_) typedef DrArray T_##Array; \ + typedef DrArrayRef T_##ArrayRef; typedef T_##Array* T_##ArrayPtr; typedef T_##Array& T_##ArrayR; \ + template class DrArray +#else +#define DRMAKEARRAY(T_) typedef DrArray T_##Array; \ + typedef DrArrayRef T_##ArrayRef; typedef T_##Array* T_##ArrayPtr; typedef T_##Array& T_##ArrayR; +#endif + +#define DRMAKEVWARRAY(T_) typedef DrArray T_##Array; \ + typedef DrArrayRef T_##ArrayRef; typedef T_##Array* T_##ArrayPtr; typedef T_##Array& T_##ArrayR; \ + template class DrArray + +#endif + +typedef DrArray DrByteArray; +DRAREF(DrByteArray,BYTE); + +typedef DrArray DrIntArray; +DRAREF(DrIntArray,int); + +typedef DrArray DrUINT32Array; +DRAREF(DrUINT32Array,UINT32); + +typedef DrArray DrUINT64Array; +DRAREF(DrUINT64Array,UINT64); + +typedef DrArray DrFloatArray; +DRAREF(DrFloatArray,float); + +DRMAKEVWARRAY(DrString); diff --git a/GraphManager/shared/DrArrayList.h b/GraphManager/shared/DrArrayList.h new file mode 100644 index 0000000..f727221 --- /dev/null +++ b/GraphManager/shared/DrArrayList.h @@ -0,0 +1,469 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#ifdef _MANAGED + + +template DRBASECLASS(DrArrayList), public System::Collections::Generic::List +{ +public: + DrArrayList() : System::Collections::Generic::List() + { + } + + DrArrayList(int initialSize) : System::Collections::Generic::List(initialSize) + { + } + + int Allocated() + { + return Capacity; + } + + int Size() + { + return Count; + } +}; + +template DRBASECLASS(DrValueWrapperArrayList) +{ +public: + DrValueWrapperArrayList() + { + m_list = DrNew System::Collections::Generic::List< DrValueWrapper >(); + } + + DrValueWrapperArrayList(int initialSize) + { + m_list = DrNew System::Collections::Generic::List< DrValueWrapper >(initialSize); + } + + void Add(T t) + { + DrValueWrapper tt; + tt.T() = t; + m_list->Add(tt); + } + + void RemoveAt(int i) + { + m_list->RemoveAt(i); + } + + T% operator[](int index) + { + return m_list[index].T(); + } + + T% Get(int index) + { + return m_list[index].T(); + } + + int Allocated() + { + return m_list->Capacity; + } + + int Size() + { + return m_list->Count; + } + +private: + System::Collections::Generic::List< DrValueWrapper >^ m_list; +}; + +#define DRMAKEARRAYLIST(T_) typedef DrArrayList T_##List; \ + typedef T_##List^ T_##ListRef; typedef T_##List^ T_##ListPtr; typedef T_##List% T_##ListR; \ + template ref class DrArrayList + +#define DRMAKEVWARRAYLIST(T_) typedef DrValueWrapperArrayList T_##List; \ + typedef T_##List^ T_##ListRef; typedef T_##List^ T_##ListPtr; typedef T_##List% T_##ListR; \ + template ref class DrValueWrapperArrayList + +#else + +#include +#include +#include + +template DRBASECLASS(DrArrayList) +{ +public: + DrArrayList() + { + Initialize(1); + } + + DrArrayList(int initialSize) + { + Initialize(initialSize); + } + + ~DrArrayList() + { +#ifdef _DEBUG_DRREF + EnterCriticalSection(&DrRefCounter::s_debugCS); + if (m_vector.size() > 0) + { + size_t nRemoved = DrRefCounter::s_arrayStorage.erase(&(m_vector[0])); + DrAssert(nRemoved == 1); + } + LeaveCriticalSection(&DrRefCounter::s_debugCS); +#endif + } + + void Add(T t) + { +#ifdef _DEBUG_DRREF + EnterCriticalSection(&DrRefCounter::s_debugCS); + if (m_vector.size() > 0) + { + size_t nRemoved = DrRefCounter::s_arrayStorage.erase(&(m_vector[0])); + DrAssert(nRemoved == 1); + } + LeaveCriticalSection(&DrRefCounter::s_debugCS); +#endif + + m_vector.push_back(t); + +#ifdef _DEBUG_DRREF + EnterCriticalSection(&DrRefCounter::s_debugCS); + bool inserted = DrRefCounter::s_arrayStorage.insert(std::make_pair(&(m_vector[0]),this)).second; + DrAssert(inserted); + LeaveCriticalSection(&DrRefCounter::s_debugCS); +#endif + } + + void RemoveAt(int element) + { +#ifdef _DEBUG_DRREF + EnterCriticalSection(&DrRefCounter::s_debugCS); + size_t nRemoved = DrRefCounter::s_arrayStorage.erase(&(m_vector[0])); + DrAssert(nRemoved == 1); + LeaveCriticalSection(&DrRefCounter::s_debugCS); +#endif + + m_vector.erase(m_vector.begin() + element); + +#ifdef _DEBUG_DRREF + EnterCriticalSection(&DrRefCounter::s_debugCS); + if (m_vector.size() > 0) + { + bool inserted = DrRefCounter::s_arrayStorage.insert(std::make_pair(&(m_vector[0]),this)).second; + DrAssert(inserted); + } + LeaveCriticalSection(&DrRefCounter::s_debugCS); +#endif + } + + bool Remove(T t) + { + std::vector::iterator i; + for (i=m_vector.begin(); i!=m_vector.end(); ++i) + { + if (*i == t) + { + m_vector.erase(i); + return true; + } + } + return false; + } + + T& operator[](int element) + { + return m_vector[element]; + } + + T& Get(int element) + { + return m_vector[element]; + } + + + int Size() + { + return (int) m_vector.size(); + } + + int Allocated() + { + return (int) m_vector.capacity(); + } + + void Sort(DrComparer* comparer) + { + Comparer c(comparer); + std::sort(m_vector.begin(), m_vector.end(), c); + } + +private: + class Comparer : std::binary_function + { + public: + Comparer(DrComparer* comparer) + { + m_comparer = comparer; + } + + bool operator() (T& a, T& b) + { + return (m_comparer->Compare(a, b) < 0); + } + + private: + DrComparer* m_comparer; + }; + + void Initialize(int initialSize) + { + if (initialSize == 0) + { + initialSize = 1; + } + m_vector.reserve(initialSize); +#ifdef _DEBUG_DRREF + DrAssert(m_vector.size() == 0); +#endif + } + +protected: + std::vector m_vector; +}; + +#if 0 +template DRBASECLASS(DrArrayList) +{ +public: + DrArrayList() + { + Initialize(1); + } + + DrArrayList(int initialSize) + { + Initialize(initialSize); + } + + ~DrArrayList() + { + delete [] m_array; + } + + void Add(T t) + { + if (m_used == m_allocated) + { + m_allocated *= 2; + T* newArray = new T[m_allocated]; + + int i; + for (i=0; i* comparer) + { + ::qsort_s(m_array, m_used, sizeof(T), DrComparer::CompareUntyped, comparer); + } + +private: + void Initialize(int initialSize) + { + if (initialSize == 0) + { + initialSize = 1; + } + m_allocated = initialSize; + m_used = 0; + m_array = new T[m_allocated]; + } + +protected: + int m_allocated; + int m_used; + T* m_array; +}; + +template DRBASECLASS(DrArrayList) +{ +public: + DrArrayList() + { + Initialize(1); + } + + DrArrayList(int initialSize) + { + Initialize(initialSize); + } + + ~DrArrayList() + { + Clear(); + delete [] m_array; + } + + void Add(T t) + { + if (m_used == m_allocated) + { + m_allocated *= 2; + unsigned char* newArray = new unsigned char[m_allocated * (int) sizeof(T)]; + + memcpy(newArray, m_array, m_used * (int) sizeof(T)); + + delete [] m_array; + m_array = newArray; + } + + DrAssert(m_used < m_allocated); + void* insertLocation = &(m_array[m_used * (int) sizeof(T)]); + ::new (insertLocation) T(t); + + ++m_used; + } + + void RemoveAt(int element) + { + DrAssert(element < m_used); + int i; + for (i=element+1; i 0) + { + RemoveAt(m_used - 1); + } + } + + int Size() + { + return m_used; + } + + void Sort(DrComparer* comparer) + { + ::qsort_s(m_array, m_used, sizeof(T), DrComparer::CompareUntyped, comparer); + } + +private: + void Initialize(int initialSize) + { + if (initialSize == 0) + { + initialSize = 1; + } + m_allocated = initialSize; + m_used = 0; + m_array = new unsigned char[m_allocated * (int) sizeof(T)]; + } + +protected: + int m_allocated; + int m_used; + unsigned char* m_array; +}; +#endif + +#define DRMAKEARRAYLIST(T_) typedef DrArrayList T_##List; \ + typedef DrArrayRef T_##ListRef; typedef T_##List* T_##ListPtr; typedef T_##List& T_##ListR; \ + template class DrArrayList + +#define DRMAKEVWARRAYLIST(T_) typedef DrArrayList T_##List; \ + typedef DrArrayRef T_##ListRef; typedef T_##List* T_##ListPtr; typedef T_##List& T_##ListR; \ + template class DrArrayList + +#endif + +typedef DrArrayList DrByteArrayList; +DRAREF(DrByteArrayList,BYTE); +template DRDECLARECLASS(DrArrayList); + +typedef DrArrayList DrIntArrayList; +DRAREF(DrIntArrayList,int); +template DRDECLARECLASS(DrArrayList); + +DRMAKEVWARRAYLIST(DrString); diff --git a/GraphManager/shared/DrAssert.h b/GraphManager/shared/DrAssert.h new file mode 100644 index 0000000..2e14e87 --- /dev/null +++ b/GraphManager/shared/DrAssert.h @@ -0,0 +1,23 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#define DrAssert(_c) if (!(_c)) DrLogA("Assertion failed: %s", #_c) diff --git a/GraphManager/shared/DrCritSec.h b/GraphManager/shared/DrCritSec.h new file mode 100644 index 0000000..616642a --- /dev/null +++ b/GraphManager/shared/DrCritSec.h @@ -0,0 +1,379 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include "DrRef.h" + +DRDECLARECLASS(DrCritSec); +DRREF(DrCritSec); + +DRINTERFACE(DrICritSec) +{ +public: + virtual void Enter() DRABSTRACT; + virtual void Leave() DRABSTRACT; + virtual DrCritSecPtr GetBaseLock() DRABSTRACT; +}; +DRIREF(DrICritSec); + +#ifdef _MANAGED + +DRBASECLASS(DrCritSec), public DrICritSec +{ +public: + virtual void Enter() + { + System::Threading::Monitor::Enter(this); + } + + virtual void Leave() + { + System::Threading::Monitor::Exit(this); + } + + virtual DrCritSecPtr GetBaseLock() + { + return this; + } +}; + +ref class DrAutoCriticalSection +{ +public: + DrAutoCriticalSection(DrICritSecPtr critSec) + { + m_critSec = critSec; + m_critSec->Enter(); + } + ~DrAutoCriticalSection() + { + m_critSec->Leave(); + } + +private: + DrICritSecIRef m_critSec; +}; + +template ref class DrLockBox +{ +public: + DrLockBox() + { + } + + DrLockBox(T^ t) + { + m_obj = t; + } + + DrLockBox(DrLockBox% t) + { + m_obj = t.m_obj; + } + + DrLockBox% operator=(T^ t) + { + return Set(t); + } + + DrLockBox% operator=(DrLockBox% t) + { + return Set(t.m_obj); + } + + bool operator== (DrLockBox% t) + { + return (m_obj == t.m_obj); + } + + bool IsNull() + { + return (m_obj == DrNull); + } + + DrLockBox% Set(T^ t) + { + m_obj = t; + return *this; + } + + operator DrLockBox() + { + return DrLockBox(m_obj); + } + + bool IsEmpty() + { + return (m_obj == nullptr); + } + + /* this is only to be used by the DrLockBoxKey class, and if there were friend classes + in managed code this would be private */ + T^ DoNotUse() + { + return dynamic_cast(m_obj); + } + +private: + T^ m_obj; +}; + +template ref class DrLockBoxKey +{ +public: + DrLockBoxKey(DrLockBox box) + { + p = box.DoNotUse(); + this->Enter(); + } + + ~DrLockBoxKey() + { + this->Leave(); + } + + operator T^() + { + return p; + } + + T^ operator->() + { + return p; + } + + bool operator!() + { + return (p == nullptr); + } + + bool operator==(T^ pT) + { + return (p == pT); + } + + bool operator!=(T^ pT) + { + return (p != pT); + } + +private: + T^ p; +}; + +#else + +//#include + +DRBASECLASS(DrCritSec), public DrICritSec +{ +public: + DrCritSec() + { + InitializeCriticalSection(&m_critsec); + } + + ~DrCritSec() + { + DeleteCriticalSection(&m_critsec); + } + + virtual void Enter() + { +// printf("thread %u acquiring lock %p\n", GetCurrentThreadId(), this); + EnterCriticalSection(&m_critsec); +// printf("thread %u acquired lock %p\n", GetCurrentThreadId(), this); + } + + virtual void Leave() + { + LeaveCriticalSection(&m_critsec); +// printf("thread %u released lock %p\n", GetCurrentThreadId(), this); + } + + virtual DrCritSecPtr GetBaseLock() + { + return this; + } + +private: + CRITICAL_SECTION m_critsec; +}; +DRREF(DrCritSec); + +class DrAutoCriticalSection +{ +public: + DrAutoCriticalSection(DrICritSecPtr critSec) + { + m_critSec = critSec; + m_critSec->Enter(); + } + ~DrAutoCriticalSection() + { + m_critSec->Leave(); + } + +private: + /* ensure this can't be allocated on the heap: stack only! */ + void* operator new(size_t); + + DrICritSecIRef m_critSec; +}; + +template class DrLockBox +{ +public: + DrLockBox() + { + } + + DrLockBox(const DrLockBox& t) + { + m_obj = t.m_obj; + } + + DrLockBox(T* t) + { + m_obj = t; + } + + DrLockBox& operator=(T* t) + { + return Set(t); + } + + DrLockBox& operator=(const DrLockBox& t) + { + return Set((T*) (t.m_obj)); + } + + bool operator== (const DrLockBox& t) + { + return (m_obj == t.m_obj); + } + + bool IsNull() + { + return (m_obj == DrNull); + } + + DrLockBox& Set(T* t) + { + m_obj = t; + return *this; + } + + operator DrLockBox() + { + return DrLockBox(m_obj); + } + + bool IsEmpty() + { + return (m_obj == NULL); + } + + /* this is only to be used by the DrLockBoxKey class, and if there were friend classes + in managed code this would be private */ + T* DoNotUse() + { + return dynamic_cast((DrICritSecPtr) m_obj); + } + +private: + DrInterfaceRef m_obj; +}; + +template class DrLockBoxKey +{ +public: + DrLockBoxKey(DrLockBox box) + { + p = box.DoNotUse(); + p->Enter(); + } + ~DrLockBoxKey() + { + p->Leave(); + } + + operator T*() + { + return p; + } + + T* operator->() + { + return p; + } + + bool operator!() + { + return (p == NULL); + } + + bool operator==(T* pT) + { + return (p == pT); + } + + bool operator!=(T* pT) + { + return (p != pT); + } + +private: + /* ensure this can't be allocated on the heap: stack only! */ + void* operator new(size_t); + + DrInterfaceRef p; +}; + +#endif + +DRBASECLASS(DrSharedCritSec), public DrICritSec +{ +public: + DrSharedCritSec(DrICritSecPtr parent) + { + m_cs = parent; + } + + virtual void Enter() + { + m_cs->Enter(); + } + + virtual void Leave() + { + m_cs->Leave(); + } + + virtual DrCritSecPtr GetBaseLock() + { + return m_cs->GetBaseLock(); + } + +private: + DrICritSecIRef m_cs; +}; +DRREF(DrSharedCritSec); diff --git a/GraphManager/shared/DrDictionary.h b/GraphManager/shared/DrDictionary.h new file mode 100644 index 0000000..b3116d2 --- /dev/null +++ b/GraphManager/shared/DrDictionary.h @@ -0,0 +1,451 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#ifdef _MANAGED + +template DRBASECLASS(DrDictionary), public System::Collections::Generic::Dictionary +{ +public: + value class DrEnumerator + { + public: + DrEnumerator(Dictionary^ dict) + { + m_iterator = dict->GetEnumerator(); + } + + K GetKey() + { + return m_iterator.Current.Key; + } + + V GetValue() + { + return m_iterator.Current.Value; + } + + bool MoveNext() + { + return m_iterator.MoveNext(); + } + + private: + Dictionary::Enumerator m_iterator; + }; + + DrEnumerator GetDrEnumerator() + { + return DrEnumerator(this); + } + + int GetSize() + { + return Count; + } + + void Replace(K key, V value) + { + bool removed = Remove(key); + DrAssert(removed); + Add(key, value); + } +}; + +template DRCLASS(DrStringDictionary) : public DrDictionary +{ +}; + +#if 0 +/* unfortunately this crashes the compiler with an internal error */ +template DRCLASS(DrDictionaryForValueWrapper) + : public System::Collections::Generic::Dictionary< K, DrValueWrapper > +{ +public: + value class DrEnumerator + { + public: + DrEnumerator(Dictionary< K, DrValueWrapper >^ dict) + { + m_iterator = dict->GetEnumerator(); + } + + K GetKey() + { + return m_iterator.Current.Key; + } + + V GetValue() + { + return m_iterator.Current.Value.T(); + } + + bool MoveNext() + { + return m_iterator.MoveNext(); + } + + private: + Dictionary< K, DrValueWrapper >::Enumerator m_iterator; + }; + + DrEnumerator GetDrEnumerator() + { + return DrEnumerator(this); + } + + int GetSize() + { + return Count; + } + + void Add(K key, V value) + { + DrValueWrapper v; + v.T() = value; + Add(key, v); + } + + bool TryGetValue(K key, /*out*/ V% value) + { + DrValueWrapper v; + bool found = TryGetValue(key, v); + if (found) + { + value = v.T(); + } + return found; + } + + void Replace(K key, V value) + { + bool removed = Remove(key); + DrAssert(removed); + Add(key, value); + } +}; + +typedef DrDictionaryForValueWrapper DrStringStringDictionary; +DRREF(DrStringStringDictionary); +#endif + +template DRCLASS(DrDictionaryForString) + : public System::Collections::Generic::Dictionary< K, DrValueWrapper > +{ +public: + value class DrEnumerator + { + public: + DrEnumerator(Dictionary< K, DrValueWrapper >^ dict) + { + m_iterator = dict->GetEnumerator(); + } + + K GetKey() + { + return m_iterator.Current.Key; + } + + DrString GetValue() + { + return m_iterator.Current.Value.T(); + } + + bool MoveNext() + { + return m_iterator.MoveNext(); + } + + private: + Dictionary< K, DrValueWrapper >::Enumerator m_iterator; + }; + + DrEnumerator GetDrEnumerator() + { + return DrEnumerator(this); + } + + int GetSize() + { + return Count; + } + + void Add(K key, DrString value) + { + DrValueWrapper v; + v.T() = value; + Add(key, v); + } + + bool TryGetValue(K key, /*out*/ DrString% value) + { + DrValueWrapper v; + bool found = TryGetValue(key, v); + if (found) + { + value = v.T(); + } + return found; + } + + void Replace(K key, DrString value) + { + bool removed = Remove(key); + DrAssert(removed); + Add(key, value); + } +}; + +typedef DrDictionaryForString DrStringStringDictionary; +DRREF(DrStringStringDictionary); +template ref class DrDictionaryForString; + + +#else + +#include +#include + +template DRBASECLASS(DrDictionary) +{ + typedef std::map Map; + +public: + class DrEnumerator + { + public: + DrEnumerator(Map* m) + { + m_iterator = m->begin(); + m_end = m->end(); + m_moved = false; + } + + const K& GetKey() + { + DrAssert(m_moved); + return m_iterator->first; + } + + V& GetValue() + { + DrAssert(m_moved); + return m_iterator->second; + } + + bool MoveNext() + { + if (m_moved == false) + { + m_moved = true; + } + else if (m_iterator != m_end) + { + ++m_iterator; + } + + if (m_iterator == m_end) + { + return false; + } + else + { + return true; + } + } + + private: + typename Map::iterator m_iterator; + typename Map::iterator m_end; + bool m_moved; + }; + + void Add(K key, V value) + { + bool inserted = + m_map.insert(std::make_pair(key, value)).second; + DrAssert(inserted); + } + + bool TryGetValue(K key, /*out*/ V& value) + { + Map::const_iterator i = m_map.find(key); + if (i == m_map.end()) + { + return false; + } + else + { + value = i->second; + return true; + } + } + + bool Remove(K key) + { + size_t nRemoved = m_map.erase(key); + if (nRemoved == 1) + { + return true; + } + else + { + DrAssert(nRemoved == 0); + return false; + } + } + + void Replace(K key, V value) + { + Map::iterator i = m_map.find(key); + DrAssert(i != m_map.end()); + i->second = value; + } + + DrEnumerator GetDrEnumerator() + { + return DrEnumerator(&m_map); + } + + int GetSize() + { + return (int) m_map.size(); + } + +protected: + Map m_map; +}; + +template DRBASECLASS(DrStringDictionary) +{ + typedef std::map Map; + +public: + class DrEnumerator + { + public: + DrEnumerator(Map* m) + { + m_iterator = m->begin(); + m_end = m->end(); + m_moved = false; + } + + const char* GetKey() + { + DrAssert(m_moved); + return m_iterator->first.c_str(); + } + + V& GetValue() + { + DrAssert(m_moved); + return m_iterator->second; + } + + bool MoveNext() + { + if (m_moved == false) + { + m_moved = true; + } + else if (m_iterator != m_end) + { + ++m_iterator; + } + + if (m_iterator == m_end) + { + return false; + } + else + { + return true; + } + } + + private: + typename Map::iterator m_iterator; + typename Map::iterator m_end; + bool m_moved; + }; + + void Add(const char* key, V value) + { + bool inserted = + m_map.insert(std::make_pair(std::string(key), value)).second; + DrAssert(inserted); + } + + bool TryGetValue(const char* key, /*out*/ V& value) + { + Map::const_iterator i = m_map.find(std::string(key)); + if (i == m_map.end()) + { + return false; + } + else + { + value = i->second; + return true; + } + } + + bool Remove(const char* key) + { + size_t nRemoved = m_map.erase(std::string(key)); + if (nRemoved == 1) + { + return true; + } + else + { + DrAssert(nRemoved == 0); + return false; + } + } + + DrEnumerator GetDrEnumerator() + { + return DrEnumerator(&m_map); + } + + int GetSize() + { + return (int) m_map.size(); + } + +protected: + Map m_map; +}; + +template DRCLASS(DrDictionaryForValueWrapper) : public DrDictionary +{ +}; + +template DRCLASS(DrDictionaryForString) : public DrDictionary +{ +}; + +typedef DrStringDictionary DrStringStringDictionary; +DRREF(DrStringStringDictionary); + +#endif diff --git a/GraphManager/shared/DrError.cpp b/GraphManager/shared/DrError.cpp new file mode 100644 index 0000000..583aa48 --- /dev/null +++ b/GraphManager/shared/DrError.cpp @@ -0,0 +1,72 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +DrError::DrError(HRESULT code, DrNativeString component, DrString explanation) +{ + m_code = code; + m_component = component; + m_explanation = explanation; +} + +DrError::DrError(HRESULT code, DrString component, DrString explanation) +{ + m_code = code; + m_component = component; + m_explanation = explanation; +} + +DrString DrError::ToShortText() +{ + DrString s; + if (m_explanation.GetString() == DrNull) + { + s.SetF("%s:%x:%s", m_component.GetChars(), m_code, DRERRORSTRING(m_code)); + } + else + { + s.SetF("%s:%x:%s. %s", m_component.GetChars(), m_code, DRERRORSTRING(m_code), m_explanation.GetChars()); + } + return s; +} + +void DrError::AddProvenance(DrErrorPtr previousError) +{ + if (m_errorProvenance == DrNull) + { + m_errorProvenance = DrNew DrErrorList(); + } + m_errorProvenance->Add(previousError); +} + +DrString DrError::ToShortText(DrErrorPtr errorOrNull) +{ + DrString s; + if (errorOrNull == DrNull) + { + s = "No error"; + } + else + { + s = errorOrNull->ToShortText(); + } + return s; +} diff --git a/GraphManager/shared/DrError.h b/GraphManager/shared/DrError.h new file mode 100644 index 0000000..e0d015f --- /dev/null +++ b/GraphManager/shared/DrError.h @@ -0,0 +1,104 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRDECLARECLASS(DrError); +DRREF(DrError); + +typedef DrArrayList DrErrorList; +DRAREF(DrErrorList,DrErrorRef); + +DRBASECLASS(DrError) +{ +public: + DrError(HRESULT code, DrNativeString component, DrString explanation); + DrError(HRESULT code, DrString component, DrString explanation); + + void AddProvenance(DrErrorPtr previousError); + + DrString ToShortText(); + static DrString ToShortText(DrErrorPtr errorOrNull); + + HRESULT m_code; + DrString m_component; + DrString m_explanation; + + DrErrorListRef m_errorProvenance; +}; +DRREF(DrError); + +#define FACILITY_COSMOS 777 +#define FACILITY_DRYAD 778 +#define FACILITY_DSC 779 + +#define DRYAD_ERROR(n) ((HRESULT)(0x80000000 + (FACILITY_DRYAD << 16) + n)) +#define COSMOS_ERROR(n) ((HRESULT)(0x80000000 + (FACILITY_COSMOS << 16) + n)) +#define DSC_ERROR(n) ((HRESULT)(0x80000000 + (FACILITY_DSC << 16) + n)) + +const HRESULT DrError_BadMetaData = DRYAD_ERROR (0x0001); +const HRESULT DrError_InvalidCommand = DRYAD_ERROR (0x0002); +const HRESULT DrError_VertexReceivedTermination = DRYAD_ERROR (0x0003); +const HRESULT DrError_InvalidChannelURI = DRYAD_ERROR (0x0004); +const HRESULT DrError_ChannelOpenError = DRYAD_ERROR (0x0005); +const HRESULT DrError_ChannelRestartError = DRYAD_ERROR (0x0006); +const HRESULT DrError_ChannelWriteError = DRYAD_ERROR (0x0007); +const HRESULT DrError_ChannelReadError = DRYAD_ERROR (0x0008); +const HRESULT DrError_ItemParseError = DRYAD_ERROR (0x0009); +const HRESULT DrError_ItemMarshalError = DRYAD_ERROR (0x0010); +const HRESULT DrError_BufferHole = DRYAD_ERROR (0x0011); +const HRESULT DrError_ItemHole = DRYAD_ERROR (0x0012); +const HRESULT DrError_ChannelRestart = DRYAD_ERROR (0x0013); +const HRESULT DrError_ChannelAbort = DRYAD_ERROR (0x0014); +const HRESULT DrError_VertexRunning = DRYAD_ERROR (0x0015); +const HRESULT DrError_VertexCompleted = DRYAD_ERROR (0x0016); +const HRESULT DrError_VertexError = DRYAD_ERROR (0x0017); +const HRESULT DrError_ProcessingError = DRYAD_ERROR (0x0018); +const HRESULT DrError_VertexInitialization = DRYAD_ERROR (0x0019); +const HRESULT DrError_ProcessingInterrupted = DRYAD_ERROR (0x001a); +const HRESULT DrError_VertexChannelClose = DRYAD_ERROR (0x001b); +const HRESULT DrError_AssertFailure = DRYAD_ERROR (0x001c); +const HRESULT DrError_ExternalChannel = DRYAD_ERROR (0x001d); +const HRESULT DrError_AlreadyInitialized = DRYAD_ERROR (0x001e); +const HRESULT DrError_DuplicateVertices = DRYAD_ERROR (0x001f); +const HRESULT DrError_ComposeRHSNeedsInput = DRYAD_ERROR (0x0020); +const HRESULT DrError_ComposeLHSNeedsOutput = DRYAD_ERROR (0x0021); +const HRESULT DrError_ComposeStagesMustBeDifferent = DRYAD_ERROR (0x0022); +const HRESULT DrError_ComposeStageEmpty = DRYAD_ERROR (0x0023); +const HRESULT DrError_VertexNotInGraph = DRYAD_ERROR (0x0024); +const HRESULT DrError_HardConstraintCannotBeMet = DRYAD_ERROR (0x0025); +const HRESULT DrError_XComputeError = DRYAD_ERROR (0x0026); +const HRESULT DrError_CohortShutdown = DRYAD_ERROR (0x0027); +const HRESULT DrError_Unexpected = DRYAD_ERROR (0x0028); +const HRESULT DrError_DependentVertexFailure = DRYAD_ERROR (0x0029); +const HRESULT DrError_BadOutputReported = DRYAD_ERROR (0x002a); +const HRESULT DrError_InputUnavailable = DRYAD_ERROR (0x002b); + +const HRESULT DrError_EndOfStream = COSMOS_ERROR (0x000b); + +const HRESULT DrError_CannotConnectToDsc = DSC_ERROR (0x0100); +const HRESULT DrError_DscOperationFailed = DSC_ERROR (0x0101); +const HRESULT DrError_FailedToDeleteFileset = DSC_ERROR (0x0102); +const HRESULT DrError_FailedToCreateFileset = DSC_ERROR (0x0103); +const HRESULT DrError_FailedToAddFile = DSC_ERROR (0x0104); +const HRESULT DrError_FailedToSetMetadata = DSC_ERROR (0x0105); +const HRESULT DrError_FailedToSealFileset = DSC_ERROR (0x0106); +const HRESULT DrError_FailedToSetLease = DSC_ERROR (0x0107); +const HRESULT DrError_FailedToOpenFileset = DSC_ERROR (0x0108); diff --git a/GraphManager/shared/DrErrorInternal.h b/GraphManager/shared/DrErrorInternal.h new file mode 100644 index 0000000..57eb895 --- /dev/null +++ b/GraphManager/shared/DrErrorInternal.h @@ -0,0 +1,35 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +typedef DrDictionaryForString DrErrorDictionary; +DRREF(DrErrorDictionary); + +DRCLASS(DrErrorText) +{ +public: + static void Initialize(); + static void Discard(); + static DrString GetErrorText(HRESULT err); + +private: + static DrErrorDictionaryRef s_dictionary; +}; diff --git a/GraphManager/shared/DrFileWriter.cpp b/GraphManager/shared/DrFileWriter.cpp new file mode 100644 index 0000000..bb4f55e --- /dev/null +++ b/GraphManager/shared/DrFileWriter.cpp @@ -0,0 +1,208 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + + +DrFileWriter::DrFileWriter() +{ + m_fileHandle = INVALID_HANDLE_VALUE; + m_dataBufferSize = 64 * 1024; + m_bufferedData = new char[m_dataBufferSize]; + m_bufferedDataLength = 0; + +} + +#ifdef _MANAGED +DrFileWriter::~DrFileWriter() +{ + this->!DrFileWriter(); +} + +DrFileWriter::!DrFileWriter() +{ + if (m_fileHandle != INVALID_HANDLE_VALUE) + { + FlushInternal(); + CloseHandle(m_fileHandle); + } + delete [] m_bufferedData; + +} +#else +DrFileWriter::~DrFileWriter() +{ + if (m_fileHandle != INVALID_HANDLE_VALUE) + { + FlushInternal(); + CloseHandle(m_fileHandle); + } + delete [] m_bufferedData; + +} +#endif + +bool DrFileWriter::ReOpen(DrString fileName) +{ + DrAutoCriticalSection acs(this); + + DrAssert(m_fileHandle == INVALID_HANDLE_VALUE); + + m_fileHandle = CreateFileA( + fileName.GetChars(), + GENERIC_WRITE, + FILE_SHARE_READ, + NULL, + OPEN_ALWAYS, + FILE_ATTRIBUTE_NORMAL, + NULL); + + if (m_fileHandle == INVALID_HANDLE_VALUE) + { + DrLogW("File open failed for %s with error %s", fileName.GetChars(), + DRERRORSTRING(HRESULT_FROM_WIN32(GetLastError()))); + return false; + } + + return true; +} + +bool DrFileWriter::Open(DrString fileName) +{ + DrAutoCriticalSection acs(this); + + DrAssert(m_fileHandle == INVALID_HANDLE_VALUE); + + m_fileHandle = CreateFileA( + fileName.GetChars(), + GENERIC_WRITE, + FILE_SHARE_READ, + NULL, + CREATE_ALWAYS, + FILE_ATTRIBUTE_NORMAL, + NULL); + + if (m_fileHandle == INVALID_HANDLE_VALUE) + { + DrLogW("File open failed for %s with error %s", fileName.GetChars(), + DRERRORSTRING(HRESULT_FROM_WIN32(GetLastError()))); + return false; + } + + return true; +} + +void DrFileWriter::FlushInternal() +{ + if (m_fileHandle != INVALID_HANDLE_VALUE && m_bufferedDataLength > 0) + { + DWORD bytesWritten; + WriteFile(m_fileHandle, m_bufferedData, m_bufferedDataLength, &bytesWritten, NULL); + m_bufferedDataLength = 0; + } +} + +void DrFileWriter::Flush() +{ + DrAutoCriticalSection acs(this); + + FlushInternal(); +} + +void DrFileWriter::Close() +{ + DrAutoCriticalSection acs(this); + + if (m_fileHandle != INVALID_HANDLE_VALUE) + { + FlushInternal(); + CloseHandle(m_fileHandle); + m_fileHandle = INVALID_HANDLE_VALUE; + } +} + +void DrFileWriter::Append(const char *data, int dataLength) +{ + DrAutoCriticalSection acs(this); + + if (m_fileHandle != INVALID_HANDLE_VALUE) + { + if (m_bufferedDataLength + dataLength > m_dataBufferSize) + { + FlushInternal(); + } + + if (m_bufferedDataLength + dataLength > m_dataBufferSize) + { + DrAssert(m_bufferedDataLength == 0); + m_dataBufferSize = dataLength; + delete [] m_bufferedData; + m_bufferedData = new char[m_dataBufferSize]; + } + + memcpy(m_bufferedData + m_bufferedDataLength, data, dataLength); + m_bufferedDataLength += dataLength; + } +} + +typedef DrArrayList DrFWArray; +DRAREF(DrFWArray,DrFileWriterRef); + +DRCLASS(DrSFWInternal) +{ +public: + static DrCritSecRef s_cs; + static DrFWArrayRef s_file; +}; + +#ifndef _MANAGED +DrCritSecRef DrSFWInternal::s_cs; +DrFWArrayRef DrSFWInternal::s_file; +#endif + +void DrStaticFileWriters::Initialize() +{ + DrSFWInternal::s_cs = DrNew DrCritSec(); + DrSFWInternal::s_file = DrNew DrFWArray(); +} + +void DrStaticFileWriters::AddWriter(DrFileWriterPtr writer) +{ + DrAutoCriticalSection acs(DrSFWInternal::s_cs); + + DrSFWInternal::s_file->Add(writer); +} + +void DrStaticFileWriters::FlushWriters() +{ + DrAutoCriticalSection acs(DrSFWInternal::s_cs); + + int i; + for (i=0; iSize(); ++i) + { + DrSFWInternal::s_file[i]->Flush(); + } +} + +void DrStaticFileWriters::Discard() +{ + DrSFWInternal::s_cs = DrNull; + DrSFWInternal::s_file = DrNull; +} \ No newline at end of file diff --git a/GraphManager/shared/DrFileWriter.h b/GraphManager/shared/DrFileWriter.h new file mode 100644 index 0000000..c537792 --- /dev/null +++ b/GraphManager/shared/DrFileWriter.h @@ -0,0 +1,62 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +/* when there's time, I will write a high-performance file writer so we can get decent logging performance, + but for now we are using regular WriteFile */ + +#include + +DRCLASS(DrFileWriter) : public DrCritSec +{ +public: + DrFileWriter(); + ~DrFileWriter(); +#ifdef _MANAGED + !DrFileWriter(); +#endif + + bool Open(DrString fileName); + bool ReOpen(DrString fileName); + void Flush(); + void Close(); + + void Append(const char* data, int dataLength); + +private: + void FlushInternal(); + + HANDLE m_fileHandle; + int m_dataBufferSize; + char* m_bufferedData; + int m_bufferedDataLength; + +}; +DRREF(DrFileWriter); + +DRCLASS(DrStaticFileWriters) +{ +public: + static void Initialize(); + static void AddWriter(DrFileWriterPtr writer); + static void FlushWriters(); + static void Discard(); +}; \ No newline at end of file diff --git a/GraphManager/shared/DrLogging.cpp b/GraphManager/shared/DrLogging.cpp new file mode 100644 index 0000000..176f70d --- /dev/null +++ b/GraphManager/shared/DrLogging.cpp @@ -0,0 +1,763 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrShared.h" + +#include "DrErrorInternal.h" + +#include + +#ifdef _MANAGED +QueryAssertException::QueryAssertException(System::String ^message, System::Diagnostics::StackTrace ^st) : System::Exception(message) +{ + HResult = DrError_AssertFailure; + m_stackTrace = st; +} +#endif + +/* there are race conditions in the accessor functions for these global bools, however +we are ignoring them since their effects will merely be to change the number of log +statements printed if the logging level is dynamically changed during a program run, +which is benign compared to the overhead of acquiring a lock while testing for logging +eing enabled */ + +static DrLogType s_loggingType = DrLog_Warning; + +void DrLogging::SetLoggingLevel(DrLogType type) +{ + s_loggingType = type; +} + +bool DrLogging::Enabled(DrLogType type) +{ + return ((s_loggingType & type) == type); +} + +DRCLASS(DrLoggingInternal) +{ +public: + static void Initialize(); + static void FlushLogs(); + DrLoggingInternal(); + +#ifdef _MANAGED + static __declspec(noreturn) void Terminate(UINT exitCode, QueryAssertException ^e); +#else + static __declspec(noreturn) void Terminate(UINT exitCode); +#endif + static void Stop(); + + static void Append(const char* data, int dataLength); + + static void RolloverLog(); + +private: + + static void FlushThread(); + static void ForkFlushThread(); + static void WaitForFlushThread(); + + static DrFileWriterRef s_logFile; + static volatile int s_flag; + static const DWORD s_maxLogFileBytes = 10 * 1024 * 1024; // 10 MBytes + static volatile DWORD s_currentLogFileBytes; + static int s_currentLogFileCounter; + static DrString s_logFileName; + static DrString s_archiveLogFileFormat; + + static DrCritSecRef m_rolloverLock; +#ifdef _MANAGED + static System::Threading::Thread^ s_flushThread; +#else + static unsigned __stdcall FlushThreadFunc(void* arg); + static HANDLE s_flushThread; +#endif +}; + +#ifndef _MANAGED +DrFileWriterRef DrLoggingInternal::s_logFile; +volatile int DrLoggingInternal::s_flag; +volatile DWORD DrLoggingInternal::s_currentLogFileBytes = 0; +int DrLoggingInternal::s_currentLogFileCounter = 0; +DrString DrLoggingInternal::s_logFileName = "default.log"; +DrString DrLoggingInternal::s_archiveLogFileFormat = "default%03d.log"; +DrCritSecRef DrLoggingInternal::m_rolloverLock = DrNull; +HANDLE DrLoggingInternal::s_flushThread; +#endif + +DrLoggingInternal::DrLoggingInternal() +{ + // Nothing needed +} + +#pragma warning (push) +// _wfreopen_s tries to open the file with exclusive access, which fails since HpcQueryGraphManager.exe already has it open +#pragma warning (disable: 4996 ) // _wfreopen : This function or variable may be unsafe. Consider using _wfreopen_s instead. +void DrLoggingInternal::Initialize() +{ + WCHAR szOut[MAX_PATH + 1] = {0}; + WCHAR szDir[MAX_PATH + 1] = {0}; + + s_currentLogFileBytes = 0; + s_currentLogFileCounter = 0; + s_logFileName = "default.log"; + s_archiveLogFileFormat = "default%03d.log"; + m_rolloverLock = DrNew DrCritSec(); + + if (GetCurrentDirectory(MAX_PATH, szDir) > 0) + { + FILE *f = 0; + if (_snwprintf_s(szOut, MAX_PATH, MAX_PATH, L"%s\\stdout.txt", szDir) != -1) + { + f = _wfreopen(szOut, L"a", stdout); + } + } + + s_flag = 0; + s_logFile = DrNew DrFileWriter(); + if (s_logFile->Open(s_logFileName)) + { + DrStaticFileWriters::AddWriter(s_logFile); + } + else + { + fprintf(stderr, "Failed to open log file %s: no logging!\n", s_logFileName.GetChars()); + } + + ForkFlushThread(); + + DrLogWithType(DrLog_Info)("Logging started"); +} +#pragma warning(pop) + +void DrLoggingInternal::FlushThread() +{ + do + { + Sleep(1000); + FlushLogs(); + } while (s_flag == 0); +} + +void DrLoggingInternal::FlushLogs() +{ + // We don't want to flush if we're rolling over the logs + DrAutoCriticalSection acs(m_rolloverLock); + + fflush(stdout); + fflush(stderr); + DrStaticFileWriters::FlushWriters(); +} + +#ifdef _MANAGED +__declspec(noreturn) void DrLoggingInternal::Terminate(UINT exitCode, QueryAssertException ^e) +{ + DrLogWithType(DrLog_Info)("------------- Terminating the process ------------ ExitCode=%u", exitCode); + + Stop(); + + if (exitCode == 0) + { + // + // application says that everything completed normally; do clean shutdown + // + exit(0); + } + else if (e != DrNull) + { + // + // something went wrong; we flushed the logs and must commit a suicide now + // as we cannot do clean shutdown as it may hang (e.g. because data structures were corrupted, + // a thread failed while holding a lock, etc.) + // + + throw e; + } + else + { + // + // We weren't passed an exception, so just terminate the process + // + + TerminateProcess(GetCurrentProcess(), exitCode); + } +} +#else +__declspec(noreturn) void DrLoggingInternal::Terminate(UINT exitCode) +{ + DrLogI("------------- Terminating the process ------------ ExitCode=%u", exitCode); + + Stop(); + + if (exitCode == 0) + { + // + // application says that everything completed normally; do clean shutdown + // + exit(0); + } + else + { + // + // something went wrong; we flushed the logs and must commit a suicide now + // as we cannot do clean shutdown as it may hung (e.g. because data structures were corrupted, + // a thread failed while holding a lock, etc.) + // + TerminateProcess(GetCurrentProcess(), exitCode); + } + +} +#endif + +void DrLoggingInternal::Stop() +{ + FlushLogs(); + + { + DrAutoCriticalSection acs(m_rolloverLock); + + s_flag = 1; + } + WaitForFlushThread(); + s_logFile = DrNull; +} + +void DrLoggingInternal::RolloverLog() +{ + DrAutoCriticalSection acs(m_rolloverLock); + + ++s_currentLogFileCounter; + s_logFile->Close(); + + s_currentLogFileBytes = 0; + DrString archiveFileName; + archiveFileName.SetF(s_archiveLogFileFormat.GetChars(), s_currentLogFileCounter); + if (MoveFileA(s_logFileName.GetChars(), archiveFileName.GetChars())) + { + if (!s_logFile->Open(s_logFileName)) + { + fprintf(stderr, "Failed to open log file %s: no logging!\n", s_logFileName.GetChars()); + } + } + else + { + fprintf(stderr, "Failed to archive log file %s to %s with error %s\n", s_logFileName.GetChars(), archiveFileName.GetChars(), + DRERRORSTRING(HRESULT_FROM_WIN32(GetLastError()))); + if (!s_logFile->ReOpen(s_logFileName)) + { + fprintf(stderr, "Failed to reopen log file %s: no logging!\n", s_logFileName.GetChars()); + } + } +} + +void DrLoggingInternal::Append(const char *data, int dataLength) +{ + // TODO: this might slow logging down, but we don't want to miss any log stms + DrAutoCriticalSection acs(m_rolloverLock); + + if (s_flag == 1) return; + + if (s_currentLogFileBytes + dataLength > s_maxLogFileBytes) + { + RolloverLog(); + } + s_logFile->Append(data, dataLength); + s_currentLogFileBytes += dataLength; +} + +#include +static int s_miniDumpTimeoutMilliseconds = 15 * 60 * 1000; +static LONG volatile s_startedMiniDump = 0; + +// +// Only include data sections for our assemblies and ntdll +// + +const WCHAR* c_szIncludeModules[] = +{ + L"HpcQueryGraphManager", + L"Microsoft.Hpc.Query.GraphManager", + L"Microsoft.Hpc.Query.ClusterAdapter", + L"HpcQueryNativeClusterAdapter", + L"Microsoft.Hpc.Dsc", + L"HpcDscNativeClient" +}; + +static BOOL IncludeDataSection(const WCHAR* pModule) +{ + if (pModule == NULL) + { + return FALSE; + } + + WCHAR szFileName[_MAX_FNAME] = L""; + _wsplitpath_s(pModule, NULL, 0, NULL, 0, szFileName, _MAX_FNAME, NULL, 0); + + DWORD numMods = sizeof(c_szIncludeModules) / sizeof(c_szIncludeModules[0]); + for (DWORD i = 0; i < numMods; i++) + { + if (_wcsicmp(c_szIncludeModules[i], szFileName) == 0) + { + return TRUE; + } + } + + return FALSE; +} + +// +// Callback for MiniDumpWriteDump +// +static BOOL MiniDumpCallback( + PVOID, + const PMINIDUMP_CALLBACK_INPUT pInput, + PMINIDUMP_CALLBACK_OUTPUT pOutput +) +{ + BOOL bRet = FALSE; + + + // Check parameters + if( pInput == 0 ) + return FALSE; + + if( pOutput == 0 ) + return FALSE; + + + // Process the callbacks + switch( pInput->CallbackType ) + { + case IncludeModuleCallback: + { + // Include the module into the dump + bRet = TRUE; + } + break; + + case IncludeThreadCallback: + { + // Skip the minidump thread + if (pInput->Thread.ThreadId == ::GetCurrentThreadId()) + { + bRet = FALSE; + } + else + { + bRet = TRUE; + } + } + break; + + case ModuleCallback: + { + if( pOutput->ModuleWriteFlags & ModuleWriteDataSeg ) + { + if( !IncludeDataSection( pInput->Module.FullPath ) ) + { + pOutput->ModuleWriteFlags &= (~ModuleWriteDataSeg); + } + } + bRet = TRUE; + } + break; + + case ThreadCallback: + { + // Include all thread information into the minidump + bRet = TRUE; + } + break; + + case ThreadExCallback: + { + // Include this information + bRet = TRUE; + } + break; + + case MemoryCallback: + { + // We do not include any information here -> return FALSE + bRet = FALSE; + } + break; + + case CancelCallback: + break; + } + + return bRet; + +} + + +// +// Write out a mini dump +// +void DrLogging::WriteMiniDumpImpl() +{ + // If this function actually causes exceptions itself, we will be + // doomed with deadlock. So, we catch all the possible exceptions + // generated by this function and do nothing. + try + { + CHAR szDumpFile[MAX_PATH + 1] = {0}; + if (!GetCurrentDirectoryA(MAX_PATH, szDumpFile)) + { + fprintf(stderr, "Failed to get current directory: %s\n", DRERRORSTRING(GetLastError())); + return; + } + + strcat_s(szDumpFile, MAX_PATH, "\\minidump.dmp"); + + // write the actual dump + HANDLE hMiniDumpFile = CreateFileA ( + szDumpFile, + GENERIC_READ | GENERIC_WRITE, + FILE_SHARE_WRITE | FILE_SHARE_READ, + 0, + CREATE_ALWAYS, + 0, + 0 + ); + + if (hMiniDumpFile != INVALID_HANDLE_VALUE) + { + /* don't log here in case it's logging that's causing the problem */ + + MINIDUMP_CALLBACK_INFORMATION mci = {0}; + mci.CallbackRoutine = (MINIDUMP_CALLBACK_ROUTINE)MiniDumpCallback; + mci.CallbackParam = NULL; + + MINIDUMP_TYPE mdt = (MINIDUMP_TYPE) ( + MiniDumpWithFullMemory | + MiniDumpWithHandleData | + MiniDumpWithUnloadedModules | + MiniDumpWithThreadInfo | + MiniDumpWithDataSegs ); + + BOOL err = MiniDumpWriteDump( + GetCurrentProcess(), + GetCurrentProcessId(), + hMiniDumpFile, + mdt, + NULL, + NULL, + &mci + ); + if (!err) + { + fprintf(stderr, "Failed to write minidump last Error %s\n", DRERRORSTRING(GetLastError())); + } + else + { + CHAR szComputer[MAX_COMPUTERNAME_LENGTH + 1] = {0}; + DWORD cchSize = MAX_COMPUTERNAME_LENGTH + 1; + + // enable for DNS hostnames + //CHAR szComputer[DNS_MAX_LABEL_BUFFER_LENGTH] = {0}; + //DWORD cchSize = DNS_MAX_LABEL_BUFFER_LENGTH; + + fprintf(stderr, "Wrote minidump to %s", szDumpFile); + if (GetComputerNameA(szComputer, &cchSize)) + //if (GetComputerNameExA(ComputerNameDnsHostname, szComputer, &cchSize)) + { + fprintf(stderr, " on node %s", szComputer); + } + fprintf(stderr, "\n"); + } + + CloseHandle(hMiniDumpFile); + } + else + { + fprintf(stderr, "Failed to open dump file %s error %s\n", szDumpFile, DRERRORSTRING(GetLastError())); + } + } + catch (...) + { + // do nothing + } + + fflush(stderr); + + return; +} + +#ifdef _MANAGED + +void DrLoggingInternal::ForkFlushThread() +{ + s_flushThread = DrNew System::Threading::Thread( + DrNew System::Threading::ThreadStart(&FlushThread)); + s_flushThread->Start(); +} + +void DrLoggingInternal::WaitForFlushThread() +{ + s_flushThread->Join(); +} + +void DrLogging::MiniDumpThread() +{ + DrLogging::WriteMiniDumpImpl(); +} + +#undef GetEnvironmentVariable + +bool DrLogging::WriteMiniDump() +{ + //Since writing a dump is a one time occurrence dont write it if its already + //started. + if (InterlockedExchange(&s_startedMiniDump, 1) != 0) + { + // The dump is already started + return false; + } + + // Check to see whether we are even supposed to write a minidump + System::String ^dumpEnvVal = System::Environment::GetEnvironmentVariable(L"HPCQUERY_GM_CREATEDUMP"); + if (System::String::IsNullOrEmpty(dumpEnvVal) || dumpEnvVal->Equals(L"0")) + return false; + + // Write the minidump in a new thread so that we capture the state of the + // current thread correctly + System::Threading::Thread ^dumpThread = + DrNew System::Threading::Thread( + DrNew System::Threading::ThreadStart(&MiniDumpThread)); + dumpThread->Start(); + if (!dumpThread->Join(s_miniDumpTimeoutMilliseconds)) + { + // Timed out waiting for minidump thread + return false; + } + + return true; +} + + +#else + +#include + +static unsigned __stdcall MiniDumpThreadFunc(void* /* unused arg */) +{ +// WriteMiniDumpImpl(NULL); + return 0; +} + +unsigned __stdcall DrLoggingInternal::FlushThreadFunc(void* /* unused arg */) +{ + DrLoggingInternal::FlushThread(); + return 0; +} + +#pragma warning (push) +#pragma warning (disable: 4505) // Unreferenced local function parameter +static void ForkMiniDumpThread() +{ + fprintf(stderr, "About to create dump thread\n"); + unsigned threadAddr; + HANDLE handle = + (HANDLE) ::_beginthreadex(NULL, + 0, + MiniDumpThreadFunc, + NULL, + 0, + &threadAddr); + assert(handle != 0); + fprintf(stderr, "Waiting for dump thread\n"); + ::WaitForSingleObject(handle, INFINITE); + + ::CloseHandle(handle); + fprintf(stderr, "Finished waiting for dump thread\n"); +} +#pragma warning (pop) + +void DrLoggingInternal::ForkFlushThread() +{ + unsigned threadAddr; + s_flushThread = + (HANDLE) ::_beginthreadex(NULL, + 0, + FlushThreadFunc, + NULL, + 0, + &threadAddr); + assert(s_flushThread != 0); +} + +void DrLoggingInternal::WaitForFlushThread() +{ + ::WaitForSingleObject(s_flushThread, INFINITE); + ::CloseHandle(s_flushThread); +} + +#endif + +#ifndef _MANAGED +#pragma warning (push) +#pragma warning (disable: 4715 ) // 'Logger::LogAndExitProcess' : not all control paths return a value +#pragma warning (disable: 4702 ) // unreachable code +#pragma warning (disable: 4100 ) // unreferenced formal parameter +static LONG WINAPI LogAndExitProcess(EXCEPTION_POINTERS *exceptionPointers) +{ + fprintf(stderr, "An unhandled exception was thrown -- exiting process\n"); + DrLoggingInternal::FlushLogs(); + + //if (WriteMiniDumpImpl(exceptionPointers)) + { + DrLoggingInternal::Terminate(1); + } + + return (EXCEPTION_CONTINUE_SEARCH); +} +#pragma warning (pop) +#endif + +void DrLogHelper::operator()(const char* format, ...) +{ + va_list args; + va_start(args, format); + + DrString s; + s.VSetF(format, args); + + SYSTEMTIME utc, local; + FILETIME ft; + GetSystemTimeAsFileTime(&ft); + FileTimeToSystemTime(&ft, &utc); + SystemTimeToTzSpecificLocalTime(NULL, &utc, &local); + + // For DrLog_Assert and DrLog_Error, write the message to stderr + // so that it is displayed in task's output + bool logToConsole = false; + + char initial = 0; + switch (m_type) + { + case DrLog_Assert: + initial = 'a'; + logToConsole = true; + break; + case DrLog_Error: + initial = 'e'; + logToConsole = true; + break; + case DrLog_Warning: + initial = 'w'; + break; + case DrLog_Info: + initial = 'i'; + break; + case DrLog_Debug: + initial = 'd'; + break; + } + + DrString logEntry; + logEntry.SetF( + "%c," + "%02d/%02d/%04d %02d:%02d:%02d.%03u," + "TID=%d,%s,%s:%d,%s\r\n", + initial, + local.wMonth, + local.wDay, + local.wYear, + local.wHour, + local.wMinute, + local.wSecond, + local.wMilliseconds, + GetCurrentThreadId(), + m_function, m_file, m_line, + s.GetChars() + ); + + + if (logToConsole) + { + fprintf(stderr, s.GetChars()); + fflush(stderr); + } + + if (m_type != DrLog_Assert) + { + DrLoggingInternal::Append(logEntry.GetChars(), logEntry.GetCharsLength()); + } + else + { +#ifdef _MANAGED + // Get a stack trace, excluding the current frame + System::Diagnostics::StackTrace ^st = gcnew System::Diagnostics::StackTrace(1, true); + char *pszStackTrace = (char*)(void*)System::Runtime::InteropServices::Marshal::StringToHGlobalAnsi(st->ToString()); + + // Log it + logEntry = logEntry.AppendF("%s", pszStackTrace); + DrLoggingInternal::Append(logEntry.GetChars(), logEntry.GetCharsLength()); + + // Terminate the graph manager + DrLoggingInternal::Terminate((UINT)DrError_AssertFailure, gcnew QueryAssertException(s.GetString(), st)); +#else + +#endif + } +} + +void DrLogging::Initialize() +{ +#ifdef _DEBUG_DRREF + InitializeCriticalSection(&DrRefCounter::s_debugCS); +#endif + + DrStaticFileWriters::Initialize(); + DrLoggingInternal::Initialize(); + DrErrorText::Initialize(); + +#ifdef _MANAGED +#else + ::SetUnhandledExceptionFilter(LogAndExitProcess); +#endif +} + +void DrLogging::ShutDown(UINT code) +{ + DrLogWithType(DrLog_Info)("------------- Shutting down logging ------------ ExitCode=%u", code); + DrLoggingInternal::Stop(); + DrStaticFileWriters::Discard(); + DrErrorText::Discard(); +} + +void DrLogging::ShutDown(int code) +{ + DrLogWithType(DrLog_Info)("------------- Shutting down logging ------------ ExitCode=%d", code); + DrLoggingInternal::Stop(); + DrStaticFileWriters::Discard(); + DrErrorText::Discard(); +} + +bool DrLogging::DebuggerIsPresent() +{ +#ifdef _MANAGED + return System::Diagnostics::Debugger::IsAttached; +#else + return (::IsDebuggerPresent()) ? true : false; +#endif +} \ No newline at end of file diff --git a/GraphManager/shared/DrLogging.h b/GraphManager/shared/DrLogging.h new file mode 100644 index 0000000..3a58037 --- /dev/null +++ b/GraphManager/shared/DrLogging.h @@ -0,0 +1,99 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#define DrLogWithType(_x) DrLogHelper(_x,__FILE__,__FUNCTION__,__LINE__) + +#define DRMAKELOGTYPE(_type,_initial) \ + + +#define DrLogD if (DrLogging::Enabled(DrLog_Debug)) DrLogWithType(DrLog_Debug) +#define DrLogI if (DrLogging::Enabled(DrLog_Info)) DrLogWithType(DrLog_Info) +#define DrLogW if (DrLogging::Enabled(DrLog_Warning)) DrLogWithType(DrLog_Warning) +#define DrLogE if (DrLogging::Enabled(DrLog_Error)) DrLogWithType(DrLog_Error) +#define DrLogA if (DrLogging::Enabled(DrLog_Assert)) DrLogWithType(DrLog_Assert) + +DRPUBLICENUM(DrLogType) +{ + DrLog_Off = 0, + DrLog_Assert = 1, + DrLog_Error = 3, + DrLog_Warning = 7, + DrLog_Info = 15, + DrLog_Debug = 31 +}; + +DRCLASS(DrLogging) +{ +public: + static void Initialize(); + static void ShutDown(UINT code); + static void ShutDown(int code); + static void SetLoggingLevel(DrLogType type); + static bool Enabled(DrLogType type); + + static bool DebuggerIsPresent(); + + static bool WriteMiniDump(); + +private: + static void MiniDumpThread(); + static void WriteMiniDumpImpl(); + +}; + +DRCLASS(DrLogHelper) +{ +public: + DrLogHelper(DrLogType type, const char* file, const char* function, int line) + { + m_type = type; + m_file = file; + m_function = function; + m_line = line; + } + + void operator()(const char* format, ...); + +private: + DrLogType m_type; + const char* m_file; + const char* m_function; + int m_line; +}; + + +#ifdef _MANAGED + +DRCLASS(QueryAssertException) : System::Exception +{ +public: + QueryAssertException(System::String ^message, System::Diagnostics::StackTrace ^st); + + virtual property System::String^ StackTrace + { + System::String^ get() override {return m_stackTrace->ToString(); } + } + +private: + System::Diagnostics::StackTrace ^m_stackTrace; +}; +#endif \ No newline at end of file diff --git a/GraphManager/shared/DrMultiMap.h b/GraphManager/shared/DrMultiMap.h new file mode 100644 index 0000000..5f134c2 --- /dev/null +++ b/GraphManager/shared/DrMultiMap.h @@ -0,0 +1,88 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#ifdef _MANAGED + +#include + +template DRBASECLASS(DrMultiMap) +{ + typedef cliext::multimap Map; +#else + +#include + +template DRBASECLASS(DrMultiMap) +{ + typedef std::multimap Map; +#endif + +public: + typedef typename Map::iterator Iter; + + Iter Insert(K key, V value) + { +#ifdef _MANAGED + return m_map.insert(Map::make_value(key, value)); +#else + return m_map.insert(std::make_pair(key, value)); +#endif + } + + Iter Find(K key) + { + return m_map.find(key); + } + + Iter Erase(Iter i) + { + return m_map.erase(i); + } + + int Erase(K key) + { + return (int) m_map.erase(key); + } + + Iter Begin() + { + return m_map.begin(); + } + + Iter End() + { + return m_map.end(); + } + + int GetSize() + { + return (int) m_map.size(); + } + + void Clear() + { + m_map.clear(); + } + +protected: + Map m_map; +}; diff --git a/GraphManager/shared/DrRef.cpp b/GraphManager/shared/DrRef.cpp new file mode 100644 index 0000000..4ceb3c3 --- /dev/null +++ b/GraphManager/shared/DrRef.cpp @@ -0,0 +1,126 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrShared.h" + +#ifdef _MANAGED + +#else + +#define DRREF_MAGIC_CONSTRUCTOR_VALUE (-666) + +#ifdef _DEBUG_DRREF +std::set DrRefCounter::s_refsAllocated; +std::map DrRefCounter::s_arrayStorage; +CRITICAL_SECTION DrRefCounter::s_debugCS; +#endif + +DrRefCounter::DrRefCounter() : m_iRefCount( DRREF_MAGIC_CONSTRUCTOR_VALUE ) +{ +#ifdef _DEBUG_DRREF + EnterCriticalSection(&s_debugCS); + bool inserted = s_refsAllocated.insert(this).second; + DrAssert(inserted); + LeaveCriticalSection(&s_debugCS); +#endif +} + +DrRefCounter::~DrRefCounter() +{ + DrAssert(m_iRefCount == 0); +#ifdef _DEBUG_DRREF + DrAssert(m_holders.empty()); + EnterCriticalSection(&s_debugCS); + size_t nRemoved = s_refsAllocated.erase(this); + DrAssert(nRemoved == 1); + LeaveCriticalSection(&s_debugCS); +#endif +} + +void DrRefCounter::FreeMemory() // Called when the refcount becomes zero. +{ + delete this; +} + + +#ifdef _DEBUG_DRREF +void DrRefCounter::IncRef(void* h) +{ + EnterCriticalSection(&s_debugCS); + ++m_iRefCount; + if (m_iRefCount <= 1) + { + // this is the first assignment of a newly constructed object */ + DrAssert(m_iRefCount == (DRREF_MAGIC_CONSTRUCTOR_VALUE + 1)); + m_iRefCount = 1; + } + bool inserted = m_holders.insert(h).second; + DrAssert(inserted); + DrAssert(m_holders.size() == (size_t) m_iRefCount); + LeaveCriticalSection(&s_debugCS); +} +#else +void DrRefCounter::IncRef() +{ + LONG i; + i = InterlockedIncrement(&m_iRefCount); + if (i <= 1) + { + // this is the first assignment of a newly constructed object */ + DrAssert(i == (DRREF_MAGIC_CONSTRUCTOR_VALUE + 1)); + m_iRefCount = 1; + } +} +#endif + +#ifdef _DEBUG_DRREF +void DrRefCounter::DecRef(void* h) +{ + EnterCriticalSection(&s_debugCS); + DrAssert(m_holders.size() == (size_t) m_iRefCount); + size_t nRemoved = m_holders.erase(h); + DrAssert(nRemoved == 1); + --m_iRefCount; + if (m_iRefCount <= 0) + { + DrAssert(m_iRefCount == 0); + FreeMemory(); + } + LeaveCriticalSection(&s_debugCS); +} +#else +void DrRefCounter::DecRef() +{ + LONG i; + i = InterlockedDecrement(&m_iRefCount); + if (i <= 0) + { + DrAssert(i == 0); + FreeMemory(); + } +} +#endif + +void DrInterfaceRefBase::AssertTypeCast() +{ + DrLogA("Type cast failed"); +} + +#endif \ No newline at end of file diff --git a/GraphManager/shared/DrRef.h b/GraphManager/shared/DrRef.h new file mode 100644 index 0000000..9631f95 --- /dev/null +++ b/GraphManager/shared/DrRef.h @@ -0,0 +1,544 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +/* + +DrRef is the class that is the basis for managing object storage in Dryad. +Under both the managed and unmanaged compilation, all object references are +wrapped in a DrRef, and the object is accessed and copied using overloads on +DrRef. In the managed version DrRef simply holds a managed reference to the +object on the garbage-collected heap. In the native version the object +inherits from a basic reference-counter class, and DrRef manages incrementing +and decrementing the reference count as the object is passed around the program. + +This header also includes macros for allocating and freeing objects and arrays +so that they are allocated on the managed heap or the native heap depending on +compilation type. + +A base type that needs to be reference-counted/managed is declared as follows: + +DRBASECLASS(Foo) +{ + Foo() { ... } + ... +}; +DRREF(Foo); + +Now type FooRef (which is just DrRef) is used to store the object +throughout the program, and FooPtr (which is a Foo^ or Foo* depending) +can be used for a 'temporary' copy that does not modify the reference count in +the native world. + +If you want to inherit from one of these classes it works like this: + +DRCLASS(Bar) : public Foo +{ + ... +}; +DRREF(Bar); + +FooRef x = DrNew Foo(); +{ + FooRef y; + y = x; // refcount is now 2 +} // refcount is back to 1 when y goes out of scope + +FooPtr z = x; // refcount is still 1: be careful x doesn't go out of scope! + +x->Wombat(); // can access members directly through the overload of -> +z->Wombat(); // or of course on the temporary copy directly + +*/ + +#ifdef _MANAGED + +#define DRABSTRACT abstract +#define DROVERRIDE override +#define DRSEALED sealed + +#define DRENUM(T_) enum T_ +#define DRPUBLICENUM(T_) public enum T_ +#define DRBASECLASS(T_) public ref class T_ : public IDrRefCounter +#define DRINTERNALBASECLASS(T_) ref class T_ : public IDrRefCounter +#define DRINTERFACE(T_) public interface class T_ +#define DRCLASS(T_) public ref class T_ +#define DRINTERNALCLASS(T_) ref class T_ +#define DRDECLARECLASS(T_) ref class T_ +#define DRVALUECLASS(T_) public value class T_ +#define DRDECLAREVALUECLASS(T_) value class T_ +#define DRINTERNALVALUECLASS(T_) value class T_ +#define DRTEMPLATE template ref class +#define DRREF(T_) typedef T_^ T_##Ref; typedef T_^ T_##Ptr; typedef T_% T_##R +#define DRAREF(A_,T_) typedef A_^ A_##Ref; typedef A_^ A_##Ptr; typedef A_% A_##R +#define DRIREF(T_) typedef T_^ T_##IRef; typedef T_^ T_##Ptr; typedef T_% T_##R +#define DRRREF(T_) typedef T_% T_##R +#define DrNew gcnew +#define DrNull nullptr +#define DRPIN(T_) pin_ptr< T_ > +#define DrObjectPtr Object^ + +interface class IDrRefCounter +{ +}; + +template class DrRefHolder +{ +public: + DrRefHolder() + { + /* The CLR may initialize System::IntPtr members to IntPtr::Zero, + but let's err on the safe side */ + m_storedValue = System::IntPtr::Zero; + } + + void Store(T^ obj) + { + System::Runtime::InteropServices::GCHandle gch = + System::Runtime::InteropServices::GCHandle::Alloc(obj); + m_storedValue = System::Runtime::InteropServices::GCHandle::ToIntPtr(gch); + } + + T^ Extract() + { + /* In some cases, Extract may be called on a DrRefHolder for which Store was + never called. Specifically, DrXComputeCancelScheduleProcessOverlapped + does not use the m_message member of its base class DrXComputeOverlapped, but + DrXComputeOverlapped::Discard calls m_message.Extract. */ + if (m_storedValue == System::IntPtr::Zero) + { + return DrNull; + } + + System::Runtime::InteropServices::GCHandle gch = + System::Runtime::InteropServices::GCHandle::FromIntPtr(m_storedValue); + m_storedValue = System::IntPtr::Zero; + + T^ obj = (T^) gch.Target; + gch.Free(); + + return obj; + } + +private: + System::IntPtr m_storedValue; +}; + +template public value class DrValueWrapper +{ +public: + T_% T() + { + if (p == DrNull) + { + p = DrNew T_(); + } + + return *p; + } + +private: + T_^ p; +}; + + +#else + +#define DRABSTRACT =0 +#define DROVERRIDE +#define DRSEALED + +#define DRENUM(T_) enum T_ +#define DRPUBLICENUM(T_) enum T_ +#define DRBASECLASS(T_) class T_ : public DrRefCounter +#define DRINTERNALBASECLASS(T_) class T_ : public DrRefCounter +#define DRINTERFACE(T_) class T_ +#define DRCLASS(T_) class T_ +#define DRINTERNALCLASS(T_) class T_ +#define DRDECLARECLASS(T_) class T_ +#define DRVALUECLASS(T_) class T_ +#define DRDECLAREVALUECLASS(T_) class T_ +#define DRINTERNALVALUECLASS(T_) class T_ +#define DRTEMPLATE template class +#define DRREF(T_) typedef DrRef T_##Ref; typedef T_* T_##Ptr; typedef T_& T_##R +#define DRAREF(A_,T_) typedef DrArrayRef A_##Ref; typedef A_* A_##Ptr; typedef A_& A_##R +#define DRIREF(T_) typedef DrInterfaceRef T_##IRef; typedef T_* T_##Ptr; typedef T_& T_##R +#define DRRREF(T_) typedef T_& T_##R +#define DrNew new +#define DrNull NULL +#define DRPIN(T_) T_* +#define DrObjectPtr DrRefCounter* + +//#define _DEBUG_DRREF + +#ifdef _DEBUG_DRREF +#include +#include +#endif + +class DrRefCounter +{ +public: +#ifdef _DEBUG_DRREF + void IncRef(void* holder); + void DecRef(void* holder); +#else + void IncRef(); + void DecRef(); +#endif + +protected: + mutable volatile LONG m_iRefCount; + + DrRefCounter(); + virtual ~DrRefCounter(); + void FreeMemory(); // Called when the refcount becomes zero. + +#ifdef _DEBUG_DRREF +public: + static std::set s_refsAllocated; + static std::map s_arrayStorage; + static CRITICAL_SECTION s_debugCS; + std::set m_holders; +#endif +}; + +#ifdef _DEBUG_DRREF +#define DRINCREF(p_) p_->IncRef(this) +#else +#define DRINCREF(p_) p_->IncRef() +#endif + +#ifdef _DEBUG_DRREF +#define DRDECREF(p_) p_->DecRef(this) +#else +#define DRDECREF(p_) p_->DecRef() +#endif + +template class DrRef +{ +public: + DrRef() + { + p = NULL; + } + + explicit DrRef(const DrRef& ref) + { + p = ref.p; + + if (p != NULL) + { + DRINCREF(p); + } + } + + DrRef(T* lp) + { + p = lp; + + if (p != NULL) + { + DRINCREF(p); + } + } + + virtual ~DrRef() + { + if (p != NULL) + { + DRDECREF(p); + p = NULL; // Make sure we AV in case someone is using DrRef after DecRef + } + } + + operator T*() const + { + return p; + } + + T* operator->() const + { + return p; + } + + bool operator==(T* pT) const + { + return (p == pT); + } + + bool operator!=(T* pT) const + { + return (p != pT); + } + + bool operator<(T* pT) const + { + return p < pT; + } + + DrRef& Set(T* lp) + { + if (p != lp) + { + if (lp != NULL) + { + DRINCREF(lp); + } + + if (p != NULL) + { + DRDECREF(p); + } + + p = lp; + } + + return *this; + } + + DrRef& operator=(const DrRef& ref) + { + return Set(ref.p); + } + + DrRef& operator=(T* lp) + { + return Set(lp); + } + +private: + T* p; +}; + +template class DrArrayRef +{ +public: + DrArrayRef() + { + p = NULL; + } + + explicit DrArrayRef(const DrArrayRef& ref) + { + p = ref.p; + + if (p != NULL) + { + DRINCREF(p); + } + } + + DrArrayRef(A* lp) + { + p = lp; + + if (p != NULL) + { + DRINCREF(p); + } + } + + virtual ~DrArrayRef() + { + if (p != NULL) + { + DRDECREF(p); + p = NULL; // Make sure we AV in case someone is using DrRef after DecRef + } + } + + operator A*() const + { + return p; + } + + A* operator->() const + { + return p; + } + + T& operator[](int element) + { + return p->operator[](element); + } + + bool operator==(A* pT) const + { + return (p == pT); + } + + bool operator!=(A* pT) const + { + return (p != pT); + } + + bool operator<(A* pT) const + { + return p < pT; + } + + DrArrayRef& Set(A* lp) + { + if (p != lp) + { + if (lp != NULL) + { + DRINCREF(lp); + } + + if (p != NULL) + { + DRDECREF(p); + } + + p = lp; + } + + return *this; + } + + DrArrayRef& operator=(const DrArrayRef& ref) + { + return Set(ref.p); + } + + DrArrayRef& operator=(A* lp) + { + return Set(lp); + } + +private: + A* p; +}; + +/* the only reason for this class is so that we can get DrInterfaceRef to assert using the normal + logging interface if the type cast fails. We can't call DrLogA from this header since it hasn't + been defined yet, but generic classes need to have their implementations in the header files... */ +class DrInterfaceRefBase +{ +protected: + static void AssertTypeCast(); +}; + +/* DrInterfaceRef is a reference counter holder for an object of type I that is an interface. + There is a runtime check that the actual object inherits from DrRefCounter. We can't check + this statically due to the limitations on multiple inheritance imposed by allowing cross-compilation + to managed code */ +template class DrInterfaceRef : public DrInterfaceRefBase +{ +public: + DrInterfaceRef() + { + } + + explicit DrInterfaceRef(const DrInterfaceRef& ref) + { + Set((I*) ref); + } + + DrInterfaceRef(I* lp) + { + Set(lp); + } + + operator I*() const + { + return dynamic_cast((DrRefCounter *) p); + } + + I* operator->() const + { + return dynamic_cast((DrRefCounter *) p); + } + + bool operator!() const + { + return (!p); + } + + bool operator==(I* pT) const + { + I* pI = dynamic_cast((DrRefCounter *) p); + return (pI == pT); + } + + bool operator!=(I* pT) const + { + return ((I*) this != pT); + } + + DrInterfaceRef& Set(I* obj) + { + p = dynamic_cast(obj); + if (obj != DrNull && p == DrNull) + { + AssertTypeCast(); + } + return *this; + } + + DrInterfaceRef& operator=(const DrInterfaceRef& ref) + { + return Set((I*) ref); + } + + DrInterfaceRef& operator=(I* lp) + { + return Set(lp); + } + +private: + DrRef p; +}; + +template class DrRefHolder +{ +public: + void Store(T* obj) + { + m_storedValue = obj; + } + + DrRef Extract() + { + DrRef obj = m_storedValue; + m_storedValue = NULL; + return obj; + } + +private: + DrRef m_storedValue; +}; + +template class DrValueWrapper +{ +public: + T_& T() + { + return t; + } + +private: + T_ t; +}; + +#endif diff --git a/GraphManager/shared/DrSet.h b/GraphManager/shared/DrSet.h new file mode 100644 index 0000000..aa1b534 --- /dev/null +++ b/GraphManager/shared/DrSet.h @@ -0,0 +1,220 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#ifdef _MANAGED + +using namespace System::Threading; + +public ref class SimpleLock +{ +public: + SimpleLock(System::Object ^object) + { + m_object = object; + if (Monitor::TryEnter(m_object) == false) + { + System::Diagnostics::Debugger::Break(); + } + } + + ~SimpleLock() + { + Monitor::Exit(m_object); + } + +private: + System::Object ^m_object; +}; + + +// this inherits from Dictionary instead of HashSet because my (and Ulfar's) install of VS08 +// does not permit HashSet in C++/CLI. I don't know why. If I figure it out, this can easily +// be fixed +template DRBASECLASS(DrSet), public System::Collections::Generic::Dictionary +{ +public: + value class DrEnumerator + { + public: + DrEnumerator(Dictionary^ hset) + { + m_iterator = hset->GetEnumerator(); + } + + T GetElement() + { + return m_iterator.Current.Key; + } + + bool MoveNext() + { + return m_iterator.MoveNext(); + } + + private: + Dictionary::Enumerator m_iterator; + }; + + bool Add(T key) + { + if (Contains(key)) + { + return false; + } + else + { + Add(key, false); + return true; + } + } + + virtual bool Remove(T element) new + { + bool bRet = System::Collections::Generic::Dictionary::Remove(element); + return bRet; + } + + bool Contains(T key) + { + bool dummy; + try + { + dummy = TryGetValue(key, dummy); + } + catch (...) + { + return true; + } + return dummy; + } + + DrEnumerator GetDrEnumerator() + { + return DrEnumerator(this); + } + + int GetSize() + { + return Count; + } +}; + +#else + +#include + +template DRBASECLASS(DrSet) +{ + typedef std::set Set; + +public: + class DrEnumerator + { + public: + DrEnumerator(Set* s) + { + m_iterator = s->begin(); + m_end = s->end(); + m_moved = false; + } + + const T& GetElement() + { + DrAssert(m_moved); + return *m_iterator; + } + + bool MoveNext() + { + if (m_moved == false) + { + m_moved = true; + } + else if (m_iterator != m_end) + { + ++m_iterator; + } + + if (m_iterator == m_end) + { + return false; + } + else + { + return true; + } + } + + private: + typename Set::iterator m_iterator; + typename Set::iterator m_end; + bool m_moved; + }; + + void Add(T element) + { + bool inserted = + m_set.insert(element).second; + DrAssert(inserted); + } + + bool Contains(T element) + { + Set::const_iterator i = m_set.find(element); + return (i != m_set.end()); + } + + bool Remove(T element) + { + size_t nRemoved = m_set.erase(element); + if (nRemoved == 1) + { + return true; + } + else + { + DrAssert(nRemoved == 0); + return false; + } + } + + DrEnumerator GetDrEnumerator() + { + return DrEnumerator(&m_set); + } + + int GetSize() + { + return (int) m_set.size(); + } + +protected: + Set m_set; +}; + +#endif + +typedef DrSet DrIntSet; +DRREF(DrIntSet); + +typedef DrSet DrUInt64Set; +DRREF(DrUInt64Set); \ No newline at end of file diff --git a/GraphManager/shared/DrShared.h b/GraphManager/shared/DrShared.h new file mode 100644 index 0000000..e48c7d9 --- /dev/null +++ b/GraphManager/shared/DrShared.h @@ -0,0 +1,56 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +/* TODO understand public classes */ +#pragma warning( disable: 4677 ) + +#define _HAS_ITERATOR_DEBUGGING 0 + +//#define _CRTDBG_MAP_ALLOC +//#include +//#include + +#define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers +#include + +#include "DrTypes.h" + +#include "DrAssert.h" + +#include "DrRef.h" + +#include "DrCritSec.h" + +#include "DrString.h" +#include "DrFileWriter.h" +#include "DrLogging.h" + +#include "DrSort.h" +#include "DrArray.h" +#include "DrArrayList.h" +#include "DrDictionary.h" +#include "DrSet.h" +#include "DrMultiMap.h" + +#include "DrError.h" + +#include "DrStringUtil.h" \ No newline at end of file diff --git a/GraphManager/shared/DrSort.h b/GraphManager/shared/DrSort.h new file mode 100644 index 0000000..019ef8e --- /dev/null +++ b/GraphManager/shared/DrSort.h @@ -0,0 +1,46 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#ifdef _MANAGED + +template public ref class DrComparer abstract : public IDrRefCounter, public System::Collections::Generic::IComparer +{ +public: + virtual int Compare(T x, T y) = 0; +}; + +#else + +template DRBASECLASS(DrComparer) +{ +public: + static int __cdecl CompareUntyped(void* context, const void* x, const void* y) + { + DrComparer* self = (DrComparer *) context; + return self->Compare(*((T *) x), *((T *) y)); + } + + virtual int Compare(T x, T y) = 0; +}; + +#endif + diff --git a/GraphManager/shared/DrString.cpp b/GraphManager/shared/DrString.cpp new file mode 100644 index 0000000..2a167ce --- /dev/null +++ b/GraphManager/shared/DrString.cpp @@ -0,0 +1,539 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +static int ReverseIndexOfCharInBuf(char c, const char* buf, int length) +{ + while (length > 0) + { + --length; + if (buf[length] == c) + { + return length; + } + } + return DrStr_InvalidIndex; +}; + +static int IndexOfCharInBuf(char c, const char* buf, int length, int startPos) +{ + int i; + for (i=startPos; i!DrStringBuffer(); +} + +DrStringBuffer::!DrStringBuffer() +{ + delete [] m_buffer; +} + +void DrStringBuffer::SetString(System::String^ string) +{ + DrAssert(m_string == nullptr); + DrAssert(m_buffer == nullptr); + + m_string = string; +} + +System::String^ DrStringBuffer::GetString() +{ + return m_string; +} + +const char* DrStringBuffer::GetChars() +{ + System::Threading::Monitor::Enter(this); + + if (m_string == nullptr) + { + System::Threading::Monitor::Exit(this); + return NULL; + } + + if (m_buffer == NULL) + { + /* cache a copy of the UTF8 char* representation */ + pin_ptr c = PtrToStringChars(m_string); + + int newSize = ::WideCharToMultiByte(CP_UTF8, 0, c, -1, NULL, 0, NULL, NULL); + m_buffer = new char[newSize]; + + int convertedChars = ::WideCharToMultiByte(CP_UTF8, 0, c, -1, m_buffer, newSize, NULL, NULL); + if (convertedChars == 0) + { + delete[] m_buffer; + m_buffer = NULL; + DrLogA("Failed to convert %S to UTF8: %d", c, GetLastError()); + } + } + + const char* buffer = m_buffer; + System::Threading::Monitor::Exit(this); + return buffer; +} + +int DrStringBuffer::GetCharsLength() +{ + const char* chars = GetChars(); + if (chars == nullptr) + { + return 0; + } + else + { + return (int) strlen(chars); + } +} + +DrString::DrString() +{ + m_buffer = DrNew DrStringBuffer(); +} + +DrString::DrString(System::String^ s) +{ + Set(s); +} + +DrString::DrString(DrString% s) +{ + m_buffer = s.m_buffer; +} + +int DrString::GetCharsLength() +{ + return m_buffer->GetCharsLength(); +} + +System::String^ DrString::GetString() +{ + return m_buffer->GetString(); +} + +DrString% DrString::operator=(System::String^ s) +{ + Set(s); + return *this; +} + +DrString% DrString::operator=(DrString% s) +{ + Set(s); + return *this; +} + +void DrString::Set(System::String^ newString) +{ + m_buffer = DrNew DrStringBuffer(); + m_buffer->SetString(newString); +} + +void DrString::Set(DrString otherString) +{ + m_buffer = otherString.m_buffer; +} + +int DrString::Compare(DrString otherString) +{ + if (m_buffer == otherString.m_buffer) + { + return 0; + } + else + { + return Compare(otherString.GetString()); + } +} + +int DrString::Compare(System::String^ otherString) +{ + System::String^ thisString = m_buffer->GetString(); + + if (otherString == nullptr) + { + if (thisString == nullptr) + { + return 0; + } + else + { + return 1; + } + } + else if (thisString == nullptr) + { + return -1; + } + + return System::String::Compare(thisString, otherString); +} + +int DrString::Compare(System::String^ otherString, int charsToCompare) +{ + System::String^ thisString = m_buffer->GetString(); + + if (otherString == nullptr) + { + if (thisString == nullptr) + { + return 0; + } + else + { + return 1; + } + } + else if (thisString == nullptr) + { + return -1; + } + + return System::String::Compare(thisString, 0, otherString, 0, charsToCompare); +} + + +void DrString::SetF(const char *format, ...) +{ + va_list args; + va_start(args, format); + + VSetF(format, args); +} + +const char* DrString::GetChars() +{ + return m_buffer->GetChars(); +} + +void DrString::VSetF(const char* format, va_list args) +{ + size_t bufferLength = 2 * strlen(format) + 1; + + for ( ; ; ) + { + char* buffer = new char[bufferLength]; + int ret = _vsnprintf_s(buffer, bufferLength, bufferLength-1, format, args); + if (ret >= 0) + { + System::String^ s = gcnew System::String(buffer); + delete [] buffer; + m_buffer = DrNew DrStringBuffer(); + m_buffer->SetString(s); + break; + } + + delete [] buffer; + bufferLength *= 2; + } +} + +int DrString::IndexOfChar(char c) +{ + return IndexOfChar(c, 0); +} + +int DrString::IndexOfChar(char c, int startPos) +{ + const char* buf = GetChars(); + if (buf == nullptr) + { + return DrStr_InvalidIndex; + } + + int length = (int) strlen(buf); + return IndexOfCharInBuf(c, buf, length, startPos); +} + + +int DrString::ReverseIndexOfChar(char c) +{ + const char* buf = GetChars(); + if (buf == nullptr) + { + return DrStr_InvalidIndex; + } + + int length = (int) strlen(buf); + return ReverseIndexOfCharInBuf(c, buf, length); +} + +#else + +DrStringBuffer::DrStringBuffer(int length) +{ + m_buffer = new char[length+1]; + m_length = length; +} + +DrStringBuffer::~DrStringBuffer() +{ + delete [] m_buffer; +} + +char* DrStringBuffer::GetString() +{ + return m_buffer; +} + +void DrStringBuffer::SetLength(int length) +{ + m_length = length; +} + +int DrStringBuffer::GetLength() +{ + return m_length; +} + + +DrString::DrString() +{ +} + +DrString::DrString(const char* s) +{ + Set(s); +} + +DrString::DrString(const DrString& s) +{ + m_string = s.m_string; +} + +int DrString::GetCharsLength() +{ + if (m_string == NULL) + { + return 0; + } + else + { + return m_string->GetLength(); + } +} + +const char* DrString::GetString() +{ + if (m_string == NULL) + { + return NULL; + } + else + { + return m_string->GetString(); + } +} + +const char* DrString::GetChars() +{ + return GetString(); +} + +void DrString::Set(const char* newString) +{ + if (newString != NULL) + { + int newLength = (int) strlen(newString); + DrStringBufferRef newBuffer = DrNew DrStringBuffer(newLength); + errno_t err = strcpy_s(newBuffer->GetString(), newLength+1, newString); + DrAssert(err == 0); + m_string = newBuffer; + } +} + +void DrString::Set(DrString otherString) +{ + m_string = otherString.m_string; +} + +DrString& DrString::operator=(const DrString& other) +{ + m_string = other.m_string; + return *this; +} + +int DrString::Compare(DrString otherString) +{ + if (m_string == otherString.m_string) + { + return 0; + } + else + { + return Compare(otherString.GetString()); + } +} + +int DrString::Compare(const char* otherString) +{ + if (otherString == NULL) + { + if (m_string == NULL) + { + return 0; + } + else + { + return 1; + } + } + else if (m_string == NULL) + { + return -1; + } + + return strcmp(m_string->GetString(), otherString); +} + +int DrString::Compare(const char* otherString, int charsToCompare, bool caseSensitive) +{ + if (otherString == NULL) + { + if (m_string == NULL) + { + return 0; + } + else + { + return 1; + } + } + else if (m_string == NULL) + { + return -1; + } + + if (caseSensitive) + { + return strncmp(m_string->GetString(), otherString, charsToCompare); + } + else + { + return _strnicmp(m_string->GetString(), otherString, charsToCompare); + } +} + +void DrString::SetF(const char *format, ...) +{ + va_list args; + va_start(args, format); + + VSetF(format, args); +} + +void DrString::VSetF(const char* format, va_list args) +{ + int bufferLength = 2 * (int) strlen(format); + + for ( ; ; ) + { + m_string = DrNew DrStringBuffer(bufferLength); + int ret = _vsnprintf_s(m_string->GetString(), bufferLength+1, bufferLength, + format, args); + if (ret >= 0) + { + m_string->SetLength(ret); + break; + } + + bufferLength *= 2; + } +} + +int DrString::IndexOfChar(char c, int startPos) +{ + return IndexOfCharInBuf(c, m_string->GetString(), m_string->GetLength(), startPos); +} + +int DrString::ReverseIndexOfChar(char c) +{ + return ReverseIndexOfCharInBuf(c, m_string->GetString(), m_string->GetLength()); +} + +#endif + +bool DrString::operator==(DrStringR other) +{ + return (Compare(other) == 0); +} + +DrString DrString::AppendF(const char* format, ...) +{ + va_list args; + va_start(args, format); + + DrString newSubString; + newSubString.VSetF(format, args); + + DrString newString; + newString.SetF("%s%s", GetChars(), newSubString.GetChars()); + + return newString; +} + +#include + +void DrString::SetSubString(const char* otherString, int length) +{ + DrAssert((int)strlen(otherString) >= length); + char* copy = _strdup(otherString); + copy[length] = '\0'; + + SetF("%s", copy); + + free(copy); +} + +DrString DrString::ToUpperCase() +{ + char* copy = _strdup(GetChars()); + char* s = copy; + while (*s != '\0') + { + if (__isascii(*s) && islower(*s)) + { + *s = (char)toupper(*s); + } + ++s; + } + + DrString upcased; + upcased.SetF("%s", copy); + return upcased; +} \ No newline at end of file diff --git a/GraphManager/shared/DrString.h b/GraphManager/shared/DrString.h new file mode 100644 index 0000000..4ec54c7 --- /dev/null +++ b/GraphManager/shared/DrString.h @@ -0,0 +1,161 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +static const int DrStr_InvalidIndex = -1; + +DRDECLARECLASS(DrString); +DRRREF(DrString); + +#ifdef _MANAGED + +#include + +#define DrNullString ((System::String^) nullptr) +#define DrNativeString System::String^ + +ref class DrStringBuffer +{ +public: + DrStringBuffer(); + ~DrStringBuffer(); + !DrStringBuffer(); + + void SetString(System::String^ string); + System::String^ GetString(); + + const char* GetChars(); + int GetCharsLength(); + +private: + System::String^ m_string; + char* m_buffer; + +}; + +public ref class DrString +{ +public: + DrString(); + DrString(System::String^ newString); + DrString(DrString% otherString); + + DrString% operator=(System::String^ other); + DrString% operator=(DrString% other); + + bool operator==(DrString% other); + + int GetCharsLength(); + + System::String^ GetString(); + + const char* GetChars(); + + void Set(System::String^ newString); + void Set(DrString otherString); + + void SetSubString(const char* otherString, int length); + + int Compare(DrString otherString); + int Compare(System::String^ otherString); + int Compare(System::String^ otherString, int charsToCompare); + + void SetF(const char *format, ...); + void VSetF(const char* format, va_list args); + + DrString AppendF(const char *format, ...); + DrString ToUpperCase(); + + int IndexOfChar(char c); + int IndexOfChar(char c, int startPos); + int ReverseIndexOfChar(char c); + +private: + DrStringBuffer^ m_buffer; +}; +DRRREF(DrString); + +#else + +#define DrNullString ((const char*) NULL) +#define DrNativeString const char* + +DRBASECLASS(DrStringBuffer) +{ +public: + DrStringBuffer(int length); + ~DrStringBuffer(); + + char* GetString(); + + int GetLength(); + void SetLength(int length); + +private: + char* m_buffer; + int m_length; +}; +DRREF(DrStringBuffer); + +DRVALUECLASS(DrString) +{ +public: + DrString(); + DrString(const char* newString); + DrString(const DrString& otherString); + + DrString& operator=(const DrString& other); + + bool operator==(DrString& other); + + int GetCharsLength(); + + const char* GetString(); + + const char* GetChars(); + + void Set(const char* newString); + void Set(DrString otherString); + + void SetSubString(const char* otherString, int length); + + int Compare(DrString otherString); + int Compare(const char* otherString); + int Compare(const char* otherString, int charsToCompare, bool caseSensitive=true); + + void SetF(const char *format, ...); + void VSetF(const char* format, va_list args); + + DrString AppendF(const char *format, ...); + DrString ToUpperCase(); + + int IndexOfChar(char c, int startPos=0); + int ReverseIndexOfChar(char c); + +private: + DrStringBufferRef m_string; +}; + +#endif + +DRRREF(DrString); \ No newline at end of file diff --git a/GraphManager/shared/DrStringUtil.cpp b/GraphManager/shared/DrStringUtil.cpp new file mode 100644 index 0000000..5c924d2 --- /dev/null +++ b/GraphManager/shared/DrStringUtil.cpp @@ -0,0 +1,278 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "DrShared.h" +#include "DrErrorInternal.h" + +#include +#include + +struct DrErrorDescription +{ + HRESULT m_error; + const char* m_text; +}; + +static const DrErrorDescription s_errorTable[] = +{ + { DrError_BadMetaData, "Bad MetaData XML" }, + { DrError_InvalidCommand, "Invalid Command" }, + { DrError_VertexReceivedTermination, "Vertex Received Termination" }, + { DrError_InvalidChannelURI, "Invalid Channel URI syntax" }, + { DrError_ChannelOpenError, "Channel Open Error" }, + { DrError_ChannelRestartError, "Channel Restart Error" }, + { DrError_ChannelWriteError, "Channel Write Error" }, + { DrError_ChannelReadError, "Channel Read Error" }, + { DrError_ItemParseError, "Item Parse Error" }, + { DrError_ItemMarshalError, "Item Marshal Error" }, + { DrError_BufferHole, "Buffer Hole" }, + { DrError_ItemHole, "Item Hole" }, + { DrError_ChannelRestart, "Channel Sent Restart" }, + { DrError_ChannelAbort, "Channel Sent Abort" }, + { DrError_VertexRunning, "Vertex Is Running" }, + { DrError_VertexCompleted, "Vertex Has Completed" }, + { DrError_VertexError, "Vertex Had Errors" }, + { DrError_ProcessingError, "Error While Processing" }, + { DrError_VertexInitialization, "Vertex Could Not Initialize" }, + { DrError_ProcessingInterrupted, "Processing was interrupted before completion" }, + { DrError_VertexChannelClose, "Errors during channel close" }, + { DrError_AssertFailure, "Assertion Failure" }, + { DrError_ExternalChannel, "External Channel" }, + { DrError_AlreadyInitialized, "Dryad Already Initialized" }, + { DrError_DuplicateVertices, "Duplicate Vertices" }, + { DrError_ComposeRHSNeedsInput, "RHS of composition must have at least one input" }, + { DrError_ComposeLHSNeedsOutput, "LHS of composition must have at least one output" }, + { DrError_ComposeStagesMustBeDifferent, "Stages for composition must be different" }, + { DrError_ComposeStageEmpty, "Stage for composition is empty" }, + { DrError_VertexNotInGraph, "Vertex not in graph" }, + { DrError_HardConstraintCannotBeMet, "Hard constraint cannot be met" }, + { DrError_XComputeError, "XCompute error" }, + { DrError_CohortShutdown, "Cohort shutdown" }, + { DrError_Unexpected, "Unexpected" }, + { DrError_DependentVertexFailure, "Dependent vertex failure" }, + { DrError_BadOutputReported, "Bad output reported" }, + { DrError_InputUnavailable, "Input unavailabled" }, + + { DrError_EndOfStream, "End of stream" }, + + { DrError_CannotConnectToDsc, "Failed to connect to DSC" }, + { DrError_DscOperationFailed, "DSC operation failed" }, + { DrError_FailedToDeleteFileset, "Failed to delete DSC fileset" }, + { DrError_FailedToCreateFileset, "Failed to create DSC fileset" }, + { DrError_FailedToAddFile, "Failed to add file to DSC fileset" }, + { DrError_FailedToSetMetadata, "Failed to set metadata for DSC fileset" }, + { DrError_FailedToSealFileset, "Failed to seal DSC fileset" }, + { DrError_FailedToSetLease, "Failed to set lease for DSC fileset" }, + { DrError_FailedToOpenFileset, "Failed to open DSC fileset" }, + +}; + +#ifdef _MANAGED + +DRCLASS(DrSystemErrorText) +{ +public: + static void Initialize() + { + } + + static void Discard() + { + } + + static DrString GetSystemErrorText(DWORD dwError) + { + HRESULT hr = HRESULT_FROM_WIN32(dwError); + System::Exception ^e = System::Runtime::InteropServices::Marshal::GetExceptionForHR(hr); + return DrString(e->Message); + } +}; + +#else + +DRCLASS(DrSystemErrorText) +{ +public: + static void Initialize() + { + s_cs = DrNew DrCritSec(); + s_hModuleNetMsg = NULL; + s_hModuleWinHttp = NULL; + } + + static void Discard() + { + s_cs = DrNull; + } + + static DrString GetSystemErrorText(DWORD dwError) + { + LPSTR MessageBuffer; + DWORD dwBufferLength; + HANDLE hModule = NULL; + + DWORD dwUse = dwError; + DWORD dwNormalized = dwError; + if ((dwNormalized & 0xFFFF0000) == ((FACILITY_WIN32 << 16) | 0x80000000)) { + dwNormalized = dwNormalized & 0xFFFF; + } + + DWORD dwFormatFlags = FORMAT_MESSAGE_ALLOCATE_BUFFER | + FORMAT_MESSAGE_IGNORE_INSERTS | + FORMAT_MESSAGE_FROM_SYSTEM ; + + // + // If dwLastError is in the network range, + // load the message source. + // + + if (dwNormalized >= NERR_BASE && dwNormalized <= MAX_NERR) { + { + DrAutoCriticalSection((DrCritSecPtr)s_cs); + + if (s_hModuleNetMsg == NULL) { + s_hModuleNetMsg = LoadLibraryEx( + TEXT("netmsg.dll"), + NULL, + LOAD_LIBRARY_AS_DATAFILE + ); + } + + if(s_hModuleNetMsg != NULL) { + dwFormatFlags |= FORMAT_MESSAGE_FROM_HMODULE; + hModule = s_hModuleNetMsg; + } + } + } else if (dwNormalized >= WINHTTP_ERROR_BASE && dwNormalized <= WINHTTP_ERROR_LAST) { + { + DrAutoCriticalSection((DrCritSecPtr)s_cs); + + if (s_hModuleWinHttp == NULL) { + s_hModuleWinHttp = LoadLibraryEx( + TEXT("winhttp.dll"), + NULL, + LOAD_LIBRARY_AS_DATAFILE + ); + } + + if(s_hModuleWinHttp != NULL) { + dwFormatFlags |= FORMAT_MESSAGE_FROM_HMODULE; + hModule = s_hModuleWinHttp; + dwUse = dwNormalized; + } + } + } + + // + // Call FormatMessage() to allow for message + // text to be acquired from the system + // or from the supplied module handle. + // + // For perf, we assume here that all ANSI error messages are also valid UTF-8. If + // this turns out not to be true, we'll have to get the unicode message and convert to UTF-8 + if ((dwBufferLength = FormatMessageA( + dwFormatFlags, + hModule, + dwUse, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // default language + (LPSTR) &MessageBuffer, + 0, + NULL + )) != 0) { + char *pszBuffer = (char *)malloc(dwBufferLength + 1); + DrAssert(pszBuffer != NULL); + memcpy(pszBuffer, MessageBuffer, dwBufferLength+1); + LocalFree(MessageBuffer); + + DWORD i; + for (i = 0; i < dwBufferLength; ++i) + if (pszBuffer[i] == '\r' || pszBuffer[i] == '\n') + pszBuffer[i] = ' '; + DrString s = DrString(pszBuffer); + free((void*)pszBuffer); + return s; + } + + DrString s = DrString(); + s.SetF("Error code %u (0x%08x)", dwError, dwError); + return s; + } + +private: + static DrCritSecRef s_cs; + static HMODULE s_hModuleNetMsg; + static HMODULE s_hModuleWinHttp; +}; + +DrCritSecRef DrSystemErrorText::s_cs; +HMODULE DrSystemErrorText::s_hModuleNetMsg; +HMODULE DrSystemErrorText::s_hModuleWinHttp; + +DrErrorDictionaryRef DrErrorText::s_dictionary; + +#endif + +void DrErrorText::Initialize() +{ + DrSystemErrorText::Initialize(); + + s_dictionary = DrNew DrErrorDictionary(); + int i; + for (i=0; iAdd(s_errorTable[i].m_error, s); + } +} + +void DrErrorText::Discard() +{ + DrSystemErrorText::Discard(); + s_dictionary = DrNull; +} + +DrString DrErrorText::GetErrorText(HRESULT err) +{ + DrString text; + if (s_dictionary->TryGetValue(err, text)) + { + return text; + } + else + { + return DrSystemErrorText::GetSystemErrorText(err); + } +} + +const char* DrErrorString::GetChars(HRESULT err) +{ + if (err == S_OK) + { + /* common case */ + m_string.T() = "No Error"; + } + else + { + m_string.T() = DrErrorText::GetErrorText(err); + } + + return m_string.T().GetChars(); +} diff --git a/GraphManager/shared/DrStringUtil.h b/GraphManager/shared/DrStringUtil.h new file mode 100644 index 0000000..290f99b --- /dev/null +++ b/GraphManager/shared/DrStringUtil.h @@ -0,0 +1,32 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRVALUECLASS(DrErrorString) +{ +public: + const char* GetChars(HRESULT err); + +private: + DrValueWrapper m_string; +}; + +#define DRERRORSTRING(err_) (DrErrorString().GetChars(err_)) \ No newline at end of file diff --git a/GraphManager/shared/DrTypes.h b/GraphManager/shared/DrTypes.h new file mode 100644 index 0000000..7cd79f2 --- /dev/null +++ b/GraphManager/shared/DrTypes.h @@ -0,0 +1,50 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#define MAX_UINT8 ((UINT8)-1) +#define MAX_UINT16 ((UINT16)-1) +#define MAX_UINT32 ((UINT32)-1) +#define MIN_INT32 ((INT32)0x80000000) +#define MAX_INT32 ((INT32)0x7FFFFFFF) // 2147483647 +#define MAX_UINT64 ((UINT64)-1) +#define MIN_INT64 ((INT64)0x8000000000000000I64) +#define MAX_INT64 0x7FFFFFFFFFFFFFFFi64 + +typedef INT64 DrTimeInterval; + +static const DrTimeInterval DrTimeInterval_Infinite = (DrTimeInterval)MAX_INT64; +static const DrTimeInterval DrTimeInterval_NegativeInfinite = (DrTimeInterval)MIN_INT64; +static const DrTimeInterval DrTimeInterval_Zero = (DrTimeInterval)0; +static const DrTimeInterval DrTimeInterval_Quantum = (DrTimeInterval)1; +static const DrTimeInterval DrTimeInterval_100ns = DrTimeInterval_Quantum; +static const DrTimeInterval DrTimeInterval_Microsecond = DrTimeInterval_100ns * 10; +static const DrTimeInterval DrTimeInterval_Millisecond = DrTimeInterval_Microsecond * 1000; +static const DrTimeInterval DrTimeInterval_Second = DrTimeInterval_Millisecond * 1000; +static const DrTimeInterval DrTimeInterval_Minute = DrTimeInterval_Second * 60; +static const DrTimeInterval DrTimeInterval_Hour = DrTimeInterval_Minute * 60; +static const DrTimeInterval DrTimeInterval_Day = DrTimeInterval_Hour * 24; +static const DrTimeInterval DrTimeInterval_Week = DrTimeInterval_Day * 7; + +typedef UINT64 DrDateTime; + +static const DrDateTime DrDateTime_Never = (DrDateTime) MAX_UINT64; +static const DrDateTime DrDateTime_LongAgo = (DrDateTime) 0; diff --git a/GraphManager/stagemanager/DrDefaultManager.cpp b/GraphManager/stagemanager/DrDefaultManager.cpp new file mode 100644 index 0000000..8791bf2 --- /dev/null +++ b/GraphManager/stagemanager/DrDefaultManager.cpp @@ -0,0 +1,973 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +DrStageManager::DrStageManager(DrGraphPtr graph) : DrSharedCritSec(graph) +{ +} + +DrStageManager::~DrStageManager() +{ +} + + +DrConnectionManager::DrConnectionManager(bool manageVerticesIndividually) +{ + m_manageVerticesIndividually = manageVerticesIndividually; +} + +DrConnectionManager::~DrConnectionManager() +{ +} + +bool DrConnectionManager::ManageVerticesIndividually() +{ + return m_manageVerticesIndividually; +} + +void DrConnectionManager::SetParent(DrManagerBasePtr parent) +{ + m_parent = parent; +} + +DrManagerBasePtr DrConnectionManager::GetParent() +{ + return m_parent; +} + +void DrConnectionManager::AddUpstreamStage(DrManagerBasePtr /* unused upstreamStage */) +{ +} + +DrConnectionManagerRef DrConnectionManager::MakeManagerForVertex(DrVertexPtr /* unused vertex */, + bool /* ununsed splitting */) +{ + return this; +} + +void DrConnectionManager::RegisterVertex(DrVertexPtr /* unused vertex */, + bool /* unused splitting */) +{ +} + +void DrConnectionManager::UnRegisterVertex(DrVertexPtr /* unused vertex */) +{ +} + +void DrConnectionManager::NotifyUpstreamSplit(DrVertexPtr upstreamVertex, + DrVertexPtr baseNewVertexSplitFrom, + int outputPortOfSplitBase, + int /* unused upstreamSplitIndex */) +{ + /* the default behaviour is to add a new edge between + upstreamVertex and localVertex. */ + DefaultDealWithUpstreamSplit(upstreamVertex, baseNewVertexSplitFrom, + outputPortOfSplitBase, DCT_File); +} + +void DrConnectionManager::NotifyUpstreamVertexRemoval(DrVertexPtr upstreamVertex, + int outputPortOfRemovedVertex) +{ + /* the default behaviour is to remove the old edge between + upstreamVertex and localVertex but not modify any + vertices. */ + DefaultDealWithUpstreamRemoval(upstreamVertex, + outputPortOfRemovedVertex); +} + +void DrConnectionManager::NotifyUpstreamVertexCompleted(DrActiveVertexPtr /* unused vertex */, + int /* unused outputPort */, + int /* unused executionVersion */, + DrResourcePtr /* unused machine */, + DrVertexExecutionStatisticsPtr /* unused statistics */) +{ +} + +void DrConnectionManager::NotifyUpstreamLastVertexCompleted(DrManagerBasePtr /* unused upstreamStage */) +{ +} + +void DrConnectionManager::NotifyUpstreamInputReady(DrStorageVertexPtr /* unused vertex */, + int /* unused outputPort */, + DrAffinityPtr /* unused affinity */) +{ +} + +void DrConnectionManager::DefaultDealWithUpstreamSplit(DrVertexPtr upstreamVertex, + DrVertexPtr baseNewVertexSplitFrom, + int outputPortOfSplitBase, + DrConnectorType type) +{ + DrVertexPtr localVertex = baseNewVertexSplitFrom->RemoteOutputVertex(outputPortOfSplitBase); + + int upstreamVertexBasePort = upstreamVertex->GetOutputs()->GetNumberOfEdges(); + int localVertexBasePort = localVertex->GetInputs()->GetNumberOfEdges(); + + upstreamVertex->ConnectOutput(upstreamVertexBasePort, + localVertex, localVertexBasePort, type); +} + +void DrConnectionManager::DefaultDealWithUpstreamRemoval(DrVertexPtr upstreamVertex, + int outputPortOfRemovedVertex) +{ + DrVertexPtr localVertex = upstreamVertex->RemoteOutputVertex(outputPortOfRemovedVertex); + upstreamVertex->DisconnectOutput(outputPortOfRemovedVertex, true); + + /* shrink the local edge list to get rid of the empty slot we just + left. The upstream vertex will be dealt with by its own stage + manager */ + localVertex->GetInputs()->Compact(localVertex); +} + + +bool RTIter::operator==(RTIterR other) +{ + return + (m_version == other.m_version && + m_iter == other.m_iter); +} + + +DrManagerBase::Holder::Holder(DrConnectionManagerPtr manager) +{ + m_manager = manager; + m_stageSet = DrNew DrStageSet(); +} + +DrConnectionManagerPtr DrManagerBase::Holder::GetConnectionManager() +{ + return m_manager; +} + +void DrManagerBase::Holder::AddUpstreamStage(DrManagerBasePtr upstreamStage) +{ + m_stageSet->Add(upstreamStage); + m_manager->AddUpstreamStage(upstreamStage); +} + +bool DrManagerBase::Holder::IsManagingUpstreamStage(DrManagerBasePtr upstreamStage) +{ + return m_stageSet->Contains(upstreamStage); +} + +DrConnectionManagerPtr DrManagerBase::Holder::GetManagerForVertex(DrVertexPtr /* unused vertex */) +{ + return m_manager; +} + +void DrManagerBase::Holder::AddManagedVertex(DrVertexPtr vertex, bool splitting) +{ + m_manager->RegisterVertex(vertex, splitting); +} + +void DrManagerBase::Holder::RemoveManagedVertex(DrVertexPtr vertex) +{ + m_manager->UnRegisterVertex(vertex); +} + +void DrManagerBase::Holder::NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage) +{ + m_manager->NotifyUpstreamLastVertexCompleted(upstreamStage); +} + +DrManagerBase::IndividualHolder::IndividualHolder(DrConnectionManagerPtr manager) + : DrManagerBase::Holder(manager) +{ + DrAssert(manager->ManageVerticesIndividually()); + m_map = DrNew Map(); +} + +DrConnectionManagerPtr DrManagerBase::IndividualHolder::GetManagerForVertex(DrVertexPtr vertex) +{ + DrConnectionManagerRef manager; + if (m_map->TryGetValue(vertex, manager)) + { + return manager; + } + else + { + return DrNull; + } +} + +void DrManagerBase::IndividualHolder::AddManagedVertex(DrVertexPtr vertex, bool splitting) +{ + DrConnectionManagerRef manager = GetConnectionManager()->MakeManagerForVertex(vertex, splitting); + m_map->Add(vertex, manager); + manager->RegisterVertex(vertex, splitting); +} + +void DrManagerBase::IndividualHolder::RemoveManagedVertex(DrVertexPtr vertex) +{ + DrConnectionManagerRef manager; + if (m_map->TryGetValue(vertex, manager)) + { + manager->UnRegisterVertex(vertex); + m_map->Remove(vertex); + } +} + +void DrManagerBase::IndividualHolder::NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage) +{ + Map::DrEnumerator i = m_map->GetDrEnumerator(); + while (i.MoveNext()) + { + i.GetValue()->NotifyUpstreamLastVertexCompleted(upstreamStage); + } +} + + +DrManagerBase::DrManagerBase(DrGraphPtr graph, DrNativeString stageName) : DrStageManager(graph) +{ + m_graph = graph; + m_graph->AddStage(this); + + m_includeInJobStageList = true; + m_stillAddingVertices = false; + m_verticesNotYetCompleted = 0; + m_weHaveCompleted = false; + m_stageStatistics = DrNew DrStageStatistics(); + m_downStreamStages = DrNew DrStageSet(); + + m_vertices = DrNew DrVertexList(); + m_holder = DrNew HolderList(); + m_runningMap = DrNew RunningMap(); + m_runningTimeMap = DrNew RunningTimeMap(); + + m_stageName = stageName; + m_stageStatistics->SetName(m_stageName); +} + +DrManagerBase::~DrManagerBase() +{ +} + +void DrManagerBase::Discard() +{ + if (m_runningMap != DrNull) + { + DrAssert(m_runningMap->GetSize() == 0); + DrAssert(m_runningTimeMap->GetSize() == 0); + + } + + m_graph = DrNull; + m_stageStatistics = DrNull; + + if (m_vertices != DrNull) + { + int i; + for (i=0; iSize(); ++i) + { + m_vertices[i]->Discard(); + } + } + m_vertices = DrNull; + + m_holder = DrNull; + m_downStreamStages = DrNull; + m_runningMap = DrNull; + m_runningTimeMap = DrNull; +} + +DrGraphPtr DrManagerBase::GetGraph() +{ + return m_graph; +} + +void DrManagerBase::InitializeForGraphExecution() +{ + int i; + for (i=0; iSize(); ++i) + { + m_vertices[i]->InitializeForGraphExecution(); + } +} + +void DrManagerBase::KickStateMachine() +{ + int i; + for (i=0; iSize(); ++i) + { + m_vertices[i]->KickStateMachine(); + } +} + +DrString DrManagerBase::GetStageName() +{ + return m_stageName; +} + +DrStageStatisticsPtr DrManagerBase::GetStageStatistics() +{ + return m_stageStatistics; +} + +void DrManagerBase::SetStageStatistics(DrStageStatisticsPtr statistics) +{ + m_stageStatistics = statistics; +} + +bool DrManagerBase::GetIncludeInJobStageList() +{ + return m_includeInJobStageList; +} + +void DrManagerBase::SetIncludeInJobStageList(bool includeInJobStageList) +{ + m_includeInJobStageList = includeInJobStageList; +} + +DrVertexListPtr DrManagerBase::GetVertexVector() +{ + return m_vertices; +} + + +DrManagerBase::HolderPtr DrManagerBase::AddDynamicConnectionManagerInternal(DrManagerBasePtr upstreamStage, + DrConnectionManagerPtr manager) +{ + if (manager->GetParent() == DrNull) + { + manager->SetParent(this); + } + + DrAssert(manager->GetParent() == this); + + int i; + for (i=0; iSize(); ++i) + { + if (m_holder[i]->GetConnectionManager() == manager) + { + /* we've seen this connection manager before */ + break; + } + } + + if (i == m_holder->Size()) + { + if (manager->ManageVerticesIndividually()) + { + m_holder->Add(DrNew IndividualHolder(manager)); + } + else + { + m_holder->Add(DrNew Holder(manager)); + } + } + + m_holder[i]->AddUpstreamStage(upstreamStage); + return m_holder[i]; +} + +void DrManagerBase::AddDynamicConnectionManager(DrStageManagerPtr upstreamStage, + DrConnectionManagerPtr manager) +{ + AddDynamicConnectionManagerInternal(dynamic_cast(upstreamStage), manager); +} + +void DrManagerBase::AddDynamicConnectionManagerAtRuntime(DrStageManagerPtr upstreamStage, + DrConnectionManagerPtr connector) +{ + HolderPtr b = AddDynamicConnectionManagerInternal(dynamic_cast(upstreamStage), connector); + DrAssert(b != DrNull); + + int i; + for (i=0; iSize(); ++i) + { + b->AddManagedVertex(m_vertices[i], false); + } +} + +DrManagerBase::HolderPtr DrManagerBase::LookUpConnectionHolder(DrManagerBasePtr upstreamStage) +{ + int i; + for (i=0; iSize(); ++i) + { + if (m_holder[i]->IsManagingUpstreamStage(upstreamStage)) + { + return m_holder[i]; + } + } + + return DrNull; +} + +DrConnectionManagerPtr DrManagerBase::LookUpConnectionManager(DrVertexPtr vertex, + DrManagerBasePtr upstreamStage) +{ + HolderPtr holder = LookUpConnectionHolder(upstreamStage); + + if (holder == DrNull) + { + /* there's no default connector, and no specific connector for + this stage */ + return DrNull; + } + else + { + return holder->GetManagerForVertex(vertex); + } +} + +void DrManagerBase::RegisterVertex(DrVertexPtr vertex) +{ + RegisterVertexInternal(vertex, false); +} + +void DrManagerBase::RegisterVertexInternal(DrVertexPtr vertex, bool registerSplit) +{ + DrAssert(m_weHaveCompleted == false); + ++m_verticesNotYetCompleted; + + /* make sure all necessary connection managers are in place for + this vertex */ + int b; + for (b=0; bSize(); ++b) + { + m_holder[b]->AddManagedVertex(vertex, registerSplit); + } + + m_vertices->Add(vertex); + m_stageStatistics->IncrementSampleSize(); + + /* pass this vertex on to any derived classes */ + RegisterVertexDerived(vertex); +} + +/* this is a virtual method and the default does nothing */ +void DrManagerBase::RegisterVertexDerived(DrVertexPtr /* unused vertex */) +{ +} + +void DrManagerBase::RegisterVertexSplit(DrVertexPtr vertex, DrVertexPtr baseToSplitFrom, int splitIndex) +{ + RegisterVertexInternal(vertex, true); + + /* tell all the downstream stage managers that we are splitting a + vertex here */ + int numberOfOutputs = baseToSplitFrom->GetOutputs()->GetNumberOfEdges(); + int i; + for (i=0; iRemoteOutputVertex(i); + DrManagerBasePtr remoteManager = dynamic_cast(remote->GetStageManager()); + remoteManager->NotifyUpstreamSplit(vertex, baseToSplitFrom, i, splitIndex); + } + + RegisterVertexSplitDerived(vertex, baseToSplitFrom, splitIndex); +} + +/* this is a virtual method and the default does nothing */ +void DrManagerBase::RegisterVertexSplitDerived(DrVertexPtr /* unused vertex */, + DrVertexPtr /* unused baseToSplitFrom */, + int /* unused splitIndex */) +{ +} + +void DrManagerBase::NotifyUpstreamSplit(DrVertexPtr upstreamVertex, DrVertexPtr baseNewVertexSplitFrom, + int outputPortOfSplitBase, int upstreamSplitIndex) +{ + DrVertexPtr localVertex = baseNewVertexSplitFrom->RemoteOutputVertex(outputPortOfSplitBase); + DrManagerBasePtr upstreamStage = dynamic_cast(baseNewVertexSplitFrom->GetStageManager()); + + DrConnectionManagerPtr manager = LookUpConnectionManager(localVertex, upstreamStage); + if (manager != DrNull) + { + /* if we have a connection manager for this upstream stage, + tell it about the split */ + manager->NotifyUpstreamSplit(upstreamVertex, baseNewVertexSplitFrom, + outputPortOfSplitBase, upstreamSplitIndex); + } + else + { + /* the default behaviour is to add a new edge between + upstreamVertex and localVertex. The DrNull below says to use + the graph's default channel constructor for the new edge */ + DrConnectionManager::DefaultDealWithUpstreamSplit(upstreamVertex, baseNewVertexSplitFrom, + outputPortOfSplitBase, DCT_File); + } + + /* pass this split on to any derived classes */ + NotifyUpstreamSplitDerived(upstreamVertex, baseNewVertexSplitFrom, + outputPortOfSplitBase, upstreamSplitIndex); +} + +/* this is a virtual method and the default does nothing */ +void DrManagerBase::NotifyUpstreamSplitDerived(DrVertexPtr /* unused upstreamVertex */, + DrVertexPtr /* unused baseNewVertexSplitFrom */, + int /* unused outputPortOfSplitBase */, + int /* unused upstreamSplitIndex */) +{ +} + +void DrManagerBase::SetStillAddingVertices(bool stillAddingVertices) +{ + DrAssert(m_weHaveCompleted == false); + m_stillAddingVertices = stillAddingVertices; + CheckIfWeHaveCompleted(); +} + +void DrManagerBase::CheckIfWeHaveCompleted() +{ + if (m_verticesNotYetCompleted == 0 && m_stillAddingVertices == false) + { + DrAssert(m_weHaveCompleted == false); + m_weHaveCompleted = true; + + DrStageSet::DrEnumerator s = m_downStreamStages->GetDrEnumerator(); + /* tell all the downstream vertices that the last vertex in + our stage has completed, so they can tell their connection + managers */ + while (s.MoveNext()) + { + s.GetElement()->NotifyUpstreamLastVertexCompleted(this); + } + + NotifyLastVertexCompletedDerived(); + } +} + +void DrManagerBase::UnRegisterVertex(DrVertexPtr vertex) +{ + DrAssert(m_verticesNotYetCompleted > 0); + + /* pass this removal on to any derived classes */ + UnRegisterVertexDerived(vertex); + + bool removed = m_vertices->Remove(vertex); + DrAssert(removed); + + m_stageStatistics->DecrementSampleSize(); + + /* tell the downstream stage managers that something is being removed */ + int numberOfOutputs = vertex->GetOutputs()->GetNumberOfEdges(); + int i; + for (i=0; iRemoteOutputVertex(i); + DrManagerBasePtr remoteManager = dynamic_cast(remote->GetStageManager()); + remoteManager->NotifyUpstreamVertexRemoval(vertex, i); + } + vertex->GetOutputs()->Compact(DrNull); + DrAssert(vertex->GetOutputs()->GetNumberOfEdges() == 0); + + for (i=0; iSize(); ++i) + { + m_holder[i]->RemoveManagedVertex(vertex); + } + + DrAssert(m_verticesNotYetCompleted > 0); + --m_verticesNotYetCompleted; + + CheckIfWeHaveCompleted(); +} + +/* this is a virtual method and the default does nothing */ +void DrManagerBase::UnRegisterVertexDerived(DrVertexPtr /* unused vertex */) +{ +} + +void DrManagerBase::NotifyUpstreamVertexRemoval(DrVertexPtr upstreamVertex, int outputPortOfRemovedVertex) +{ + DrVertexPtr localVertex = upstreamVertex->RemoteOutputVertex(outputPortOfRemovedVertex); + DrManagerBasePtr upstreamStage = dynamic_cast(upstreamVertex->GetStageManager()); + + DrConnectionManagerPtr manager = LookUpConnectionManager(localVertex, upstreamStage); + if (manager != DrNull) + { + /* if we have a connection manager for this upstream stage, + tell it about the split */ + manager->NotifyUpstreamVertexRemoval(upstreamVertex, outputPortOfRemovedVertex); + } + else + { + /* the default behaviour is to remove the old edge between + upstreamVertex and localVertex but not modify any + vertices. */ + DrConnectionManager::DefaultDealWithUpstreamRemoval(upstreamVertex, outputPortOfRemovedVertex); + } + + /* pass this split on to any derived classes */ + NotifyUpstreamVertexRemovalDerived(upstreamVertex, outputPortOfRemovedVertex); +} + +/* this is a virtual method and the default does nothing */ +void DrManagerBase::NotifyUpstreamVertexRemovalDerived(DrVertexPtr /* unused upstreamVertex */, + int /* unused outputPortOfRemovedVertex */) +{ +} + +void DrManagerBase::AddToRunningMap(DrActiveVertexPtr vertex, int version, DrDateTime runningTime) +{ + VertexAndVersion vv; + vv.m_vertex = vertex; + vv.m_version = version; + RunningTimeMap::Iter i = m_runningTimeMap->Insert(runningTime, vv); + + RTIterListRef list; + if (m_runningMap->TryGetValue(vertex, list) == false) + { + list = DrNew RTIterList(); + m_runningMap->Add(vertex, list); + } + + RTIter rt; + rt.m_version = version; + rt.m_iter = i; + list->Add(rt); +} + +void DrManagerBase::RemoveFromRunningMap(DrActiveVertexPtr vertex, int version) +{ + RTIterListRef list; + if (m_runningMap->TryGetValue(vertex, list)) + { + int i; + for (i=0; iSize(); ++i) + { + if (list[i].m_version == version) + { + DrLogI("Removing vertex from running map %d.%d (%s)", + vertex->GetId(), version, vertex->GetName().GetChars()); + m_runningTimeMap->Erase(list[i].m_iter); + list->RemoveAt(i); + + if (list->Size() == 0) + { + m_runningMap->Remove(vertex); + } + return; + } + } + } +} + +void DrManagerBase::CheckForDuplicates() +{ + CheckForDuplicatesDerived(); + + DrTimeInterval threshold = m_stageStatistics->GetOutlierThreshold(m_graph->GetParameters()); + + if (threshold == DrTimeInterval_Infinite || m_runningTimeMap->GetSize() == 0) + { + /* there's not enough data yet to set a threshold, or nothing + running, so just leave */ + return; + } + + DrLogI("Checking for duplicates Stage %s threshold %lf potential duplicate count %d", + GetStageName().GetChars(), (double) threshold / (double) DrTimeInterval_Second, + m_runningTimeMap->GetSize()); + + DrDateTime now = m_graph->GetXCompute()->GetCurrentTimeStamp(); + + while (m_runningTimeMap->GetSize() > 0) + { + RunningTimeMap::Iter i = m_runningTimeMap->Begin(); + DrDateTime runningTime = i->first; + DrAssert(runningTime <= now); + + if (runningTime + threshold > now) + { + DrLogI("Exiting with vertices still running Stage %s threshold %lf potential duplicate count %d", + GetStageName().GetChars(), (double) threshold / (double) DrTimeInterval_Second, + m_runningTimeMap->GetSize()); + return; + } + + /* the oldest vertex has been running for longer than the + outlier threshold, so take it off the list and try to start + it as a duplicate */ + DrActiveVertexPtr v = i->second.m_vertex; + int version = i->second.m_version; + + DrLogI("Considering vertex for duplication Stage %s threshold %lf vertex %d.%d (%s) running for %lf", + GetStageName().GetChars(), (double) threshold / (double) DrTimeInterval_Second, + v->GetId(), version, v->GetName().GetChars(), + (double) (now - runningTime) / (double) DrTimeInterval_Second); + + v->RequestDuplicate(version+1); + + int oldSize = m_runningTimeMap->GetSize(); + RemoveFromRunningMap(v, version); + DrAssert(m_runningTimeMap->GetSize() + 1 == oldSize); + } +} + +/* this is a virtual method and the default does nothing */ +void DrManagerBase::CheckForDuplicatesDerived() +{ +} + +void DrManagerBase::NotifyVertexRunning(DrActiveVertexPtr vertex, int executionVersion, + DrResourcePtr machine, DrVertexExecutionStatisticsPtr statistics) +{ + AddToRunningMap(vertex, executionVersion, statistics->m_runningTime); + + m_stageStatistics->IncrementStartedCount(); + + NotifyVertexRunningDerived(vertex, executionVersion, machine, statistics); +} + +/* this is a virtual method and the default does nothing */ +void DrManagerBase::NotifyVertexRunningDerived(DrActiveVertexPtr /* unused vertex */, + int /* unused executionVersion */, + DrResourcePtr /* unused machine */, + DrVertexExecutionStatisticsPtr /* unused statistics */) +{ +} + +void DrManagerBase::NotifyVertexStatus(DrActiveVertexPtr /* unused vertex */, + HRESULT /* unused completionStatus */, + DrVertexProcessStatusPtr /* unused status */) +{ +} + +void DrManagerBase::NotifyVertexCompleted(DrActiveVertexPtr vertex, int executionVersion, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics) +{ + if (vertex->GetNumberOfReportedCompletions() == 0) + { + DrAssert(m_verticesNotYetCompleted > 0); + --m_verticesNotYetCompleted; + } + + RemoveFromRunningMap(vertex, executionVersion); + + /* machine is DrNull if this is a dummy vertex continuing after + failure */ + if (machine != DrNull) + { + m_stageStatistics->AddMeasurement(m_graph->GetParameters(), machine, statistics); + } + + /* we're keeping track of all the downstream stage managers so we + know who to tell when all the inputs have notified. If there + are lots of outputs, the common case will be that they all + connect to the same stage manager, so special-case this to + avoid hammering on the stl map */ + DrManagerBasePtr previousDownstreamStage = DrNull; + + /* make sure all necessary connection managers know about this + vertex completion */ + int numberOfOutputs = vertex->GetOutputs()->GetNumberOfEdges(); + int i; + for (i=0; iRemoteOutputVertex(i); + DrManagerBasePtr remoteManager = dynamic_cast(remote->GetStageManager()); + remoteManager->NotifyUpstreamVertexCompleted(vertex, i, executionVersion, + machine, statistics); + if (remoteManager != previousDownstreamStage) + { + if (m_downStreamStages->Contains(remoteManager) == false) + { + m_downStreamStages->Add(remoteManager); + } + previousDownstreamStage = remoteManager; + } + } + + /* pass this vertex on to any derived classes */ + NotifyVertexCompletedDerived(vertex, executionVersion, machine, statistics); + + if (vertex->GetNumberOfReportedCompletions() == 0) + { + CheckIfWeHaveCompleted(); + } +} + +/* this is a virtual method and the default does nothing */ +void DrManagerBase::NotifyVertexCompletedDerived(DrActiveVertexPtr /* unused vertex */, + int /* unused executionVersion */, + DrResourcePtr /* unused machine */, + DrVertexExecutionStatisticsPtr /* unused statistics */) +{ +} + +void DrManagerBase::NotifyVertexFailed(DrActiveVertexPtr vertex, int executionVersion, + DrResourcePtr machine, DrVertexExecutionStatisticsPtr statistics) +{ + if ((statistics != DrNull) && (statistics->m_runningTime != DrDateTime_Never)) + { + RemoveFromRunningMap(vertex, executionVersion); + + m_stageStatistics->DecrementStartedCount(); + } + + NotifyVertexFailedDerived(vertex, executionVersion, machine, statistics); +} + +/* this is a virtual method and the default does nothing */ +void DrManagerBase::NotifyVertexFailedDerived(DrActiveVertexPtr /* unused vertex */, + int /* unused executionVersion */, + DrResourcePtr /* unused machine */, + DrVertexExecutionStatisticsPtr /* unused statistics */) +{ +} + +/* this is a virtual method and the default does nothing */ +void DrManagerBase::NotifyLastVertexCompletedDerived() +{ +} + +void DrManagerBase::NotifyUpstreamVertexCompleted(DrActiveVertexPtr upstreamVertex, + int outputPortOfCompletedVertex, + int executionVersion, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics) +{ + DrVertexPtr localVertex = upstreamVertex->RemoteOutputVertex(outputPortOfCompletedVertex); + DrManagerBasePtr upstreamStage = dynamic_cast(upstreamVertex->GetStageManager()); + + DrConnectionManagerPtr manager = LookUpConnectionManager(localVertex, upstreamStage); + if (manager != DrNull) + { + /* if we have a connection manager for this upstream stage, + tell it about the completion */ + manager->NotifyUpstreamVertexCompleted(upstreamVertex, outputPortOfCompletedVertex, + executionVersion, machine, statistics); + } + + /* pass this completion on to any derived classes */ + NotifyUpstreamVertexCompletedDerived(upstreamVertex, outputPortOfCompletedVertex, + executionVersion, machine, statistics); +} + +/* this is a virtual method and the default does nothing */ +void DrManagerBase::NotifyUpstreamVertexCompletedDerived(DrActiveVertexPtr /* unused upstreamVertex */, + int /* unused outputPortOfCompletedVertex */, + int /* unused executionVersion */, + DrResourcePtr /* unused machine */, + DrVertexExecutionStatisticsPtr /* unused statistics */) +{ +} + +void DrManagerBase::NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage) +{ + HolderPtr holder = LookUpConnectionHolder(upstreamStage); + if (holder != DrNull) + { + /* if we have a connection manager for this upstream stage, + tell it about the completion */ + holder->NotifyUpstreamLastVertexCompleted(upstreamStage); + } + + /* pass this on to any derived classes */ + NotifyUpstreamLastVertexCompletedDerived(upstreamStage); +} + +/* this is a virtual method and the default does nothing */ +void DrManagerBase::NotifyUpstreamLastVertexCompletedDerived(DrManagerBasePtr /* unused upstreamStage */) +{ +} + +void DrManagerBase::NotifyInputReady(DrStorageVertexPtr vertex, DrAffinityPtr affinity) +{ + DrAssert(m_verticesNotYetCompleted > 0); + --m_verticesNotYetCompleted; + + /* we're keeping track of all the downstream stage managers so we + know who to tell when all the inputs have notified. If there + are lots of outputs, the common case will be that they all + connect to the same stage manager, so special-case this to + avoid hammering on the stl map */ + DrManagerBasePtr previousDownstreamStage = DrNull; + + /* make sure all necessary connection managers know about this + input */ + int numberOfOutputs = vertex->GetOutputs()->GetNumberOfEdges(); + int i; + for (i=0; iRemoteOutputVertex(i); + DrManagerBasePtr remoteManager = dynamic_cast(remote->GetStageManager()); + remoteManager->NotifyUpstreamInputReady(vertex, i, affinity); + if (remoteManager != previousDownstreamStage) + { + if (m_downStreamStages->Contains(remoteManager) == false) + { + m_downStreamStages->Add(remoteManager); + } + previousDownstreamStage = remoteManager; + } + } + + /* pass this vertex on to any derived classes */ + NotifyInputReadyDerived(vertex, affinity); + + CheckIfWeHaveCompleted(); +} + +/* this is a virtual method and the default does nothing */ +void DrManagerBase::NotifyInputReadyDerived(DrStorageVertexPtr /* unused vertex */, + DrAffinityPtr /* unused affinity */) +{ +} + +void DrManagerBase::NotifyUpstreamInputReady(DrStorageVertexPtr upstreamVertex, int upstreamVertexOutputPort, + DrAffinityPtr affinity) +{ + DrVertexPtr localVertex = upstreamVertex->RemoteOutputVertex(upstreamVertexOutputPort); + DrManagerBasePtr upstreamStage = dynamic_cast(upstreamVertex->GetStageManager()); + + DrConnectionManagerPtr manager = LookUpConnectionManager(localVertex, upstreamStage); + if (manager != DrNull) + { + /* if we have a connection manager for this upstream stage, + tell it about the completion */ + manager->NotifyUpstreamInputReady(upstreamVertex, upstreamVertexOutputPort, affinity); + } + + /* pass this split on to any derived classes */ + NotifyUpstreamInputReadyDerived(upstreamVertex, upstreamVertexOutputPort, affinity); +} + +/* this is a virtual method and the default does nothing */ +void DrManagerBase::NotifyUpstreamInputReadyDerived(DrStorageVertexPtr /* unused upstreamVertex */, + int /* unused upstreamVertexOutputPort */, + DrAffinityPtr /* unused affinity */) +{ +} + +/* the is a virtual method and the default always returns true */ +bool DrManagerBase::VertexIsReady(DrActiveVertexPtr vertex) +{ + return m_graph->IsRunning() && VertexIsReadyDerived(vertex); +} + +/* the is a virtual method and the default always returns true */ +bool DrManagerBase::VertexIsReadyDerived(DrActiveVertexPtr /* unused vertex */) +{ + return true; +} + +void DrManagerBase::CancelAllVertices(DrErrorPtr reason) +{ + int i; + + for (i=0; iSize(); ++i) + { + m_vertices[i]->CancelAllVersions(reason); + } +} \ No newline at end of file diff --git a/GraphManager/stagemanager/DrDefaultManager.h b/GraphManager/stagemanager/DrDefaultManager.h new file mode 100644 index 0000000..0a01bd9 --- /dev/null +++ b/GraphManager/stagemanager/DrDefaultManager.h @@ -0,0 +1,375 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRDECLARECLASS(DrManagerBase); +DRREF(DrManagerBase); + +DRDECLARECLASS(DrConnectionManager); +DRREF(DrConnectionManager); + +DRBASECLASS(DrConnectionManager) +{ +public: + DrConnectionManager(bool manageVerticesIndividually); + virtual ~DrConnectionManager(); + + void SetParent(DrManagerBasePtr parent); + DrManagerBasePtr GetParent(); + + virtual void AddUpstreamStage(DrManagerBasePtr upstreamStage); + + bool ManageVerticesIndividually(); + + virtual DrConnectionManagerRef MakeManagerForVertex(DrVertexPtr vertex, bool splitting); + virtual void RegisterVertex(DrVertexPtr vertex, bool splitting); + virtual void UnRegisterVertex(DrVertexPtr vertex); + virtual void NotifyUpstreamSplit(DrVertexPtr upstreamVertex, DrVertexPtr baseNewVertexSplitFrom, + int outputPortOfSplitBase, int upstreamSplitIndex); + virtual void NotifyUpstreamVertexRemoval(DrVertexPtr upstreamVertex, + int outputPortOfRemovedVertex); + + virtual void NotifyUpstreamVertexCompleted(DrActiveVertexPtr vertex, int outputPort, int executionVersion, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics); + virtual void NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage); + virtual void NotifyUpstreamInputReady(DrStorageVertexPtr vertex, int outputPort, DrAffinityPtr affinity); + + static void DefaultDealWithUpstreamSplit(DrVertexPtr upstreamVertex, + DrVertexPtr baseNewVertexSplitFrom, + int outputPortOfSplitBase, DrConnectorType type); + static void DefaultDealWithUpstreamRemoval(DrVertexPtr upstreamVertex, + int outputPortOfRemovedVertex); + +private: + bool m_manageVerticesIndividually; + DrManagerBaseRef m_parent; +}; +DRREF(DrConnectionManager); + +typedef DrSet DrStageSet; +DRREF(DrStageSet); + +typedef DrArrayList DrDefaultStageList; +DRAREF(DrDefaultStageList,DrManagerBaseRef); + +DRVALUECLASS(VertexAndVersion) +{ + public: + DrActiveVertexRef m_vertex; + int m_version; +}; +typedef DrMultiMap RunningTimeMap; +DRREF(RunningTimeMap); + + +DRDECLAREVALUECLASS(RTIter); +DRRREF(RTIter); +DRVALUECLASS(RTIter) +{ + public: + bool operator==(RTIterR other); + + int m_version; + RunningTimeMap::Iter m_iter; +}; +DRMAKEARRAYLIST(RTIter); +typedef DrDictionary RunningMap; +DRREF(RunningMap); + + +DRCLASS(DrManagerBase) : public DrStageManager +{ +public: + DrManagerBase(DrGraphPtr graph, DrNativeString stageName); + virtual ~DrManagerBase(); + + virtual void Discard() DROVERRIDE DRSEALED; + + virtual DrGraphPtr GetGraph() DROVERRIDE DRSEALED; + + virtual void InitializeForGraphExecution() DROVERRIDE DRSEALED; + virtual void KickStateMachine() DROVERRIDE DRSEALED; + + /* the stage name is the friendly name that is used in job + monitoring summaries. When a stage manager is created as a side + effect of creating a vertex, it inherits that vertex's name by + default. */ + virtual DrString GetStageName() DROVERRIDE DRSEALED; + + /* the stage statistics gathers running-time statistics about all + vertices in the stage, and may be used when trying to detect + outliers. */ + DrStageStatisticsPtr GetStageStatistics(); + void SetStageStatistics(DrStageStatisticsPtr statistics); + + /* the stage will only be included in job monitoring summaries if + includeInJobStageList is true. By convention, stages that are + not active (e.g. input or output streams) are not included in + monitoring, since there's not much to say about them */ + virtual bool GetIncludeInJobStageList() DROVERRIDE DRSEALED; + virtual void SetIncludeInJobStageList(bool includeInJobStageList) DROVERRIDE DRSEALED; + + /* assign connector to manage dynamic modifications to the + subgraph edges connecting this stage from upstreamStage, for + example to manage a dynamic merge tree */ + virtual void AddDynamicConnectionManager(DrStageManagerPtr upStreamStage, + DrConnectionManagerPtr connector) DROVERRIDE DRSEALED; + + /* similar to the above, but do it not at graph-build time, but during runtime. + The difference is that the vertices have to be registered */ + virtual void AddDynamicConnectionManagerAtRuntime(DrStageManagerPtr upstreamStage, + DrConnectionManagerPtr connector) DROVERRIDE DRSEALED; + + /* RegisterVertex should be called once for each vertex that is + added to the stage. RegisterVertexDerived is a virtual method + that is called automatically after other actions in + RegisterVertex so that derived classes can keep track of what + is happening, and the base class implementation does nothing. A + client wishing to let the stage know that a vertex has been + added should call RegisterVertex. RegisterVertex should not be + called if RegisterVertexSplit is also called on the new + vertex. */ + virtual void RegisterVertex(DrVertexPtr vertex) DROVERRIDE DRSEALED; + virtual void RegisterVertexDerived(DrVertexPtr vertex); + + /* Some dynamic graph modifications increase the size of a stage + by "splitting" new vertices off from a base vertex, and they + should call RegisterVertexSplit which will notify any relevant + downstream vertex stage managers about the split. In this case + RegisterVertex will automatically be called, and the client + should not call it as well. Different downstream stage managers + will generally want to deal differently with a split vertex, + for example they may also choose to split. Consequently the + split vertex should not be connected to any downstream vertices + before this method is called, and the downstream manager will + add edges as appropriate. By default it will add a new edge + between the new vertex and every downstream vertex that the + baseToSplitFrom is currently connected + to. RegisterVertexSplitDerived is a virtual method that is + called after other actions in RegisterVertexSplit so that + derived classes can keep track of what is happening, and the + base class implementation does nothing. A client wishing to let + the stage know that a vertex has been added should call + RegisterVertexSplit. */ + virtual void RegisterVertexSplit(DrVertexPtr vertex, DrVertexPtr baseToSplitFrom, + int splitIndex) DROVERRIDE DRSEALED; + virtual void RegisterVertexSplitDerived(DrVertexPtr vertex, DrVertexPtr baseToSplitFrom, + int splitIndex); + + /* Some dynamic graph modifications remove vertices from + stages. UnRegisterVertex should be called before a vertex is + removed, and it will automatically call connected downstream + managers to notify them that the vertex is being + removed. UnRegisterVertexDerived is a virtual method that is + called automatically after other actions in UnRegisterVertex so + that derived classes can keep track of what is happening, and + the base class implementation does nothing. A client wishing to + let the stage know that a vertex has been added should call + UnRegisterVertex. */ + virtual void UnRegisterVertex(DrVertexPtr vertex) DROVERRIDE DRSEALED; + virtual void UnRegisterVertexDerived(DrVertexPtr vertex); + + /* NotifyUpstreamSplitDerived is called automatically whenever a + vertex split is registered that is connected upstream of any + vertex with this stage manager. Stage managers can keep track + of upstream stages that are splitting dynamically this way, and + for example propagate the split forwards if they are in a + pipeline. This is a virtual method that is included so that + derived classes can keep track of what is happening, and the + base class implementation does nothing. */ + virtual void NotifyUpstreamSplitDerived(DrVertexPtr upStreamVertex, + DrVertexPtr baseNewVertexSplitFrom, + int outputPortOfSplitBase, + int upstreamSplitIndex); + + /* NotifyUpstreamVertexRemovalDerived is called automatically + whenever a vertex is unregistered that is connected upstream of + any vertex with this stage manager. Stage managers can keep + track of upstream stages that are splitting dynamically this + way, and for example propagate the split forwards if they are + in a pipeline. This is a virtual method that is included so + that derived classes can keep track of what is happening, and + the base class implementation does nothing. */ + virtual void NotifyUpstreamVertexRemovalDerived(DrVertexPtr upstreamVertex, + int outputPortOfRemovedVertex); + + /* VertexIsReady is called by the job manager before it attempts + to run any vertex. If VertexIsReady returns false then the job + manager will not start the vertex, otherwise it will proceed as + normal and run the vertex when it sees fit. VertexIsReady may + be called many times for a given vertex. If it ever returns + false, then the application must subsequently call + vertex->NotifyVertexIsReady() once the vertex is ready + to run, otherwise the job may never make progress. The default + implementation of this method always returns true but it can be + overridden. */ + virtual bool VertexIsReady(DrActiveVertexPtr vertex) DROVERRIDE DRSEALED; + virtual bool VertexIsReadyDerived(DrActiveVertexPtr vertex); + + virtual void SetStillAddingVertices(bool stillAddingVertices) DROVERRIDE DRSEALED; + + /* NotifyVertexStatus is called every time the job manager + receives an update on the vertex, which happens periodically + while the vertex is running, and once when it completes. If + completionStatus is DryadError_VertexRunning the vertex has not + yet completed. If completionStatus is + DryadError_VertexCompleted the vertex has successfully + completed and NotifyVertexStatus will not be called again for + this version of the vertex. Otherwise the vertex has exited + with an error. + + DVertexProcessStatus is defined in + dryad/system/common/include/dvertexcommand.h and it includes + the version of the vertex (with GetVertexInstanceVersion()), + metadata including any error information (with + GetVertexMetaData()) and information about all of its input and + output channels. + + The default implementation does nothing. + */ + virtual void NotifyVertexStatus(DrActiveVertexPtr vertex, + HRESULT completionStatus, + DrVertexProcessStatusPtr status) DROVERRIDE DRSEALED; + + virtual void NotifyVertexRunning(DrActiveVertexPtr vertex, + int executionVersion, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics) DROVERRIDE DRSEALED; + virtual void NotifyVertexRunningDerived(DrActiveVertexPtr vertex, + int executionVersion, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics); + virtual void NotifyVertexCompleted(DrActiveVertexPtr vertex, + int executionVersion, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics) DROVERRIDE DRSEALED; + virtual void NotifyVertexCompletedDerived(DrActiveVertexPtr vertex, + int executionVersion, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics); + virtual void NotifyVertexFailed(DrActiveVertexPtr vertex, int executionVersion, + DrResourcePtr machine, DrVertexExecutionStatisticsPtr statistics) DROVERRIDE DRSEALED; + virtual void NotifyVertexFailedDerived(DrActiveVertexPtr vertex, int executionVersion, + DrResourcePtr machine, DrVertexExecutionStatisticsPtr statistics); + + virtual void CheckForDuplicates() DROVERRIDE DRSEALED; + virtual void CheckForDuplicatesDerived(); + + virtual void NotifyLastVertexCompletedDerived(); + virtual void NotifyUpstreamVertexCompletedDerived(DrActiveVertexPtr upstreamVertex, + int upstreamVertexOutputPort, + int executionVersion, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics); + void NotifyUpstreamLastVertexCompletedDerived(DrManagerBasePtr upstreamStage); + + virtual void NotifyInputReady(DrStorageVertexPtr vertex, DrAffinityPtr affinity) DROVERRIDE DRSEALED; + virtual void NotifyInputReadyDerived(DrStorageVertexPtr vertex, DrAffinityPtr affinity); + virtual void NotifyUpstreamInputReadyDerived(DrStorageVertexPtr vertex, int upstreamVertexOutputPort, + DrAffinityPtr affinity); + + /* this returns a set containing all the vertices that have been + registered with this stage */ + virtual DrVertexListPtr GetVertexVector() DROVERRIDE; + + /* this tells all pending and running versions to abort */ + virtual void CancelAllVertices(DrErrorPtr reason) DROVERRIDE; + + /* this adds self's monitoring information to stats, which is a + container for the statistics of all the stages in the job. */ + //void FillInStageStatistics(CsJobExecutionStatistics* stats); + + DRINTERNALBASECLASS(Holder) + { + public: + Holder(DrConnectionManagerPtr manager); + + DrConnectionManagerPtr GetConnectionManager(); + void AddUpstreamStage(DrManagerBasePtr upstreamStage); + bool IsManagingUpstreamStage(DrManagerBasePtr upstreamStage); + virtual DrConnectionManagerPtr GetManagerForVertex(DrVertexPtr vertex); + virtual void AddManagedVertex(DrVertexPtr vertex, bool splitting); + virtual void RemoveManagedVertex(DrVertexPtr vertex); + virtual void NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage); + + private: + DrConnectionManagerRef m_manager; + DrStageSetRef m_stageSet; + }; + DRREF(Holder); + + typedef DrArrayList HolderList; + DRAREF(HolderList,HolderRef); + + DRINTERNALCLASS(IndividualHolder) : public Holder + { + public: + IndividualHolder(DrConnectionManagerPtr manager); + + virtual DrConnectionManagerPtr GetManagerForVertex(DrVertexPtr vertex) DROVERRIDE; + virtual void AddManagedVertex(DrVertexPtr vertex, bool splitting) DROVERRIDE; + virtual void RemoveManagedVertex(DrVertexPtr vertex) DROVERRIDE; + virtual void NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage) DROVERRIDE; + + private: + typedef DrDictionary Map; + DRREF(Map); + + MapRef m_map; + }; + +private: + void RegisterVertexInternal(DrVertexPtr vertex, bool registerSplit); + void NotifyUpstreamSplit(DrVertexPtr upstreamVertex, DrVertexPtr baseNewVertexSplitFrom, + int outputPortOfSplitBase, int upstreamSplitIndex); + void NotifyUpstreamVertexRemoval(DrVertexPtr upstreamVertex, int outputPortOfRemovedVertex); + void NotifyUpstreamVertexCompleted(DrActiveVertexPtr upstreamVertex, int upstreamVertexOutputPort, + int executionVersion, DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics); + void NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage); + void NotifyUpstreamInputReady(DrStorageVertexPtr vertex, int upstreamVertexOutputPort, DrAffinityPtr affinity); + + void CheckIfWeHaveCompleted(); + + void AddToRunningMap(DrActiveVertexPtr vertex, int version, DrDateTime runningTime); + void RemoveFromRunningMap(DrActiveVertexPtr vertex, int version); + HolderPtr AddDynamicConnectionManagerInternal(DrManagerBasePtr upstreamStage, + DrConnectionManagerPtr connector); + + HolderPtr LookUpConnectionHolder(DrManagerBasePtr upstreamStage); + DrConnectionManagerPtr LookUpConnectionManager(DrVertexPtr vertex, DrManagerBasePtr upstreamStage); + + DrGraphRef m_graph; + DrString m_stageName; + bool m_includeInJobStageList; + bool m_stillAddingVertices; + int m_verticesNotYetCompleted; + bool m_weHaveCompleted; + + DrStageStatisticsRef m_stageStatistics; + DrVertexListRef m_vertices; + HolderListRef m_holder; + DrStageSetRef m_downStreamStages; + RunningMapRef m_runningMap; + RunningTimeMapRef m_runningTimeMap; +}; +DRREF(DrManagerBase); diff --git a/GraphManager/stagemanager/DrDynamicAggregateManager.cpp b/GraphManager/stagemanager/DrDynamicAggregateManager.cpp new file mode 100644 index 0000000..d71dbab --- /dev/null +++ b/GraphManager/stagemanager/DrDynamicAggregateManager.cpp @@ -0,0 +1,1594 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +DrDamCompletedVertex::DrDamCompletedVertex(DrVertexPtr vertex, DrResourcePtr machine, UINT64 outputSize, + int outputPort) +{ + m_vertex = vertex; + m_outputPort = outputPort; + + m_affinity = DrNew DrAffinity(); + m_affinity->SetWeight(outputSize); + m_affinity->GetLocalityArray()->Add(machine); + + m_numberOfLocations = 1; + + m_machine = DrNew DrResourceArray(1); + m_group = DrNew DrDamVertexGroupArray(1); + m_groupIndex = DrNew DrIntArray(1); + + m_machine[0] = machine; + m_group[0] = DrNull; + m_groupIndex[0] = -1; +} + +DrDamCompletedVertex::DrDamCompletedVertex(DrVertexPtr /* ununsed vertex */, + DrAffinityPtr affinity, + int /* ununsed outputPort */) +{ + m_affinity = affinity; + + m_numberOfLocations = m_affinity->GetLocalityArray()->Size(); + if (m_numberOfLocations == 0) + { + m_numberOfLocations = 1; + } + + m_machine = DrNew DrResourceArray(m_numberOfLocations); + m_group = DrNew DrDamVertexGroupArray(m_numberOfLocations); + m_groupIndex = DrNew DrIntArray(m_numberOfLocations); + + if (m_affinity->GetLocalityArray()->Size() == 0) + { + DrAssert(m_numberOfLocations == 1); + m_machine[0] = DrNull; + m_group[0] = DrNull; + m_groupIndex[0] = -1; + } + else + { + int i; + for (i=0; iGetLocalityArray()[i]; + m_group[i] = DrNull; + m_groupIndex[i] = -1; + } + } +} + +void DrDamCompletedVertex::RemovePodDuplicates() +{ + int newNumberOfLocations = 0; + int i; + for (i=0; iGetParent() == m_machine[i]->GetParent()) + { + break; + } + } + if (j == newNumberOfLocations) + { + m_machine[j] = m_machine[i]; + ++newNumberOfLocations; + } + } + + m_numberOfLocations = newNumberOfLocations; +} + +DrDamCompletedVertexRef DrDamCompletedVertex::MakeCopy() +{ + return DrNew DrDamCompletedVertex(m_vertex, m_affinity, m_outputPort); +} + +DrVertexPtr DrDamCompletedVertex::GetVertex() +{ + return m_vertex; +} + +int DrDamCompletedVertex::GetOutputPort() +{ + return m_outputPort; +} + +UINT64 DrDamCompletedVertex::GetOutputSize() +{ + return m_affinity->GetWeight(); +} + +DrResourceArrayRef DrDamCompletedVertex::GetMachineArray() +{ + return m_machine; +} + +DrDamVertexGroupArrayRef DrDamCompletedVertex::GetGroupArray() +{ + return m_group; +} + +DrIntArrayRef DrDamCompletedVertex::GetGroupIndexArray() +{ + return m_groupIndex; +} + +int DrDamCompletedVertex::GetNumberOfLocations() +{ + return m_numberOfLocations; +} + +void DrDamCompletedVertex::SetGroup(int locationIndex, + DrDamVertexGroupPtr group, + int groupIndex) +{ + DrAssert(locationIndex < m_numberOfLocations); + m_group[locationIndex] = group; + m_groupIndex[locationIndex] = groupIndex; +} + +int DrDamCompletedVertex::GetLocationIndex(DrDamVertexGroupPtr group, int groupIndex) +{ + int locationIndex; + + for (locationIndex = m_numberOfLocations; locationIndex-- > 0; ) + { + if ((m_group[locationIndex] == group) && + (m_groupIndex[locationIndex] == groupIndex)) + { + return locationIndex; + } + } + + return MAX_INT32; +} + +void DrDamCompletedVertex::MoveGroupIndex(DrDamVertexGroupPtr group, + int newIndex, int oldIndex) +{ + int i; + + for (i=0; iRemoveVertex(this, m_groupIndex[i]); + m_group[i] = DrNull; + m_groupIndex[i] = ((int) -1); + } + } + } +} + + +DrDamVertexGroup::DrDamVertexGroup(int maxVerticesInGroup) +{ + m_maxVertices = maxVerticesInGroup; + m_vertex = DrNew DrDamCompletedVertexList(s_initialArraySize); + m_numberOfVertices = 0; + m_combinedOutputSize = 0; + m_maxNumberOfLocations = 0; +} + +void DrDamVertexGroup::Discard() +{ + ClaimVertices(); + + m_vertex = DrNew DrDamCompletedVertexList(s_initialArraySize); + m_numberOfVertices = 0; + m_combinedOutputSize = 0; + m_maxNumberOfLocations = 0; +} + +void DrDamVertexGroup::AddVertex(DrDamCompletedVertexPtr v, + int locationIndex) +{ + DrAssert(m_numberOfVertices < m_maxVertices); + + m_vertex->Add(v); + v->SetGroup(locationIndex, this, m_vertex->Size() - 1); + + ++m_numberOfVertices; + + m_combinedOutputSize += v->GetOutputSize(); + if (v->GetNumberOfLocations() > m_maxNumberOfLocations) + { + m_maxNumberOfLocations = v->GetNumberOfLocations(); + } + + if (m_vertex->Size() == m_vertex->Allocated()) + { + Compact(); + } +} + +void DrDamVertexGroup::RemoveVertex(DrDamCompletedVertexPtr v, + int groupIndex) +{ + DrAssert(m_numberOfVertices > 0); + DrAssert(groupIndex < m_vertex->Size()); + DrAssert(m_vertex[groupIndex] == v); + DrAssert(m_combinedOutputSize >= v->GetOutputSize()); + m_combinedOutputSize -= v->GetOutputSize(); + m_vertex[groupIndex] = DrNull; + --m_numberOfVertices; + if (m_numberOfVertices == 0) + { + m_maxNumberOfLocations = 0; + } +} + +void DrDamVertexGroup::ClaimVertices() +{ + int i; + for (i=0; iSize(); ++i) + { + if (m_vertex[i] != DrNull) + { + m_vertex[i]->RemoveDuplicateVerticesFromGroups(this); + } + } + m_maxNumberOfLocations = 1; +} + +int DrDamVertexGroup::GetGroupSize() +{ + return m_numberOfVertices; +} + +DrDamCompletedVertexListPtr DrDamVertexGroup::GetGroupArray() +{ + return m_vertex; +} + +UINT64 DrDamVertexGroup::GetCombinedOutputSize() +{ + return m_combinedOutputSize; +} + +int DrDamVertexGroup::GetMaxNumberOfLocations() +{ + return m_maxNumberOfLocations; +} + +void DrDamVertexGroup::Compact() +{ + int newSlotsUsed = 0; + int i; + for (i=0; iSize(); ++i) + { + if (m_vertex[i] != DrNull) + { + if (i > newSlotsUsed) + { + m_vertex[newSlotsUsed] = m_vertex[i]; + m_vertex[i] = DrNull; + m_vertex[newSlotsUsed]->MoveGroupIndex(this, newSlotsUsed, i); + } + ++newSlotsUsed; + } + } + + DrAssert(newSlotsUsed == m_numberOfVertices); + while (m_vertex->Size() > m_numberOfVertices) + { + DrAssert(m_vertex[m_vertex->Size() - 1] == DrNull); + m_vertex->RemoveAt(m_vertex->Size() - 1); + } +} + +void DrDamVertexGroup::DisconnectFromSuccessor(DrVertexPtr successor) +{ + DrAssert(m_numberOfVertices > 0); + + int i; + for (i=0; iSize(); ++i) + { + if (m_vertex[i] != DrNull) + { + DrVertexPtr base = m_vertex[i]->GetVertex(); + int basePort = m_vertex[i]->GetOutputPort(); + + DrEdge edge = base->GetOutputs()->GetEdge(basePort); + DrAssert(edge.m_remoteVertex == successor); + int successorPort = edge.m_remotePort; + successor->DisconnectInput(successorPort, false); + } + } + + /* leave the removed edges as DrNull in the successor to avoid doing + n^2 operations by compacting every time we remove a group. As + long as we compact out the DrNull edges by the time the successor + is ready to execute, things will be fine */ +} + +void DrDamVertexGroup::ConnectToSuccessor(DrVertexPtr successor) +{ + DrAssert(m_numberOfVertices > 0); + successor->GetInputs()->SetNumberOfEdges(m_numberOfVertices); + + int i; + int usedBases = 0; + for (i=0; iSize(); ++i) + { + if (m_vertex[i] != DrNull) + { + DrVertexPtr base = m_vertex[i]->GetVertex(); + int basePort = m_vertex[i]->GetOutputPort(); + + DrEdge e = base->GetOutputs()->GetEdge(basePort); + DrEdge re; + re.m_remoteVertex = base; + re.m_remotePort = basePort; + re.m_type = e.m_type; + successor->GetInputs()->SetEdge(usedBases, re); + + e.m_remoteVertex = successor; + e.m_remotePort = usedBases; + base->GetOutputs()->SetEdge(basePort, e); + + ++usedBases; + + m_vertex[i] = DrNull; + } + } + + DrAssert(usedBases == m_numberOfVertices); + + m_numberOfVertices = 0; + m_vertex = DrNew DrDamCompletedVertexList(); + m_combinedOutputSize = 0; + m_maxNumberOfLocations = 0; +} + +DrDamPartiallyGroupedLayer:: + DrDamPartiallyGroupedLayer(DrGraphPtr graph, DrStageStatisticsPtr statistics, + DrDynamicAggregateManagerPtr parent, + DrVertexPtr internalVertex, + int aggregationLevel, DrString name) +{ + m_name = name; + m_parent = parent; + m_aggregationLevel = aggregationLevel; + m_numberOfInternalCreated = 0; + m_delayGrouping = false; + + if (m_aggregationLevel > 0) + { + DrAssert(internalVertex != DrNull); + + m_stageManager = DrNew DrManagerBase(graph, name.GetString()); + m_internalVertex = internalVertex; + + /* all the sub-managers should share the same statistics as + the parent internal vertex's manager, since they are all + running vertices of the same class */ + m_stageManager->SetStageStatistics(statistics); + + m_stageManager->SetStillAddingVertices(true); + } + else + { + DrAssert(internalVertex == DrNull); + } + + m_machineGroup = DrNew GroupMap(); + m_podGroup = DrNew GroupMap(); +} + +void DrDamPartiallyGroupedLayer::Discard() +{ + GroupMap::DrEnumerator e = m_machineGroup->GetDrEnumerator(); + while (e.MoveNext()) + { + e.GetValue()->Discard(); + } + + e = m_podGroup->GetDrEnumerator(); + while (e.MoveNext()) + { + e.GetValue()->Discard(); + } + + if (m_overallGroup != DrNull) + { + m_overallGroup->Discard(); + } +} + +void DrDamPartiallyGroupedLayer::SetDelayGrouping(bool delayGrouping) +{ + m_delayGrouping = delayGrouping; +} + +DrString DrDamPartiallyGroupedLayer::GetName() +{ + return m_name; +} + +DrDynamicAggregateManagerPtr DrDamPartiallyGroupedLayer::GetParent() +{ + return m_parent; +} + +DrManagerBasePtr DrDamPartiallyGroupedLayer::GetStageManager() +{ + return m_stageManager; +} + +void DrDamPartiallyGroupedLayer:: + ConsiderSending(DrDamVertexGroupPtr group, + int maxVertices, UINT64 additionalSize) +{ + if (m_delayGrouping) + { + return; + } + + /* if adding this new vertex to group would cause it to + overflow, send the group off and create a new empty one to + put this vertex in */ + if (group->GetGroupSize() == maxVertices || + (group->GetCombinedOutputSize() + additionalSize) > + GetParent()->GetMaxDataPerGroup()) + { + group->ClaimVertices(); + GetParent()->AcceptCompletedGroup(group, m_aggregationLevel); + } + + return; +} + + +// This function combines all input vertices assinged to the group into +// processing vertices by verifying grouping conditions. +// +// Input vertices which are left after such grouping still left in the group, +// so further steps can handle them appropriately. +// +// When nothing left, it destorys a group and returns DrNull, otherwise it +// returns a pointer to a group with unprocessed input vertices. + +DrDamVertexGroupPtr DrDamPartiallyGroupedLayer::SendMinimum(DrDamVertexGroupPtr group, + int minVertices, int maxVertices) +{ + // If source group is DrNull, return it as is -- nothing to do + if (group == DrNull) + { + return group; + } + + int vertexIdx = 0; + DrDamVertexGroupRef tmp_group = DrNull; + DrDamCompletedVertexListRef vertex_array = group->GetGroupArray(); + + // Move input vertices from group to tmp_group while it have at least + // minimum number of vertices OR data size of a vertex is not less than + // required minimum amount (minimum threshold). + while (group->GetGroupSize() >= minVertices || + group->GetCombinedOutputSize() >= GetParent()->GetMinDataPerGroup()) + { + // Allocate temporary group if not done before + if (tmp_group == DrNull) + { + tmp_group = DrNew DrDamVertexGroup(maxVertices); + } + + // Move up to maximum limit of inputs or maximum limit of data per + // single processing vertex + for (; group->GetGroupSize() > 0; ++vertexIdx) + { + DrDamCompletedVertexRef v = vertex_array[vertexIdx]; + if (v == DrNull) + { + continue; + } + + // If group goes out maximum threshold, create new vertex immediately + if ((tmp_group->GetGroupSize() == maxVertices) || + ((tmp_group->GetCombinedOutputSize() + v->GetOutputSize()) > GetParent()->GetMaxDataPerGroup())) + { + break; + } + + int locationIdx = v->GetLocationIndex(group, vertexIdx); + if (v->GetNumberOfLocations() > 1) + { + v->RemoveDuplicateVerticesFromGroups(DrNull); + } + else + { + group->RemoveVertex(v, vertexIdx); + } + tmp_group->AddVertex(v, locationIdx); + } + + // Generate processing vertex and clean temporary group + GetParent()->AcceptCompletedGroup(tmp_group, m_aggregationLevel); + } + + // If we do not have any vertices in source group -- destroy it too + if (group->GetGroupSize() == 0) + { + group = DrNull; + } + + return group; +} + +void DrDamPartiallyGroupedLayer::ReturnUnGrouped(DrDamVertexGroupPtr group, + DrDamGroupingLevel level) +{ + if (group != DrNull) + { + int groupSize = group->GetGroupSize(); + + group->ClaimVertices(); + DrDamCompletedVertexListRef groupArray = group->GetGroupArray(); + int nSlots = groupArray->Size(); + + int moved = 0; + int i; + for (i=0; iRemoveDuplicateVerticesFromGroups(DrNull); + + switch (level) + { + case DDGL_Machine: + if (GetParent()->GetMaxPerPod() > 0) + { + AddVertexToPodGroup(cv); + break; + } + + case DDGL_Pod: + if (GetParent()->GetMaxOverall() > 0) + { + AddVertexToOverallGroup(cv); + break; + } + + default: + /* we aren't going to group it at this layer so + send it back to the parent to add to the next + aggregation layer */ + GetParent()->ReturnUnGrouped(cv, m_aggregationLevel); + } + ++moved; + + groupArray[i] = DrNull; + } + } + + DrAssert(moved == groupSize); + } +} + +void DrDamPartiallyGroupedLayer::AddVertexToMachineGroup(DrDamCompletedVertexPtr vertex) +{ + int maxVertices = GetParent()->GetMaxPerMachine(); + + if (m_delayGrouping) + { + maxVertices = MAX_INT32; + } + + int nLocations = vertex->GetNumberOfLocations(); + DrResourceArrayRef machine = vertex->GetMachineArray(); + int i; + for (i=0; iTryGetValue(machine[i], group) == false) + { + group = DrNew DrDamVertexGroup(maxVertices); + m_machineGroup->Add(machine[i], group); + } + + ConsiderSending(group, maxVertices, vertex->GetOutputSize()); + group->AddVertex(vertex, i); + } +} + +void DrDamPartiallyGroupedLayer::AddVertexToPodGroup(DrDamCompletedVertexPtr vertex) +{ + int maxVertices = GetParent()->GetMaxPerPod(); + + if (m_delayGrouping) + { + maxVertices = MAX_INT32; + } + + vertex->RemovePodDuplicates(); + + int nLocations = vertex->GetNumberOfLocations(); + DrResourceArrayRef machine = vertex->GetMachineArray(); + int i; + for (i=0; iGetParent(); + } + + DrDamVertexGroupRef group; + if (m_podGroup->TryGetValue(podPtr, group) == false) + { + group = DrNew DrDamVertexGroup(maxVertices); + m_podGroup->Add(podPtr, group); + } + + ConsiderSending(group, maxVertices, vertex->GetOutputSize()); + group->AddVertex(vertex, i); + } +} + +void DrDamPartiallyGroupedLayer:: + AddVertexToOverallGroup(DrDamCompletedVertexPtr vertex) +{ + int maxVertices = GetParent()->GetMaxOverall(); + + if (m_delayGrouping) + { + maxVertices = MAX_INT32; + } + + if (m_overallGroup == DrNull) + { + m_overallGroup = DrNew DrDamVertexGroup(maxVertices); + } + + ConsiderSending(m_overallGroup, maxVertices, vertex->GetOutputSize()); + m_overallGroup->AddVertex(vertex, 0); +} + +// Moves one "topmost" vertex from source group to destination one. +bool DrDamPartiallyGroupedLayer::MoveOneVertex(MachineGroupR grpStruct) +{ + DrDamVertexGroupPtr srcGroup = grpStruct.group; + DrDamVertexGroupPtr dstGroup = grpStruct.outputGroup; + DrDamCompletedVertexListRef vertices = srcGroup->GetGroupArray(); + int numOfSlots = vertices->Size(); + + int vertexIdx; + for(vertexIdx = grpStruct.nextVertex; vertexIdx < numOfSlots; ++vertexIdx) + { + if (vertices[vertexIdx] == DrNull) + { + continue; + } + + DrDamCompletedVertexPtr v = vertices[vertexIdx]; + + int locationIdx = v->GetLocationIndex(srcGroup, vertexIdx); + if (v->GetNumberOfLocations() > 1) + { + v->RemoveDuplicateVerticesFromGroups(DrNull); + } + else + { + srcGroup->RemoveVertex(v, vertexIdx); + } + dstGroup->AddVertex(v, locationIdx); + ++vertexIdx; + break; + } + + // If source group becomes empty -- destroy it, otherwise remember next + // vertex index we need to look at. + if (srcGroup->GetGroupSize() == 0) + { + grpStruct.group = DrNull; + return true; + } + else + { + grpStruct.nextVertex = vertexIdx; + } + + return false; +} + +// Compares two group of machines to follow defined ordering +// +// For small clusters (or large datasets, when number of instances is much more +// than number of machines where those instances are distributed), we need to +// take care about amount of data allocated to particular machine. +// +// TODO: Weights for particular data distributions should be different. We need +// to figure out almost optimal algorithm, which is lightweight and suitable for +// most cases. +// +// TODO: We need to figure out a way when and how we need to switch between +// different lightweight algorithms for individual cases depending on how +// much data we have and how this data is distributed across machines. +int DrDamPartiallyGroupedLayer::MachineGroupCmp(MachineGroupR left, MachineGroupR right) +{ + UINT64 leftSize, rightSize; + UINT64 leftOutput, rightOutput; + + if (right.group == DrNull) + { + if (left.group == DrNull) + { + return 0; + } + return -1; + } + + if (left.group == DrNull) + { + return 1; + } + + // Do not try to assign anything to unknown machines + if (right.machine == DrNull) + { + return -1; + } + + if (left.machine == DrNull) + { + return 1; + } + + leftOutput = left.outputGroup->GetCombinedOutputSize(); + rightOutput = right.outputGroup->GetCombinedOutputSize(); + + if (leftOutput < rightOutput) + { + return -1; + } + + if (leftOutput > rightOutput) + { + return 1; + } + + leftSize = left.group->GetCombinedOutputSize(); + rightSize = right.group->GetCombinedOutputSize(); + + if (leftSize < rightSize) + { + return -1; + } + + if (leftSize > rightSize) + { + return 1; + } + + return 0; +} + +void DrDamPartiallyGroupedLayer::GatherRemainingMachineGroups() +{ + /* first pull out any groups which are big enough. Since removing + a group can cull machines out of a subsequent group (if a + machine has multiple locations), do this full pass before + gathering up the remainders below */ + if (m_delayGrouping) + { + MachineGroupArrayRef weightedMap = DrNew MachineGroupArray(m_machineGroup->GetSize()); + + GroupMapRef newMap = DrNew GroupMap(); + + int numMachines = 0; + GroupMap::DrEnumerator m = m_machineGroup->GetDrEnumerator(); + while (m.MoveNext()) + { + if (m.GetValue()->GetMaxNumberOfLocations() > 1) + { + DrDamVertexGroupRef newGroup = DrNew DrDamVertexGroup(m.GetValue()->GetGroupSize()); + + weightedMap[numMachines].machine = m.GetKey(); + weightedMap[numMachines].group = m.GetValue(); + weightedMap[numMachines].outputGroup = newGroup; + weightedMap[numMachines].nextVertex = 0; + + newMap->Add(m.GetKey(), newGroup); + + ++numMachines; + } + else + { + newMap->Add(m.GetKey(), m.GetValue()); + } + } + + m_machineGroup = newMap; + + // If we have machines with multi-instance vertices, process them to produce almost flat distribution + if (numMachines > 0) + { + for(;;) + { + // Find top-most machine by MachineGroupCmp() sorting condition + MachineGroup topMachine = weightedMap[0]; + int i; + for(i=1; i 0) + { + topMachine = weightedMap[i]; + } + } + + // If top-most machine does not have associated source group, this signals that we + // already allocated all inputs, nothing more to do. + if (topMachine.group == DrNull) + { + break; + } + + // Move a vertex to target group + MoveOneVertex(topMachine); + } + + // Verify that we processed all machines + int unassignedMachinesCount = 0; + int i; + for(i=0; iGetName().GetChars()); + ++unassignedMachinesCount; + } + } + + if (unassignedMachinesCount != 0) + { + fflush(stdout); + } + + DrAssert(unassignedMachinesCount == 0); + } + + // Report generated allocation for debugging purposes + m = m_machineGroup->GetDrEnumerator(); + numMachines = 0; + while (m.MoveNext()) + { + if (m.GetValue() != DrNull) + { + printf("%s: assigned %d inputs, %I64u bytes\n", + (m.GetKey() ? m.GetKey()->GetName().GetChars() : ""), + m.GetValue()->GetGroupSize(), m.GetValue()->GetCombinedOutputSize()); + ++numMachines; + } + } + printf("TOTAL %d machines in use\n", numMachines); + fflush(stdout); + } + + GroupMap::DrEnumerator mm = m_machineGroup->GetDrEnumerator(); + while (mm.MoveNext()) + { + if (mm.GetValue() != DrNull) + { + SendMinimum(mm.GetValue(), GetParent()->GetMinPerMachine(), GetParent()->GetMaxPerMachine()); + } + } + + mm = m_machineGroup->GetDrEnumerator(); + while (mm.MoveNext()) + { + if (mm.GetValue() != DrNull) + { + ReturnUnGrouped(mm.GetValue(), DDGL_Machine); + } + } + + m_machineGroup = DrNew GroupMap(); +} + +void DrDamPartiallyGroupedLayer::GatherRemainingPodGroups() +{ + /* first pull out any groups which are big enough. Since removing + a group can cull machines out of a subsequent group (if a + machine has multiple locations), do this full pass before + gathering up the remainders below */ + GroupMap::DrEnumerator p = m_podGroup->GetDrEnumerator(); + while (p.MoveNext()) + { + SendMinimum(p.GetValue(), GetParent()->GetMinPerPod(), GetParent()->GetMaxPerPod()); + } + + p = m_podGroup->GetDrEnumerator(); + while (p.MoveNext()) + { + ReturnUnGrouped(p.GetValue(), DDGL_Pod); + } + + m_podGroup = DrNew GroupMap(); +} + +void DrDamPartiallyGroupedLayer::GatherRemainingOverallGroups() +{ + if (m_overallGroup != DrNull) + { + SendMinimum(m_overallGroup, GetParent()->GetMinOverall(), GetParent()->GetMaxOverall()); + ReturnUnGrouped(m_overallGroup, DDGL_Overall); + m_overallGroup = DrNull; + } +} + +void DrDamPartiallyGroupedLayer::LastVertexHasCompleted() +{ + GatherRemainingMachineGroups(); + GatherRemainingPodGroups(); + GatherRemainingOverallGroups(); +} + +void DrDamPartiallyGroupedLayer::MakeInternalGroup(DrDamVertexGroupPtr group, + DrVertexPtr successor) +{ + DrLogI("creating internal vertex %d level %d with %d inputs 1 output", + m_numberOfInternalCreated, m_aggregationLevel, group->GetGroupSize()); + + /* we only make new vertices in internal levels */ + DrAssert(m_aggregationLevel > 0); + + group->DisconnectFromSuccessor(successor); + successor->GetInputs()->Compact(successor); + + DrVertexRef vertex = m_internalVertex->MakeCopy(m_numberOfInternalCreated, m_stageManager); + ++m_numberOfInternalCreated; + m_parent->RegisterCreatedVertex(vertex, m_aggregationLevel); + m_stageManager->RegisterVertex(vertex); + + group->ConnectToSuccessor(vertex); + + int successorInputsAfterGroup = successor->GetInputs()->GetNumberOfEdges(); + successor->GetInputs()->GrowNumberOfEdges(successorInputsAfterGroup+1); + + DrActiveVertexPtr av = (DrActiveVertexPtr)successor; + av->GrowPendingVersion(1); + av->GetStartClique()->GrowExternalInputs(1); + + vertex->GetOutputs()->SetNumberOfEdges(1); + vertex->ConnectOutput(0, successor, successorInputsAfterGroup, DCT_File); + + vertex->InitializeForGraphExecution(); + vertex->KickStateMachine(); +} + +DrDynamicAggregateManager::DrDynamicAggregateManager() : DrConnectionManager(true) +{ + InitializeEmpty(); +} + +DrDynamicAggregateManager::DrDynamicAggregateManager(DrVertexPtr dstVertex, + DrManagerBasePtr parent) : + DrConnectionManager(false) +{ + InitializeEmpty(); + m_dstVertex = dstVertex; + SetParent(parent); + + DrDamPartiallyGroupedLayerRef newLayer = + DrNew DrDamPartiallyGroupedLayer(parent->GetGraph(), parent->GetStageStatistics(), + this, DrNull, 0, DrString()); + m_grouping->Add(newLayer); +} + +DrDynamicAggregateManager::~DrDynamicAggregateManager() +{ + int i; + for (i=0; iSize(); ++i) + { + m_grouping[i]->Discard(); + } +} + +void DrDynamicAggregateManager::InitializeEmpty() +{ + SetMaxAggregationLevel(s_maxAggregationLevel); + SetGroupingSettings(s_minGroupSize, s_maxGroupSize); + SetDataGroupingSettings(s_minDataSize, s_maxDataSize, + s_maxDataSizeToConsider); + + m_splitAfterGrouping = false; + m_numberOfSplitCreated = 0; + m_numberOfManagersCreated = 0; + m_delayGrouping = false; + + m_upstreamStage = DrNew DrDefaultStageList(); + m_grouping = DrNew LayerList(); + m_createdMap = DrNew CreatedMap(); +} + +void DrDynamicAggregateManager::CopySettings(DrDynamicAggregateManagerPtr src, + int nameIndex) +{ + m_maxAggregationLevel = src->m_maxAggregationLevel; + m_minPerMachine = src->m_minPerMachine; + m_maxPerMachine = src->m_maxPerMachine; + m_minPerPod = src->m_minPerPod; + m_maxPerPod = src->m_maxPerPod; + m_minOverall = src->m_minOverall; + m_maxOverall = src->m_maxOverall; + m_minDataPerGroup = src->m_minDataPerGroup; + m_maxDataPerGroup = src->m_maxDataPerGroup; + m_maxDataToConsiderGrouping = src->m_maxDataToConsiderGrouping; + + if (src->m_internalVertex == DrNull) + { + m_internalVertex = DrNull; + } + else + { + m_internalVertex = src->m_internalVertex->MakeCopy(nameIndex); + } + + m_splitAfterGrouping = src->m_splitAfterGrouping; + SetDelayGrouping(src->m_delayGrouping); + + DrAssert(src->m_createdMap->GetSize() == 0); + DrAssert(src->m_grouping->Size() == 0); + + int i; + for (i=0; im_upstreamStage->Size(); ++i) + { + m_upstreamStage->Add(src->m_upstreamStage[i]); + } +} + +void DrDynamicAggregateManager::AddUpstreamStage(DrManagerBasePtr upstreamStage) +{ + int i; + for (i=0; iSize(); ++i) + { + if (m_grouping[i]->GetStageManager() == upstreamStage) + { + /* we only record this if it's not one of the internal + stages we have created */ + return; + } + } + + /* nobody is supposed to be adding more stages once any groups + have been formed */ + DrAssert(m_grouping->Size() < 2); + + m_upstreamStage->Add(upstreamStage); +} + +DrConnectionManagerRef + DrDynamicAggregateManager::MakeManagerForVertex(DrVertexPtr vertex, + bool splitting) +{ + if (splitting) + { + /* if we've just created a new vertex by splitting, then that + new vertex gets a (newly reference counted) the base + manager with no dstVertex */ + return this; + } + else + { + DrDynamicAggregateManagerRef newManager = + DrNew DrDynamicAggregateManager(vertex, GetParent()); + newManager->CopySettings(this, m_numberOfManagersCreated); + ++m_numberOfManagersCreated; + DrConnectionManagerPtr cm = newManager; + return cm; + } +} + +void DrDynamicAggregateManager::SetDelayGrouping(bool delayGrouping) +{ + m_delayGrouping = delayGrouping; + int i; + for (i=0; iSize(); ++i) + { + m_grouping[i]->SetDelayGrouping(delayGrouping); + } +} + +bool DrDynamicAggregateManager::GetDelayGrouping() +{ + return m_delayGrouping; +} + +void DrDynamicAggregateManager::SetInternalVertex(DrVertexPtr internalVertex) +{ + m_internalVertex = internalVertex; +} + +void DrDynamicAggregateManager::SetMaxAggregationLevel(int maxAggregation) +{ + m_maxAggregationLevel = maxAggregation; +} + +int DrDynamicAggregateManager::GetMaxAggregationLevel() +{ + return m_maxAggregationLevel; +} + +void DrDynamicAggregateManager::SetGroupingSettings(int minGroupSize, + int maxGroupSize) +{ + SetMachineGroupingSettings(minGroupSize, maxGroupSize); + SetPodGroupingSettings(minGroupSize, maxGroupSize); + SetOverallGroupingSettings(minGroupSize, maxGroupSize); +} + +void DrDynamicAggregateManager::SetMachineGroupingSettings(int minPerMachine, + int maxPerMachine) +{ + m_minPerMachine = minPerMachine; + m_maxPerMachine = maxPerMachine; + if (m_maxPerMachine > 0) + { + DrAssert(m_minPerMachine > 1); + } + DrAssert(m_minPerMachine <= m_maxPerMachine); +} + +int DrDynamicAggregateManager::GetMinPerMachine() +{ + return m_minPerMachine; +} + +int DrDynamicAggregateManager::GetMaxPerMachine() +{ + return m_maxPerMachine; +} + +void DrDynamicAggregateManager::SetPodGroupingSettings(int minPerPod, + int maxPerPod) +{ + m_minPerPod = minPerPod; + m_maxPerPod = maxPerPod; + if (m_maxPerPod > 0) + { + DrAssert(m_minPerPod > 1); + } + DrAssert(m_minPerPod <= m_maxPerPod); +} + +int DrDynamicAggregateManager::GetMinPerPod() +{ + return m_minPerPod; +} + +int DrDynamicAggregateManager::GetMaxPerPod() +{ + return m_maxPerPod; +} + +void DrDynamicAggregateManager::SetOverallGroupingSettings(int minOverall, + int maxOverall) +{ + m_minOverall = minOverall; + m_maxOverall = maxOverall; + if (m_maxOverall > 0) + { + DrAssert(m_minOverall > 1); + } + DrAssert(m_minOverall <= m_maxOverall); +} + +int DrDynamicAggregateManager::GetMinOverall() +{ + return m_minOverall; +} + +int DrDynamicAggregateManager::GetMaxOverall() +{ + return m_maxOverall; +} + +void DrDynamicAggregateManager::SetDataGroupingSettings(UINT64 minDataSize, + UINT64 maxDataSize, + UINT64 maxDataToConsider) +{ + m_minDataPerGroup = minDataSize; + m_maxDataPerGroup = maxDataSize; + m_maxDataToConsiderGrouping = maxDataToConsider; + + DrAssert(m_maxDataPerGroup > 0); + DrAssert(m_minDataPerGroup <= m_maxDataPerGroup); + DrAssert(m_maxDataToConsiderGrouping <= m_maxDataPerGroup); +} + +UINT64 DrDynamicAggregateManager::GetMinDataPerGroup() +{ + return m_minDataPerGroup; +} + +UINT64 DrDynamicAggregateManager::GetMaxDataPerGroup() +{ + return m_maxDataPerGroup; +} + +UINT64 DrDynamicAggregateManager::GetMaxDataToConsiderGrouping() +{ + return m_maxDataToConsiderGrouping; +} + +void DrDynamicAggregateManager::SetSplitAfterGrouping(bool splitAfterGrouping) +{ + m_splitAfterGrouping = splitAfterGrouping; +} + +bool DrDynamicAggregateManager::GetSplitAfterGrouping() +{ + return m_splitAfterGrouping; +} + +DrVertexRef DrDynamicAggregateManager::AddSplitVertex() +{ + DrAssert(m_dstVertex != DrNull); + + DrVertexRef vertex = m_dstVertex->MakeCopy(m_numberOfSplitCreated, GetParent()); + + GetParent()->RegisterVertexSplit(vertex, m_dstVertex, + m_numberOfSplitCreated); + + ++m_numberOfSplitCreated; + + return vertex; +} + +void DrDynamicAggregateManager::AcceptCompletedGroup(DrDamVertexGroupPtr group, + int aggregationLevel) +{ + int maxAggregationLevel = m_maxAggregationLevel; + if (m_internalVertex == DrNull) + { + maxAggregationLevel = 0; + } + + if (aggregationLevel == maxAggregationLevel) + { + /* we're already at the maximum level for aggregation, so + create a new split vertex and connect this group to it */ + DrAssert(m_splitAfterGrouping); + + DrVertexRef newSplit = AddSplitVertex(); + group->DisconnectFromSuccessor(m_dstVertex); + group->ConnectToSuccessor(newSplit); + + /* the new split vertex should be ready to run, so let it go */ + newSplit->InitializeForGraphExecution(); + newSplit->KickStateMachine(); + } + else + { + DrAssert(aggregationLevel < maxAggregationLevel); + DrAssert(m_internalVertex != DrNull); + + ++aggregationLevel; + if (m_grouping->Size() == aggregationLevel) + { + DrLogI("Adding new aggregation level %d", aggregationLevel); + + DrString newName; + if (aggregationLevel == 1) + { + newName.SetF("%s+", m_internalVertex->GetName().GetChars()); + } + else + { + newName.SetF("%s+", m_grouping[aggregationLevel-1]->GetName().GetChars()); + } + + DrDamPartiallyGroupedLayerRef newLayer = + DrNew DrDamPartiallyGroupedLayer(GetParent()->GetGraph(), GetParent()->GetStageStatistics(), + this, m_internalVertex, + aggregationLevel, newName); + DrAssert(newLayer != DrNull); + newLayer->SetDelayGrouping(m_delayGrouping); + m_grouping->Add(newLayer); + + GetParent()-> + AddDynamicConnectionManager(newLayer->GetStageManager(), + this); + } + + DrAssert(m_grouping->Size() > aggregationLevel); + m_grouping[aggregationLevel]->MakeInternalGroup(group, m_dstVertex); + } +} + +void DrDynamicAggregateManager:: + DealWithUngroupableVertex(DrDamCompletedVertexPtr vertex) +{ + if (m_splitAfterGrouping) + { + /* this is a singleton vertex that gets its own new split + vertex */ + DrVertexRef newSplit = AddSplitVertex(); + DrVertexPtr base = vertex->GetVertex(); + int basePort = vertex->GetOutputPort(); + + DrEdge edge = base->GetOutputs()->GetEdge(basePort); + DrAssert(edge.m_remoteVertex == m_dstVertex); + int successorPort = edge.m_remotePort; + m_dstVertex->DisconnectInput(successorPort, false); + + DrEdge re; + re.m_remoteVertex = base; + re.m_remotePort = basePort; + re.m_type = edge.m_type; + + newSplit->GetInputs()->SetNumberOfEdges(1); + newSplit->GetInputs()->SetEdge(0, re); + newSplit->InitializeForGraphExecution(); + newSplit->KickStateMachine(); + } + else + { + /* do nothing and leave this vertex connected to the normal + successor */ + } +} + +/* this is called by a DamPartiallyGroupedLayer to return a vertex that + doesn't fit into any of its groups */ +void DrDynamicAggregateManager::ReturnUnGrouped(DrDamCompletedVertexPtr vertex, + int aggLevel) +{ + ++aggLevel; + if (aggLevel == m_grouping->Size()) + { + /* none of the vertices at this level were grouped, so no next + layer has been created */ + DealWithUngroupableVertex(vertex); + } + else + { + /* pass it on to the next layer in case it fits into one of + the next set of groups */ + AddCompletedVertex(vertex, aggLevel); + } +} + +void DrDynamicAggregateManager::AddCompletedVertex(DrDamCompletedVertexPtr vertex, + int aggLevel) +{ + DrAssert(m_grouping->Size() > aggLevel); + + DrDamPartiallyGroupedLayerPtr layer = m_grouping[aggLevel]; + + int maxAggregationLevel = m_maxAggregationLevel; + if (m_internalVertex == DrNull) + { + maxAggregationLevel = 0; + } + + /* if a vertex output is greater than m_maxDataToConsiderGrouping don't + bother to try to group it. The greedy algorithm we are using is + really only good at grouping together outputs substantially + smaller than m_maxGroupDataSize. If we're at the final + aggregation level and we're not splitting, don't bother to find + the groups, since we're not going to do anything based on + them */ + if (vertex->GetOutputSize() < m_maxDataToConsiderGrouping && + !(aggLevel == maxAggregationLevel && !m_splitAfterGrouping)) + { + /* when we add a vertex to a layer, it will eventually return + to us either as part of a group in a call to + AcceptCompletedGroup() or as a singleton in a call to + ReturnUnGrouped() */ + if (m_maxPerMachine > 0) + { + layer->AddVertexToMachineGroup(vertex); + } + else if (m_maxPerPod > 0) + { + layer->AddVertexToPodGroup(vertex); + } + else + { + DrAssert(m_maxOverall > 0); + layer->AddVertexToOverallGroup(vertex); + } + } + else + { + /* this vertex is going to remain a singleton */ + DealWithUngroupableVertex(vertex); + } +} + +void DrDynamicAggregateManager::RegisterCreatedVertex(DrVertexPtr vertex, + int aggLevel) +{ + m_createdMap->Add(vertex, aggLevel); +} + +void DrDynamicAggregateManager:: + NotifyUpstreamVertexCompleted(DrActiveVertexPtr vertex, int outputPort, + int /* unused executionVersion */, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics) +{ + if (vertex->GetNumberOfReportedCompletions() > 0) + { + DrLogI("Ignoring completion since vertex has previously completed %d times", + vertex->GetNumberOfReportedCompletions()); + return; + } + + int aggLevel = 0; + /* look up to see if this is an internal vertex created by us */ + m_createdMap->TryGetValue(vertex, aggLevel); + + UINT64 outputSize = statistics->m_outputData[outputPort]->m_dataWritten; + AddCompletedVertex(DrNew DrDamCompletedVertex(vertex, machine, outputSize, outputPort), + aggLevel); +} + +void DrDynamicAggregateManager:: + NotifyUpstreamInputReady(DrStorageVertexPtr vertex, int outputPort, DrAffinityPtr affinity) +{ + AddCompletedVertex(DrNew DrDamCompletedVertex(vertex, affinity, outputPort), + 0); /* the aggregation level is always 0 for + inputs */ +} + +void DrDynamicAggregateManager::CleanUp() +{ + /* that's the end, we won't see any more action. Clean up. */ + + /* first, get rid of any DrNull edges we left lying around that used + to belong to inputs we have now grouped */ + m_dstVertex->GetInputs()->Compact(m_dstVertex); + + if (m_splitAfterGrouping) + { + /* the original "dummy" destination vertex should have been + replaced by some number of clones, one for each group we + ended up with */ + DrAssert(m_dstVertex->GetInputs()->GetNumberOfEdges() == 0); + GetParent()->UnRegisterVertex(m_dstVertex); + DrAssert(m_dstVertex->GetOutputs()->GetNumberOfEdges() == 0); + m_dstVertex->RemoveFromGraphExecution(); + } + else + { + /* the successor should be ready to run: let it rip */ + m_dstVertex->InitializeForGraphExecution(); + m_dstVertex->KickStateMachine(); + } +} + +/* this is called by any upstream stage manager when it isn't going to + send any more vertex completion calls. This includes the internal + aggregation layer stage managers that we created */ +void DrDynamicAggregateManager:: + NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage) +{ + if (m_dstVertex == DrNull) + { + /* we are being called by a split-created vertex, and we are + the base manager that doesn't need to do anything */ + DrAssert(m_grouping->Size() == 0); + return; + } + + int layer = 0; + bool foundLayer = false; + + /* first see if it's one of the "real" upstream stages, i.e. not + created by us */ + int i; + for (i=0; iSize(); ++i) + { + if (m_upstreamStage[i] == upstreamStage) + { + /* remove this from the set of stages we expect to hear from */ + m_upstreamStage->RemoveAt(i); + if (m_upstreamStage->Size() > 0) + { + /* there are still vertices to come from other upstream + stages, so layer 0 is not finished */ + return; + } + foundLayer = true; + break; + } + } + + /* at this point layer = 0 */ + + if (!foundLayer) + { + /* we didn't find anything in the list of registered stages so + look in the ones we created */ + for (layer=1; layer < m_grouping->Size(); ++layer) + { + if (m_grouping[layer]->GetStageManager() == upstreamStage) + { + foundLayer = true; + break; + } + } + } + + /* this is supposed to be some stage we've heard about + before */ + DrAssert(foundLayer); + DrAssert(layer < m_grouping->Size()); + + m_grouping[layer]->LastVertexHasCompleted(); + + if (layer+1 == m_grouping->Size()) + { + /* this means the last vertex we're ever going to see has + completed, since the last vertex in the highest stage has + completed without generating any new stages */ + CleanUp(); + } + else + { + /* this layer isn't going to add any more vertices to the next + layer down */ + m_grouping[layer+1]->GetStageManager()->SetStillAddingVertices(false); + } +} diff --git a/GraphManager/stagemanager/DrDynamicAggregateManager.h b/GraphManager/stagemanager/DrDynamicAggregateManager.h new file mode 100644 index 0000000..ec27654 --- /dev/null +++ b/GraphManager/stagemanager/DrDynamicAggregateManager.h @@ -0,0 +1,267 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRDECLARECLASS(DrDynamicAggregateManager); +DRREF(DrDynamicAggregateManager); + +DRDECLARECLASS(DrDamCompletedVertex); +DRREF(DrDamCompletedVertex); + +DRDECLARECLASS(DrDamVertexGroup); +DRREF(DrDamVertexGroup); + +typedef DrArray DrDamVertexGroupArray; +DRAREF(DrDamVertexGroupArray,DrDamVertexGroupRef); + +DRBASECLASS(DrDamCompletedVertex) +{ +public: + DrDamCompletedVertex(DrVertexPtr vertex, DrResourcePtr machine, UINT64 outputSize, int outputPort); + DrDamCompletedVertex(DrVertexPtr vertex, DrAffinityPtr affinity, int outputPort); + + DrDamCompletedVertexRef MakeCopy(); + void RemovePodDuplicates(); + + DrVertexPtr GetVertex(); + int GetOutputPort(); + UINT64 GetOutputSize(); + DrResourceArrayRef GetMachineArray(); + DrDamVertexGroupArrayRef GetGroupArray(); + DrIntArrayRef GetGroupIndexArray(); + int GetNumberOfLocations(); + + void SetGroup(int locationIndex, DrDamVertexGroupPtr group, int groupIndex); + void MoveGroupIndex(DrDamVertexGroupPtr group, int newIndex, int oldIndex); + void RemoveDuplicateVerticesFromGroups(DrDamVertexGroupPtr keepGroup); + int GetLocationIndex(DrDamVertexGroupPtr group, int groupIndex); + +private: + DrVertexRef m_vertex; + int m_outputPort; + int m_numberOfLocations; + DrAffinityRef m_affinity; + DrResourceArrayRef m_machine; + DrDamVertexGroupArrayRef m_group; + DrIntArrayRef m_groupIndex; +}; +DRREF(DrDamCompletedVertex); + +typedef DrArrayList DrDamCompletedVertexList; +DRAREF(DrDamCompletedVertexList,DrDamCompletedVertexRef); + +DRBASECLASS(DrDamVertexGroup) +{ +public: + DrDamVertexGroup(int maxVerticesInGroup); + + void AddVertex(DrDamCompletedVertexPtr v, int locationIndex); + void RemoveVertex(DrDamCompletedVertexPtr v, int groupIndex); + void ClaimVertices(); + void Discard(); + void DisconnectFromSuccessor(DrVertexPtr successor); + void ConnectToSuccessor(DrVertexPtr successor); + + int GetGroupSize(); + DrDamCompletedVertexListPtr GetGroupArray(); + UINT64 GetCombinedOutputSize(); + int GetMaxNumberOfLocations(); + +private: + static const int s_initialArraySize = 4; + void Compact(); + + UINT64 m_combinedOutputSize; + DrDamCompletedVertexListRef m_vertex; + int m_numberOfVertices; + int m_maxVertices; + int m_maxNumberOfLocations; +}; + +enum DrDamGroupingLevel +{ + DDGL_Machine, + DDGL_Pod, + DDGL_Overall +}; + +DRVALUECLASS(MachineGroup) +{ + public: + DrResourceRef machine; + DrDamVertexGroupRef group; + DrDamVertexGroupRef outputGroup; + int nextVertex; +}; +DRRREF(MachineGroup); +DRMAKEARRAY(MachineGroup); + +DRBASECLASS(DrDamPartiallyGroupedLayer) +{ +public: + DrDamPartiallyGroupedLayer(DrGraphPtr graph, DrStageStatisticsPtr statistics, + DrDynamicAggregateManagerPtr parent, + DrVertexPtr internalVertex, + int aggregationLevel, DrString name); + + void Discard(); + + DrString GetName(); + DrDynamicAggregateManagerPtr GetParent(); + DrManagerBasePtr GetStageManager(); + + void AddVertexToMachineGroup(DrDamCompletedVertexPtr vertex); + void AddVertexToPodGroup(DrDamCompletedVertexPtr vertex); + void AddVertexToOverallGroup(DrDamCompletedVertexPtr vertex); + void GatherRemainingMachineGroups(); + void GatherRemainingPodGroups(); + void GatherRemainingOverallGroups(); + void LastVertexHasCompleted(); + void SetDelayGrouping(bool delayGrouping); + + void MakeInternalGroup(DrDamVertexGroupPtr group, DrVertexPtr successor); + +private: + typedef DrDictionary GroupMap; + DRREF(GroupMap); + + + + void ConsiderSending(DrDamVertexGroupPtr group, int maxVertices, UINT64 additionalSize); + DrDamVertexGroupPtr SendMinimum(DrDamVertexGroupPtr group, int minVertices, int maxVertices); + void ReturnUnGrouped(DrDamVertexGroupPtr group, DrDamGroupingLevel level); + static int MachineGroupCmp(MachineGroupR left, MachineGroupR right); + bool MoveOneVertex(MachineGroupR grpStruct); + + DrString m_name; + DrDynamicAggregateManagerPtr m_parent; + DrManagerBaseRef m_stageManager; + DrVertexRef m_internalVertex; + int m_numberOfInternalCreated; + int m_aggregationLevel; + GroupMapRef m_machineGroup; + GroupMapRef m_podGroup; + DrDamVertexGroupRef m_overallGroup; + bool m_delayGrouping; +}; +DRREF(DrDamPartiallyGroupedLayer); + + +DRCLASS(DrDynamicAggregateManager) : public DrConnectionManager +{ +public: + DrDynamicAggregateManager(); + ~DrDynamicAggregateManager(); + + static const int s_minGroupSize = 16; + static const int s_maxGroupSize = 128; + static const int s_maxAggregationLevel = 256; + static const UINT64 s_minDataSize = 896*1024*1024; + static const UINT64 s_maxDataSize = 1024*1024*1024; /* 1 GByte */ + static const UINT64 s_maxDataSizeToConsider = 512*1024*1024; + + /* if no internal vertex is set, then the maxAggregationLevel will + be taken to be zero regardless of what is set in + SetMaxAggregationLevel */ + void SetInternalVertex(DrVertexPtr internalVertex); + + void SetMaxAggregationLevel(int maxAggregation); + int GetMaxAggregationLevel(); + + void SetGroupingSettings(int minGroupSize, + int maxGroupSize); + + void SetMachineGroupingSettings(int minPerMachine, + int maxPerMachine); + int GetMinPerMachine(); + int GetMaxPerMachine(); + + void SetPodGroupingSettings(int minPerPod, + int maxPerPod); + int GetMinPerPod(); + int GetMaxPerPod(); + + void SetOverallGroupingSettings(int minOverall, + int maxOverall); + int GetMinOverall(); + int GetMaxOverall(); + + void SetDataGroupingSettings(UINT64 minDataSize, UINT64 maxDataSize, + UINT64 maxDataToConsider); + UINT64 GetMinDataPerGroup(); + UINT64 GetMaxDataPerGroup(); + UINT64 GetMaxDataToConsiderGrouping(); + + void SetSplitAfterGrouping(bool splitAfterGrouping); + bool GetSplitAfterGrouping(); + + void SetDelayGrouping(bool delayGrouping); + bool GetDelayGrouping(); + + virtual DrConnectionManagerRef MakeManagerForVertex(DrVertexPtr vertex, bool splitting) DROVERRIDE; + virtual void AddUpstreamStage(DrManagerBasePtr upstreamStage) DROVERRIDE; + + virtual void NotifyUpstreamVertexCompleted(DrActiveVertexPtr vertex, + int outputPort, int executionVersion, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics) DROVERRIDE; + virtual void NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage) DROVERRIDE; + virtual void NotifyUpstreamInputReady(DrStorageVertexPtr vertex, int outputPort, + DrAffinityPtr affinity) DROVERRIDE; + + typedef DrArrayList LayerList; + DRAREF(LayerList,DrDamPartiallyGroupedLayerRef); + typedef DrDictionary CreatedMap; + DRREF(CreatedMap); + + DrDynamicAggregateManager(DrVertexPtr dstVertex, DrManagerBasePtr parent); + + void InitializeEmpty(); + void CopySettings(DrDynamicAggregateManagerPtr src, int nameIndex); + DrVertexRef AddSplitVertex(); + void AcceptCompletedGroup(DrDamVertexGroupPtr group, int aggregationLevel); + void DealWithUngroupableVertex(DrDamCompletedVertexPtr vertex); + void ReturnUnGrouped(DrDamCompletedVertexPtr vertex, int aggLevel); + void AddCompletedVertex(DrDamCompletedVertexPtr vertex, int aggLevel); + void RegisterCreatedVertex(DrVertexPtr vertex, int aggLevel); + void CleanUp(); + + int m_maxAggregationLevel; + int m_minPerMachine; + int m_maxPerMachine; + int m_minPerPod; + int m_maxPerPod; + int m_minOverall; + int m_maxOverall; + UINT64 m_minDataPerGroup; + UINT64 m_maxDataPerGroup; + UINT64 m_maxDataToConsiderGrouping; + DrVertexRef m_internalVertex; + DrVertexRef m_dstVertex; + bool m_splitAfterGrouping; + int m_numberOfSplitCreated; + int m_numberOfManagersCreated; + + DrDefaultStageListRef m_upstreamStage; + LayerListRef m_grouping; + CreatedMapRef m_createdMap; + bool m_delayGrouping; +}; diff --git a/GraphManager/stagemanager/DrDynamicBroadcast.cpp b/GraphManager/stagemanager/DrDynamicBroadcast.cpp new file mode 100644 index 0000000..18f98ef --- /dev/null +++ b/GraphManager/stagemanager/DrDynamicBroadcast.cpp @@ -0,0 +1,242 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +#include + +DrDynamicBroadcastManager::DrDynamicBroadcastManager(DrActiveVertexPtr copyVertex) + : DrConnectionManager(false) +{ + m_copyVertex = copyVertex; + m_teeNumber = 0; +} + +void DrDynamicBroadcastManager::RegisterVertex(DrVertexPtr vertex, bool splitting) +{ + DrAssert(!splitting); + + if (m_baseTee == DrNull) + { + m_baseTee = dynamic_cast(vertex); + DrAssert(m_baseTee != DrNull); + } +} + +void DrDynamicBroadcastManager::MaybeMakeRoundRobinPodMachines() +// reorder the machines and return a list +{ + if (m_roundRobinMachines == DrNull) + { + m_roundRobinMachines = DrNew DrResourceList(); + + /* we have to look down a long chain to find who's in the cluster, but it's there somewhere... */ + DrUniversePtr universe = m_copyVertex->GetStageManager()->GetGraph()->GetXCompute()->GetUniverse(); + + { + DrAutoCriticalSection acs(universe->GetResourceLock()); + + DrResourceListRef pods = universe->GetResources(DRL_Rack); + DrIntArrayListRef podIndex = DrNew DrIntArrayList(); + + int i; + for (i=0; iSize(); ++i) + { + podIndex->Add(0); + } + + int podsToFinish = pods->Size(); + DrAssert(podsToFinish > 0); + + int currentPod = 0; + do + { + int currentIndex = podIndex[currentPod]; + if (currentIndex == -1) + { + /* we have already previously exhausted this pod, so just keep going */ + } + else + { + DrResourceListRef podChildren = pods[currentPod]->GetChildren(); + if (currentIndex == podChildren->Size()) + { + /* we have used all the machines from this pod */ + DrAssert(podsToFinish > 0); + --podsToFinish; + } + else + { + m_roundRobinMachines->Add(podChildren[currentIndex]); + podIndex[currentPod] = currentIndex+1; + } + } + + ++currentPod; + if (currentPod == pods->Size()) + { + currentPod = 0; + } + } + while (podsToFinish > 0); + + DrAssert(m_roundRobinMachines->Size() == universe->GetResources(DRL_Computer)->Size()); + } + } +} + + +void DrDynamicBroadcastManager::NotifyUpstreamVertexCompleted(DrActiveVertexPtr vertex, int outputPort, + int /* unused executionVersion */, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics) +// a node upstream of a tee has terminated, expand the tee into a tree +{ + DrEdge oe = vertex->GetOutputs()->GetEdge(outputPort); + DrTeeVertexPtr sourceTee = dynamic_cast((DrVertexPtr) oe.m_remoteVertex); + DrAssert(sourceTee != DrNull); + + UINT64 dataWritten = statistics->m_outputData[0]->m_dataWritten; + + ExpandTee(sourceTee, dataWritten, machine); +} + +void DrDynamicBroadcastManager::ExpandTee(DrTeeVertexPtr sourceTee, UINT64 dataWritten, DrResourcePtr machine) +{ + // how many nodes to expand this stage to + int destinations = sourceTee->GetOutputs()->GetNumberOfEdges(); + if (destinations < s_minConsumers) + { + return; + } + + int copies = (int)(sqrt((double)destinations)); + + // find the pods lazily + MaybeMakeRoundRobinPodMachines(); + + int machines = m_roundRobinMachines->Size(); + DrAssert(machines > 0); + + // If there is only one machine don't ExpandTee + if (machines == 1) + { + return; + } + + if (copies > machines) + { + copies = machines; + } + + int currentMachine; + for (currentMachine=0; currentMachineGetId(), copies); + + int edgesPerNode = destinations / copies; + int nodesWithExtraDestination = destinations % copies; + + DrTeeVertexRef newTee; + DrVertexListRef newVertices = DrNew DrVertexList(); + + int currentDestination = 0; + int copy; + // insert 'copies' broadcast nodes + for (copy=0; copyDrVertex::MakeCopy(m_teeNumber); + DrTeeVertexPtr tee = dynamic_cast((DrVertexPtr) t); + tee->GetStageManager()->RegisterVertex(tee); + + tee->GetInputs()->SetNumberOfEdges(1); + int edges = edgesPerNode + (copy < nodesWithExtraDestination); + tee->GetOutputs()->SetNumberOfEdges(edges); + + if (copy == 0) + { + /* the first 'copy' is just another tee without a copier, since the data is already on this machine */ + downstream = tee; + newTee = tee; + } + else + { + DrVertexRef v = m_copyVertex->DrVertex::MakeCopy(m_teeNumber); + DrActiveVertexPtr newVertex = dynamic_cast((DrVertexPtr) v); + DrAssert(newVertex != DrNull); + + newVertex->GetStageManager()->RegisterVertex(newVertex); + + newVertex->GetInputs()->SetNumberOfEdges(1); + newVertex->GetOutputs()->SetNumberOfEdges(1); + + /* make it prefer the new machine more than the one where the data lives */ + newVertex->GetAffinity()->AddLocality(m_roundRobinMachines[currentMachine]); + newVertex->GetAffinity()->SetWeight(10 * dataWritten); + + newVertex->ConnectOutput(0, tee, 0, DCT_File); + + downstream = newVertex; + } + + newVertices->Add(downstream); + + int i; + for (i=0; iGetOutputs()->GetEdge(currentDestination); + sourceTee->DisconnectOutput(currentDestination, true); + + tee->ConnectOutput(i, e.m_remoteVertex, e.m_remotePort, DCT_File); + } + + sourceTee->ConnectOutput(copy, downstream, 0, DCT_File); + + ++m_teeNumber; + } + + DrAssert(currentDestination == destinations); + + sourceTee->GetOutputs()->Compact(DrNull); + + /* kick all the copy vertices to start them going */ + int kick; + for (kick=0; kickSize(); ++kick) + { + newVertices[kick]->InitializeForGraphExecution(); + newVertices[kick]->KickStateMachine(); + } + + /* now recurse down with the new tee we just created. Since there are a logarithmic number of levels, + I'm not worried about exhausting the stack unless somebody builds a *really* big cluster */ + ExpandTee(newTee, dataWritten, machine); +} diff --git a/GraphManager/stagemanager/DrDynamicBroadcast.h b/GraphManager/stagemanager/DrDynamicBroadcast.h new file mode 100644 index 0000000..ecd9aca --- /dev/null +++ b/GraphManager/stagemanager/DrDynamicBroadcast.h @@ -0,0 +1,62 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRCLASS(DrDynamicBroadcastManager) : public DrConnectionManager +{ + /* + This connection manager should have always a Tee as a destination. + I.e., it is placed on the S >= T edge below. (T is a Tee.) + + (Ideally it should have been placed on the T >= C edge, but Tee + vertices don't emit stage events right now.) + + From + + (S >= T)^k >=^k (C^n) + + it builds something like: + + (S >= T >= (copy >= T)^(sqrt(n)))^k >=^k (C ^ n) + + Where the operator >=^k is >= applied k times. + */ + +public: + DrDynamicBroadcastManager(DrActiveVertexPtr copyVertex); + + virtual void NotifyUpstreamVertexCompleted(DrActiveVertexPtr vertex, int outputPort, + int executionVersion, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics) DROVERRIDE; + virtual void RegisterVertex(DrVertexPtr vertex, bool splitting) DROVERRIDE; + +private: + void MaybeMakeRoundRobinPodMachines(); + void ExpandTee(DrTeeVertexPtr sourceTee, UINT64 dataWritten, DrResourcePtr machine); + + DrTeeVertexRef m_baseTee; + DrActiveVertexRef m_copyVertex; // inserted as a layer + DrResourceListRef m_roundRobinMachines; // machines ordered to repeat pods as rarely as possible + int m_teeNumber; // used to renumber the copies + static const int s_minConsumers = 5; // do not create broadcast copies if there are fewer than this many consumers +}; +DRREF(DrDynamicBroadcastManager); \ No newline at end of file diff --git a/GraphManager/stagemanager/DrDynamicDistributor.cpp b/GraphManager/stagemanager/DrDynamicDistributor.cpp new file mode 100644 index 0000000..1056ec1 --- /dev/null +++ b/GraphManager/stagemanager/DrDynamicDistributor.cpp @@ -0,0 +1,388 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/* dynamicdistributor.cpp: + create dynamically a distribution layer + */ + +#include + +DrDynamicDistributionManager::DrDynamicDistributionManager(DrVertexPtr internalVertex, + DrConnectionManagerPtr newConnectionManager) + : DrConnectionManager(false) +{ + SetDataPerVertex(s_dataPerVertex); + m_combinedOutputSize = 0; + m_internalVertex = internalVertex; + m_newConnectionManager = newConnectionManager; + + m_stageSet = DrNew DrStageSet(); + m_sourcesSet = DrNew SourcesSet(); +} + +void DrDynamicDistributionManager::RegisterVertex(DrVertexPtr vertex, bool splitting) +{ + if (m_dstVertex) + { + // a newly created replica: nothing to do + DrAssert(splitting); + } + else + { + m_dstVertex = vertex; + } +} + +void DrDynamicDistributionManager::SetDataPerVertex(UINT64 dataPerNode) +{ + DrAssert(dataPerNode > 0); + m_dataPerVertex = dataPerNode; +} + +UINT64 DrDynamicDistributionManager::GetDataPerVertex() +{ + return m_dataPerVertex; +} + +void DrDynamicDistributionManager::AddUpstreamStage(DrManagerBasePtr upstreamStage) +{ + m_stageSet->Add(upstreamStage); +} + +void DrDynamicDistributionManager::NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage) +{ + bool removed = m_stageSet->Remove(upstreamStage); + DrAssert(removed); + + if (m_stageSet->GetSize() > 0) + { + return; + } + + /* all our upstream stages have completed. */ + + int copies = (int)((m_combinedOutputSize + m_dataPerVertex - 1) / m_dataPerVertex); + // how many nodes to expand this stage to + if (copies == 0) + { + copies = 1; + } + + DrVertexSetRef destCopies = DrNew DrVertexSet(); + DrVertexSetRef distrCopies = DrNew DrVertexSet(); + + DrLogI("Resizing stage for dynamic distribution, new size: %d\n", copies); + + // resize this stage to 'copies' copies + int copy; + for (copy=0; copyMakeCopy(copy+1); + destCopies->Add(newVertex); + + newVertex->GetInputs()->SetNumberOfEdges(m_dstVertex->GetInputs()->GetNumberOfEdges()); + + /* split from one stage to another */ + newVertex->GetStageManager()->RegisterVertexSplit(newVertex, m_dstVertex, copy); + } + + DrConnectorType edgeType = DCT_Tombstone; + int distcopy = 0; + SourcesSet::DrEnumerator e = m_sourcesSet->GetDrEnumerator(); + // disconnect the vertices in m_sourcesSet from m_dstVertex; + while (e.MoveNext()) + { + DrVertexPtr vertex = e.GetKey(); + int outputPort = e.GetValue(); + DrEdge edge = vertex->GetOutputs()->GetEdge(outputPort); + + if (edgeType == DCT_Tombstone) + { + edgeType = edge.m_type; + /* for now we only deal with non-active edges */ + DrAssert(edgeType == DCT_File || edgeType == DCT_Output); + } + else + { + DrAssert(edgeType == edge.m_type); + } + + int inputPort = edge.m_remotePort; + DrVertexPtr inThisStage = edge.m_remoteVertex; + + DrAssert(inThisStage == m_dstVertex); + inThisStage->DisconnectInput(inputPort, true); + + // create a new distributor vertex for the deleted edge + DrVertexRef distributor = m_internalVertex->MakeCopy(distcopy); + m_internalVertex->GetStageManager()->RegisterVertex(distributor); + + distributor->GetInputs()->SetNumberOfEdges(1); + + distributor->GetOutputs()->SetNumberOfEdges(copies); + + vertex->ConnectOutput(outputPort, distributor, 0, DCT_File); + distrCopies->Add(distributor); + + ++distcopy; + } + + distcopy = 0; + // connect the cross-product from distributors to destCopies + // and start the distributors + DrVertexSet::DrEnumerator ve = distrCopies->GetDrEnumerator(); + while (ve.MoveNext()) + { + DrVertexPtr distributor = ve.GetElement(); + + int dest=0; + DrVertexSet::DrEnumerator de = destCopies->GetDrEnumerator(); + while (de.MoveNext()) + { + DrVertexPtr newVertex = de.GetElement(); + + distributor->ConnectOutput(dest, newVertex, distcopy, edgeType); + ++dest; + } + + distributor->InitializeForGraphExecution(); + distributor->KickStateMachine(); + distcopy++; + } + + if (m_newConnectionManager) + { + m_dstVertex->GetStageManager()->AddDynamicConnectionManagerAtRuntime(m_internalVertex->GetStageManager(), + m_newConnectionManager); + } + + m_dstVertex->GetInputs()->Compact(m_dstVertex); + m_dstVertex->GetOutputs()->Compact(DrNull); + DrAssert(m_dstVertex->GetInputs()->GetNumberOfEdges() == 0); + GetParent()->UnRegisterVertex(m_dstVertex);// disconnects it + DrAssert(m_dstVertex->GetOutputs()->GetNumberOfEdges() == 0); + + m_dstVertex->RemoveFromGraphExecution(); + + m_internalVertex->GetStageManager()->SetStillAddingVertices(false); + + DrVertexSet::DrEnumerator startDest = destCopies->GetDrEnumerator(); + while (startDest.MoveNext()) + { + DrVertexPtr newVertex = startDest.GetElement(); + newVertex->InitializeForGraphExecution(); + newVertex->KickStateMachine(); + } +} + + +void DrDynamicDistributionManager::NotifyUpstreamVertexCompleted(DrActiveVertexPtr vertex, int outputPort, + int /* unused executionVersion */, + DrResourcePtr /* unused machine */, + DrVertexExecutionStatisticsPtr statistics) +{ + UINT64 outputSize = statistics->m_outputData[outputPort]->m_dataWritten; + m_combinedOutputSize += outputSize; + + m_sourcesSet->Add(vertex, outputPort); +} + +void DrDynamicDistributionManager::NotifyUpstreamInputReady(DrStorageVertexPtr vertex, int outputPort, + DrAffinityPtr affinity) +{ + m_combinedOutputSize += affinity->GetWeight(); + m_sourcesSet->Add(vertex, outputPort); +} + +/*******************************************************/ +/* Alternative version of hash distributor starts here */ +/* This distributor allocates extra edges, and redistributes just the edges at runtime */ + +DrDynamicHashDistributionManager::DrDynamicHashDistributionManager() + : DrConnectionManager(false) +{ + SetDataPerVertex(s_dataPerVertex); + m_combinedOutputSize = 0; + m_edgesInBundle = 0; + + m_stageSet = DrNew DrStageSet(); + m_sources = DrNew DrVertexSet(); +} + +void DrDynamicHashDistributionManager::RegisterVertex(DrVertexPtr vertex, bool splitting) +{ + if (m_dstVertex) + { + // a newly created replica: nothing to do + DrAssert(splitting); + } + else + { + m_dstVertex = vertex; + } +} + +void DrDynamicHashDistributionManager::SetDataPerVertex(UINT64 dataPerNode) +{ + DrAssert(dataPerNode); + m_dataPerVertex = dataPerNode; +} + +UINT64 DrDynamicHashDistributionManager::GetDataPerVertex() +{ + return m_dataPerVertex; +} + +void DrDynamicHashDistributionManager::AddUpstreamStage(DrManagerBasePtr upstreamStage) +{ + m_stageSet->Add(upstreamStage); +} + + +void DrDynamicHashDistributionManager::NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage) +{ + bool removed = m_stageSet->Remove(upstreamStage); + DrAssert(removed); + + if (m_stageSet->GetSize() > 0) + { + return; + } + + /* all our upstream stages have completed. */ + + int copies = (int)((m_combinedOutputSize + m_dataPerVertex - 1) / m_dataPerVertex); + + // how many nodes to expand this stage to + if (copies == 0) + { + copies = 1; + } + + DrAssert(m_dstVertex); + DrAssert(m_edgesInBundle); + + int numberOfInputs = m_dstVertex->GetInputs()->GetNumberOfEdges(); + DrAssert(numberOfInputs == m_edgesInBundle * m_sources->GetSize()); + + // can't replicate more than the number of inputs + if (m_edgesInBundle < copies) + { + copies = m_edgesInBundle; + } + + DrLogI("Resizing stage for dynamic hash distribution, new size: %d\n", copies); + + int inputsPerVertex = m_edgesInBundle / copies; + int verticesWithExtraInputs = m_edgesInBundle % copies; + + DrVertexListRef destCopies = DrNew DrVertexList(); + + // resize this stage to 'copies' copies + DrStageManagerPtr parent = m_dstVertex->GetStageManager(); + int copy; + for (copy=0; copyMakeCopy(copy); + destCopies->Add(newVertex); + + int numberOfInputs = inputsPerVertex + (copy < verticesWithExtraInputs); + numberOfInputs = numberOfInputs * m_sources->GetSize(); + newVertex->GetInputs()->SetNumberOfEdges(numberOfInputs); + parent->RegisterVertexSplit(newVertex, m_dstVertex, copy); + } + + // disconnect the vertices in m_sources from m_dstVertex; + // and connect them to the proper replica + int sourceNo = 0; + DrVertexSet::DrEnumerator ve = m_sources->GetDrEnumerator(); + while (ve.MoveNext()) + { + // see if this upstream vertex has a correspondent in the new stage + DrVertexPtr vertex = ve.GetElement(); + DrAssert(m_edgesInBundle == vertex->GetOutputs()->GetNumberOfEdges()); + + int outputPort; + for (outputPort=0; outputPortGetOutputs()->GetEdge(outputPort); + + DrVertexPtr inThisStage = edge.m_remoteVertex; + DrAssert(inThisStage == m_dstVertex); + + vertex->DisconnectOutput(outputPort, true); + + // connect to the proper replica + int copyNo = outputPort % copies; + int inputsPerSource = m_edgesInBundle / copies + (copyNo < verticesWithExtraInputs); + int portNo = outputPort / copies + sourceNo * inputsPerSource; + + DrVertexPtr copy = destCopies[copyNo]; + + vertex->ConnectOutput(outputPort, copy, portNo, edge.m_type); + } + + ++sourceNo; + } + + m_dstVertex->GetInputs()->Compact(m_dstVertex); + DrAssert(m_dstVertex->GetInputs()->GetNumberOfEdges() == 0); + + m_dstVertex->GetOutputs()->Compact(DrNull); + GetParent()->UnRegisterVertex(m_dstVertex); // disconnects it + DrAssert(m_dstVertex->GetOutputs()->GetNumberOfEdges() == 0); + + m_dstVertex->RemoveFromGraphExecution(); + + for (copy=0; copyInitializeForGraphExecution(); + destCopies[copy]->KickStateMachine(); + } +} + + +void DrDynamicHashDistributionManager::NotifyUpstreamVertexCompleted(DrActiveVertexPtr vertex, int outputPort, + int /* unused executionVersion */, + DrResourcePtr /* unused machine */, + DrVertexExecutionStatisticsPtr statistics) +{ + UINT64 outputSize = statistics->m_outputData[outputPort]->m_dataWritten; + + m_combinedOutputSize += outputSize; + m_sources->Add(vertex); + + if (m_edgesInBundle) + { + DrAssert(m_edgesInBundle == vertex->GetOutputs()->GetNumberOfEdges()); + } + else + { + m_edgesInBundle = vertex->GetOutputs()->GetNumberOfEdges(); + } +} + + +void DrDynamicHashDistributionManager::NotifyUpstreamInputReady(DrStorageVertexPtr /* ununsed vertex */, + int /* unused outputPort */, + DrAffinityPtr affinity) +{ + m_combinedOutputSize += affinity->GetWeight(); +} diff --git a/GraphManager/stagemanager/DrDynamicDistributor.h b/GraphManager/stagemanager/DrDynamicDistributor.h new file mode 100644 index 0000000..c509a05 --- /dev/null +++ b/GraphManager/stagemanager/DrDynamicDistributor.h @@ -0,0 +1,124 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRDECLARECLASS(DrDynamicDistributionManager); +DRREF(DrDynamicDistributionManager); + +DRCLASS(DrDynamicDistributionManager) : public DrConnectionManager +{ + /* + If internal vertex is A, the source vertices are B, and the + contents of the parent stage is C, the graph is changed as + follows when all B's complete: + + From: + (B,B,B) >> C + + To: + ((B>=A),(B>=A),(B>=A)) >> (C,C) + */ + +public: + // the internal vertex is the distributor + // the newConnectionManager is the manager that will handle the C layer after expansion + // (replacing this) + DrDynamicDistributionManager(DrVertexPtr internalVertex, + DrConnectionManagerPtr newConnectionManager); + + void SetDataPerVertex(UINT64 dataPerVertex); + UINT64 GetDataPerVertex(); + + virtual void AddUpstreamStage(DrManagerBasePtr upstreamStage) DROVERRIDE; + virtual void NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage) DROVERRIDE; + virtual void NotifyUpstreamVertexCompleted(DrActiveVertexPtr vertex, int outputPort, + int executionVersion, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics) DROVERRIDE; + virtual void NotifyUpstreamInputReady(DrStorageVertexPtr vertex, int outputPort, + DrAffinityPtr affinity) DROVERRIDE; + virtual void RegisterVertex(DrVertexPtr vertex, bool splitting) DROVERRIDE; + + private: + typedef DrDictionary SourcesSet; + DRREF(SourcesSet); + + static const UINT64 s_dataPerVertex = 1024 * 1024 * 1024; /* create one new vertex downstream + for each 1G by default */ + DrStageSetRef m_stageSet; + SourcesSetRef m_sourcesSet; + DrVertexRef m_dstVertex; // the vertex managed by this connection manager + UINT64 m_dataPerVertex; + UINT64 m_combinedOutputSize; + DrVertexRef m_internalVertex; // actual distributor vertex + DrConnectionManagerRef m_newConnectionManager; +}; + + +DRDECLARECLASS(DrDynamicHashDistributionManager); +DRREF(DrDynamicHashDistributionManager); + +DRCLASS(DrDynamicHashDistributionManager) : public DrConnectionManager +{ + /* + The Source vertices are B, and the contents of the parent stage + is C, the graph is changed as follows when all B's complete: + + From: + (B,B,B) >=^n C + (n parallel connections from B to C) + + To: + (B,B,B) >=^(n/2) (C,C) + + I.e., C is replicated, and its n inputs are redistributed among + the copies round-robin. + + Each B vertex must have the same number of connections to C. + There must be no outputs of B going to some other vertex than C. + */ + +public: + DrDynamicHashDistributionManager(); + + void SetDataPerVertex(UINT64 dataPerVertex); + UINT64 GetDataPerVertex(); + + virtual void AddUpstreamStage(DrManagerBasePtr upstreamStage) DROVERRIDE; + virtual void NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage) DROVERRIDE; + virtual void NotifyUpstreamVertexCompleted(DrActiveVertexPtr vertex, int outputPort, + int executionVersion, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics) DROVERRIDE; + virtual void NotifyUpstreamInputReady(DrStorageVertexPtr vertex, int outputPort, + DrAffinityPtr affinity) DROVERRIDE; + virtual void RegisterVertex(DrVertexPtr vertex, bool splitting) DROVERRIDE; + +private: + static const UINT64 s_dataPerVertex = 1024 * 1024 * 1024; /* create one new vertex downstream + for each 1G by default */ + DrStageSetRef m_stageSet; + DrVertexSetRef m_sources; + DrVertexRef m_dstVertex; // the vertex managed by this connection manager + UINT64 m_dataPerVertex; + UINT64 m_combinedOutputSize; + int m_edgesInBundle; // edges coming from each input +}; diff --git a/GraphManager/stagemanager/DrDynamicRangeDistributor.cpp b/GraphManager/stagemanager/DrDynamicRangeDistributor.cpp new file mode 100644 index 0000000..17fb494 --- /dev/null +++ b/GraphManager/stagemanager/DrDynamicRangeDistributor.cpp @@ -0,0 +1,156 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +DrDynamicRangeDistributionManager::DrDynamicRangeDistributionManager(DrStageManagerPtr dataConsumer, + double samplingRate) + : DrConnectionManager(false) +{ + SetDataPerVertex(s_dataPerVertex); + m_combinedOutputSize = 0; + m_dataConsumer = dataConsumer; + DrAssert(m_dataConsumer != DrNull); + + m_samplingRate = samplingRate; + DrAssert(samplingRate > 0 && samplingRate <= 1); + + m_stageSet = DrNew DrStageSet(); +} + +void DrDynamicRangeDistributionManager::SetDataPerVertex(UINT64 dataPerNode) +{ + DrAssert(dataPerNode); + m_dataPerVertex = dataPerNode; +} + +UINT64 DrDynamicRangeDistributionManager::GetDataPerVertex() +{ + return m_dataPerVertex; +} + +void DrDynamicRangeDistributionManager::AddUpstreamStage(DrManagerBasePtr upstreamStage) +{ + m_stageSet->Add(upstreamStage); +} + +void DrDynamicRangeDistributionManager::NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage) +{ + bool removed = m_stageSet->Remove(upstreamStage); + DrAssert(removed); + + if (m_stageSet->GetSize() > 0) + { + return; + } + + /* all our upstream stages have completed. */ + + int copies = (int)((m_combinedOutputSize/m_samplingRate + m_dataPerVertex - 1) / m_dataPerVertex); + // how many nodes to expand the M stage to + + DrLogI("Resizing stage for dynamic range distribution, new size: %d\n", copies); + + if (copies > 1) + { + DrVertexListRef consumers = m_dataConsumer->GetVertexVector(); + DrAssert(consumers->Size() == 1); + + DrVertexRef dataConsumerVertex = consumers[0]; + + int ins = dataConsumerVertex->GetInputs()->GetNumberOfEdges(); + + DrVertexListRef distributors = DrNew DrVertexList(); // keep track of the sources here + + int i; + for (i=0; iRemoteInputVertex(i); + + distributor->GetOutputs()->GrowNumberOfEdges(copies); + distributors->Add(distributor); + + dataConsumerVertex->DisconnectInput(i, true); + } + + int copy; + for (copy=0; copyMakeCopy(copy); + + // connect all of the inputs of the dataconsumer to the copy + newVertex->GetInputs()->SetNumberOfEdges(ins); + + for (i=0; iConnectOutput(copy, newVertex, i, DCT_File); + } + + /* this connects up the outputs */ + m_dataConsumer->RegisterVertexSplit(newVertex, dataConsumerVertex, copy); + + newVertex->InitializeForGraphExecution(); + newVertex->KickStateMachine(); + } + + dataConsumerVertex->GetInputs()->Compact(dataConsumerVertex); + DrAssert(dataConsumerVertex->GetInputs()->GetNumberOfEdges() == 0); + + m_dataConsumer->UnRegisterVertex(dataConsumerVertex); + DrAssert(dataConsumerVertex->GetOutputs()->GetNumberOfEdges() == 0); + + dataConsumerVertex->RemoveFromGraphExecution(); + } + + DrString arg; + arg.SetF("%d", copies); + m_bucketizer->AddArgumentInternal(arg); +} + +void DrDynamicRangeDistributionManager::RegisterVertex(DrVertexPtr vertex, bool splitting) +{ + DrAssert(!splitting); + + // there should be only one vertex in this stage + DrAssert(m_bucketizer == DrNull); + + DrActiveVertexPtr activeVertex = dynamic_cast(vertex); + DrAssert(activeVertex != DrNull); + + m_bucketizer = activeVertex; +} + +void DrDynamicRangeDistributionManager::NotifyUpstreamVertexCompleted(DrActiveVertexPtr /* unused vertex */, + int outputPort, + int /* unused executionVersion */, + DrResourcePtr /* unused machine */, + DrVertexExecutionStatisticsPtr statistics) +{ + UINT64 outputSize = statistics->m_outputData[outputPort]->m_dataWritten; + m_combinedOutputSize += outputSize; +} + +void DrDynamicRangeDistributionManager::NotifyUpstreamInputReady(DrStorageVertexPtr /* unused vertex */, + int /* unused outputPort */, + DrAffinityPtr affinity) +{ + m_combinedOutputSize += affinity->GetWeight(); +} diff --git a/GraphManager/stagemanager/DrDynamicRangeDistributor.h b/GraphManager/stagemanager/DrDynamicRangeDistributor.h new file mode 100644 index 0000000..9f9fdf9 --- /dev/null +++ b/GraphManager/stagemanager/DrDynamicRangeDistributor.h @@ -0,0 +1,79 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRCLASS(DrDynamicRangeDistributionManager) : public DrConnectionManager +{ + /* + The graph evolves as follows when all S's complete: + + Initial graph: (S,S,S) >= B >= Tee >= (D,D,D) >> M || (S,S,S) >= (D,D,D) + + Final graph: (S,S,S) >= B >= Tee >= (D,D,D) >> (M,M,M,M) || (S,S,S) >= (D,D,D) + + (i.e., M's are expanded) + + S = data source; each has 2 outputs: one for B and one for D. + The B output contains a sample of the values from the D output. + B = computes the buckets for distribution based on the + samples from all S vertices + D = distributor; input 0 reads the bucket boundaries from B, + input 1 reads the data to distribute to the outputs + M = actual data consumer. The dynamic range distributor will replicate + M to the correct number of instances, and will pass this information + as a command-line argument to B + + The DynamicRangeDistributionManager is placed on the edges S>=B. + */ + +public: + // The connection manager will only see the sampled data. To + // correctly estimate the data at the distributors, it needs to + // know the sampling rate as well. The sampling rate is the + // fraction of data that goes through the bucketizer node + // (0 < samplingrate <= 1) + DrDynamicRangeDistributionManager(DrStageManagerPtr dataConsumer /* M */, + double samplingRate); + + void SetDataPerVertex(UINT64 dataPerVertex); + UINT64 GetDataPerVertex(); + + virtual void AddUpstreamStage(DrManagerBasePtr upstreamStage) DROVERRIDE; + virtual void NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage) DROVERRIDE; + virtual void NotifyUpstreamVertexCompleted(DrActiveVertexPtr vertex, int outputPort, + int executionVersion, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics) DROVERRIDE; + virtual void NotifyUpstreamInputReady(DrStorageVertexPtr vertex, + int outputPort, DrAffinityPtr affinity) DROVERRIDE; + virtual void RegisterVertex(DrVertexPtr vertex, bool splitting) DROVERRIDE; + +private: + static const UINT64 s_dataPerVertex = 1024 * 1024 * 1024; /* create one new vertex downstream + for each 1G by default */ + DrStageSetRef m_stageSet; + double m_samplingRate; + UINT64 m_dataPerVertex; + UINT64 m_combinedOutputSize; + DrStageManagerRef m_dataConsumer; + DrActiveVertexRef m_bucketizer; // stage B, on which the current manager is placed +}; +DRREF(DrDynamicRangeDistributionManager); \ No newline at end of file diff --git a/GraphManager/stagemanager/DrPipelineSplitManager.cpp b/GraphManager/stagemanager/DrPipelineSplitManager.cpp new file mode 100644 index 0000000..bdb7a48 --- /dev/null +++ b/GraphManager/stagemanager/DrPipelineSplitManager.cpp @@ -0,0 +1,339 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +DrPipelineSplitManager::DrPipelineSplitManager() : DrConnectionManager(false) +{ + m_stageSet = DrNew DrStageSet(); + m_splitMap = DrNew DrVertexVListMap(); +} + +void DrPipelineSplitManager::AddUpstreamStage(DrManagerBasePtr stage) +{ + m_stageSet->Add(stage); +} + +void DrPipelineSplitManager:: + NotifyUpstreamSplit(DrVertexPtr upstreamVertex, + DrVertexPtr baseNewVertexSplitFrom, + int outputPortOfSplitBase, + int upstreamSplitIndex) +{ + /* find the vertex in our stage that is connected to the upstream + base that is splitting */ + DrEdge e = baseNewVertexSplitFrom->GetOutputs()->GetEdge(outputPortOfSplitBase); + + DrVertexPtr localBaseVertex = e.m_remoteVertex; + int localBasePort = e.m_remotePort; + DrConnectorType originalType = e.m_type; + + /* see if the local vertex has split before */ + DrVertexListRef splitList; + if (m_splitMap->TryGetValue(localBaseVertex, splitList) == false) + { + splitList = DrNew DrVertexList(); + m_splitMap->Add(localBaseVertex, splitList); + } + + /* this is going to be the new vertex in our stage */ + DrVertexPtr newVertex; + + /* either some other upstream stage has already split with this + index number, or it's the next index to split */ + DrAssert(upstreamSplitIndex <= splitList->Size()); + if (upstreamSplitIndex == splitList->Size()) + { + /* it's the next index to split, so we actually have to create + a new vertex */ + newVertex = localBaseVertex->MakeCopy(upstreamSplitIndex); + newVertex->GetInputs()->SetNumberOfEdges(localBaseVertex->GetInputs()->GetNumberOfEdges()); + + /* register this new vertex with ourselves: this will attach + up its output edges based on the policies of the vertices + upstream of localBaseVertex, perhaps propagating the + pipeline split forwards */ + GetParent()->RegisterVertexSplit(newVertex, localBaseVertex, + upstreamSplitIndex); + + splitList->Add(newVertex); + } + else + { + /* some other upstream vertex already triggered the creation + of this vertex */ + newVertex = splitList[upstreamSplitIndex]; + } + + DrAssert(newVertex != DrNull); + + if (upstreamVertex->GetOutputs()->GetNumberOfEdges() <= outputPortOfSplitBase) + { + upstreamVertex->GetOutputs()->GrowNumberOfEdges(outputPortOfSplitBase+1); + } + + int nInputs = newVertex->GetInputs()->GetNumberOfEdges(); + DrAssert(localBasePort < nInputs); + upstreamVertex->ConnectOutput(outputPortOfSplitBase, + newVertex, localBasePort, + originalType); + + int i; + for (i=0; iGetInputs()->GetEdge(i).m_type == DCT_Tombstone) + { + /* this edge isn't connected to anyone yet */ + return; + } + } + + /* All the input edges are connected, so this vertex is ready to run */ + newVertex->InitializeForGraphExecution(); + newVertex->KickStateMachine(); + + /* remove it from the split vector so we won't check it again in + NotifyUpstreamLastVertexCompleted below */ + splitList[upstreamSplitIndex] = DrNull; +} + +void DrPipelineSplitManager:: + NotifyUpstreamVertexRemoval(DrVertexPtr upstreamVertex, + int outputPortOfRemovedVertex) +{ + /* find the vertex in our stage that is connected to the upstream + base that is splitting */ + DrVertexRef localBaseVertex = + upstreamVertex->RemoteOutputVertex(outputPortOfRemovedVertex); + + upstreamVertex->DisconnectOutput(outputPortOfRemovedVertex, true); + + /* shrink the local edge list to get rid of the empty slot we just + left. The upstream vertex will be dealt with by its own stage + manager */ + localBaseVertex->GetInputs()->Compact(localBaseVertex); + + /* if all the upstream edges have been removed, propagate the + deletion forwards then remove the vertex */ + if (localBaseVertex->GetInputs()->GetNumberOfEdges() == 0) + { + GetParent()->UnRegisterVertex(localBaseVertex); + DrAssert(localBaseVertex->GetOutputs()->GetNumberOfEdges() == 0); + localBaseVertex->RemoveFromGraphExecution(); + } +} + +void DrPipelineSplitManager::NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage) +{ + bool removed = m_stageSet->Remove(upstreamStage); + DrAssert(removed == 1); + + if (m_stageSet->GetSize() == 0) + { + /* all our upstream stages have completed. Now go through and + dispatch any split vertices that were still waiting for an + edge */ + DrVertexVListMap::DrEnumerator e = m_splitMap->GetDrEnumerator(); + while (e.MoveNext()) + { + DrVertexListRef list = e.GetValue(); + int i; + for (i=0; iSize(); ++i) + { + DrVertexPtr vertex = list[i]; + if (vertex != DrNull) + { + int nInputs = vertex->GetInputs()->GetNumberOfEdges(); + vertex->GetInputs()->Compact(vertex); + DrAssert(vertex->GetInputs()->GetNumberOfEdges() < nInputs); + vertex->InitializeForGraphExecution(); + vertex->KickStateMachine(); + list[i] = DrNull; + } + } + } + } +} + + + +DrSemiPipelineSplitManager::DrSemiPipelineSplitManager() : DrConnectionManager(false) +{ + m_stageSet = DrNew DrStageSet(); + m_splitMap = DrNew DrVertexVListMap(); +} + +void DrSemiPipelineSplitManager::AddUpstreamStage(DrManagerBasePtr stage) +{ + m_stageSet->Add(stage); +} + +void DrSemiPipelineSplitManager:: + NotifyUpstreamSplit(DrVertexPtr upstreamVertex, + DrVertexPtr baseNewVertexSplitFrom, + int outputPortOfSplitBase, + int upstreamSplitIndex) +{ + /* find the vertex in our stage that is connected to the upstream + base that is splitting */ + DrEdge e = baseNewVertexSplitFrom->GetOutputs()->GetEdge(outputPortOfSplitBase); + + DrVertexPtr localBaseVertex = e.m_remoteVertex; + int localBasePort = e.m_remotePort; + DrConnectorType originalType = e.m_type; + + /* see if the local vertex has split before */ + DrVertexListRef splitList; + if (m_splitMap->TryGetValue(localBaseVertex, splitList) == false) + { + splitList = DrNew DrVertexList(); + m_splitMap->Add(localBaseVertex, splitList); + } + + /* this is going to be the new vertex in our stage */ + DrVertexPtr newVertex; + + /* either some other upstream stage has already split with this + index number, or it's the next index to split */ + DrAssert(upstreamSplitIndex <= splitList->Size()); + if (upstreamSplitIndex == splitList->Size()) + { + /* it's the next index to split, so we actually have to create + a new vertex */ + newVertex = localBaseVertex->MakeCopy(upstreamSplitIndex); + newVertex->GetInputs()->SetNumberOfEdges(localBaseVertex->GetInputs()->GetNumberOfEdges()); + + /* register this new vertex with ourselves: this will attach + up its output edges based on the policies of the vertices + upstream of localBaseVertex, perhaps propagating the + pipeline split forwards */ + GetParent()->RegisterVertexSplit(newVertex, localBaseVertex, + upstreamSplitIndex); + + splitList->Add(newVertex); + } + else + { + /* some other upstream vertex already triggered the creation + of this vertex */ + newVertex = splitList[upstreamSplitIndex]; + } + + DrAssert(newVertex != DrNull); + + if (upstreamVertex->GetOutputs()->GetNumberOfEdges() <= outputPortOfSplitBase) + { + upstreamVertex->GetOutputs()->GrowNumberOfEdges(outputPortOfSplitBase+1); + } + + int nInputs = newVertex->GetInputs()->GetNumberOfEdges(); + DrAssert(localBasePort < nInputs); + upstreamVertex->ConnectOutput(outputPortOfSplitBase, + newVertex, localBasePort, + originalType); + + int i; + for (i=0; iGetInputs()->GetEdge(i); + DrVertexPtr source = e.m_remoteVertex; + int outs = source->GetOutputs()->GetNumberOfEdges(); + source->GetOutputs()->GrowNumberOfEdges(outs + 1); + source->ConnectOutput(outs, newVertex, i, originalType); + } + + /* All the input edges are connected, so this vertex is ready to run */ + newVertex->InitializeForGraphExecution(); + newVertex->KickStateMachine(); + + /* remove it from the split vector so we won't check it again in + NotifyUpstreamLastVertexCompleted below */ + splitList[upstreamSplitIndex] = DrNull; +} + +void DrSemiPipelineSplitManager:: + NotifyUpstreamVertexRemoval(DrVertexPtr upstreamVertex, + int outputPortOfRemovedVertex) +{ + /* disconnect ALL the inputs to this vertex */ + DrVertexRef localBaseVertex = upstreamVertex->RemoteOutputVertex(outputPortOfRemovedVertex); + int nInputs = localBaseVertex->GetInputs()->GetNumberOfEdges(); + + int i; + for (i=0; iGetInputs()->GetEdge(i); + DrVertexPtr source = e.m_remoteVertex; + source->DisconnectOutput(e.m_remotePort, true); + + if (source != upstreamVertex) + { + source->GetOutputs()->Compact(DrNull); + } + } + + /* shrink the local edge list to get rid of the empty slot we just + left. The upstream vertex will be dealt with by its own stage + manager */ + localBaseVertex->GetInputs()->Compact(localBaseVertex); + + DrAssert(localBaseVertex->GetInputs()->GetNumberOfEdges() == 0); + + GetParent()->UnRegisterVertex(localBaseVertex); + + DrAssert(localBaseVertex->GetOutputs()->GetNumberOfEdges() == 0); + + localBaseVertex->RemoveFromGraphExecution(); +} + +void DrSemiPipelineSplitManager:: + NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage) +{ + bool removed = m_stageSet->Remove(upstreamStage); + DrAssert(removed == 1); + + if (m_stageSet->GetSize() == 0) + { + /* all our upstream stages have completed. Now go through and + dispatch any split vertices that were still waiting for an + edge */ + DrVertexVListMap::DrEnumerator e = m_splitMap->GetDrEnumerator(); + while (e.MoveNext()) + { + DrVertexListRef list = e.GetValue(); + int i; + for (i=0; iSize(); ++i) + { + DrVertexPtr vertex = list[i]; + if (vertex != DrNull) + { + int nInputs = vertex->GetInputs()->GetNumberOfEdges(); + vertex->GetInputs()->Compact(vertex); + DrAssert(vertex->GetInputs()->GetNumberOfEdges() < nInputs); + vertex->InitializeForGraphExecution(); + vertex->KickStateMachine(); + list[i] = DrNull; + } + } + } + } +} diff --git a/GraphManager/stagemanager/DrPipelineSplitManager.h b/GraphManager/stagemanager/DrPipelineSplitManager.h new file mode 100644 index 0000000..d39cb70 --- /dev/null +++ b/GraphManager/stagemanager/DrPipelineSplitManager.h @@ -0,0 +1,84 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRCLASS(DrPipelineSplitManager) : public DrConnectionManager +{ +public: + DrPipelineSplitManager(); + + virtual void AddUpstreamStage(DrManagerBasePtr upstreamStage) DROVERRIDE; + virtual void NotifyUpstreamSplit(DrVertexPtr upstreamVertex, + DrVertexPtr baseNewVertexSplitFrom, + int outputPortOfSplitBase, + int upstreamSplitIndex) DROVERRIDE; + virtual void NotifyUpstreamVertexRemoval(DrVertexPtr upstreamVertex, + int outputPortOfRemovedVertex) DROVERRIDE; + virtual void NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage) DROVERRIDE; + +private: + DrStageSetRef m_stageSet; + DrVertexVListMapRef m_splitMap; +}; +DRREF(DrPipelineSplitManager); + + +/* + semipipelinesplitter.h + + A connection manager which looks almost like a pipelinesplitter. + However, the vertices may have inputs coming from other stages as + well. These inputs are replicated instead of creating additional + copies. + + I.e. + + (A >= B) || (C >= B) || (A >= C) + + The (A >= C) edge is required so that C does not yet execute when B is being rewritten. + With a semipipelinesplitter on the edge (A >= B) the following occurs: + + When A is expanded to (A,A,A) the graph becomes: + + (A,A,A) >= (B,B,B) || C >= (B,B,B) || (A,A,A) => C + + The simple pipelinesplitter does not handle the C => (B,B,B) connection. +*/ + +DRCLASS(DrSemiPipelineSplitManager) : public DrConnectionManager +{ +public: + DrSemiPipelineSplitManager(); + + virtual void AddUpstreamStage(DrManagerBasePtr upstreamStage) DROVERRIDE; + virtual void NotifyUpstreamSplit(DrVertexPtr upstreamVertex, + DrVertexPtr baseNewVertexSplitFrom, + int outputPortOfSplitBase, + int upstreamSplitIndex) DROVERRIDE; + virtual void NotifyUpstreamVertexRemoval(DrVertexPtr upstreamVertex, + int outputPortOfRemovedVertex) DROVERRIDE; + virtual void NotifyUpstreamLastVertexCompleted(DrManagerBasePtr upstreamStage) DROVERRIDE; + +private: + DrStageSetRef m_stageSet; + DrVertexVListMapRef m_splitMap; +}; +DRREF(DrSemiPipelineSplitManager); \ No newline at end of file diff --git a/GraphManager/stagemanager/DrStageHeaders.h b/GraphManager/stagemanager/DrStageHeaders.h new file mode 100644 index 0000000..8e3ab75 --- /dev/null +++ b/GraphManager/stagemanager/DrStageHeaders.h @@ -0,0 +1,32 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +#include +#include + +#include +#include +#include +#include +#include diff --git a/GraphManager/stagemanager/DrStageStatistics.cpp b/GraphManager/stagemanager/DrStageStatistics.cpp new file mode 100644 index 0000000..c096622 --- /dev/null +++ b/GraphManager/stagemanager/DrStageStatistics.cpp @@ -0,0 +1,652 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +static const double s_conservativeOutlierFraction = 0.5; +static const double s_outlierFraction = 0.20; +static const double s_outlierThresholdInSigmas = 3.0; +static const int s_numberOfTrials = 10; +static const double s_conditionThreshold = 10000.0; + +static const int s_firstEstimationPercentage = 50; +static const int s_reEstimationPercentage = 5; + +void DrStageStatistics::Measurement::Initialize(DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics) +{ + m_machine = machine; + m_elapsed = (double) (statistics->m_completionTime - statistics->m_runningTime); + m_dataSize = (double) (statistics->m_totalInputData->m_dataRead); + m_deviation = 0.0; +} + +DrStageStatistics::DrStageStatistics() +{ + m_numberOfStarted = 0; + m_measurement = DrNew MeasurementList(); + + m_sampleSize = 0; + m_nextReEstimationPercentage = s_firstEstimationPercentage; + + m_startup = 0.0; + m_dataMultiplier = 0.0; + m_stdDeviation = 0.0; + m_relativeStdDev = 1.0; + m_numberOfOutliers = 0; + + m_gotEstimate = false; + m_nonParametricOutlierEstimate = DrTimeInterval_Infinite; + m_reportedFinalStatistics = false; + m_dumpedRawStatisticsData = false; + + m_name = "(No name)"; +} + +void DrStageStatistics::SetName(DrString name) +{ + m_name = name; +} + +DrString DrStageStatistics::GetName() +{ + return m_name; +} + +void DrStageStatistics::ComputeNextEstimationThreshold() +{ + int nextReEstimation; + + do { + nextReEstimation = (m_sampleSize * m_nextReEstimationPercentage) / 100; + if (nextReEstimation < 2) + { + nextReEstimation = 2; + } + if (nextReEstimation < m_measurement->Size()+1) + { + /* note this can go over 100% because there can be more + executions than m_sampleSize due to re-executions of + the same vertex */ + m_nextReEstimationPercentage += s_reEstimationPercentage; + } + } while (nextReEstimation < m_measurement->Size()+1); +} + +DrTimeInterval DrStageStatistics::GetOutlierThreshold(DrGraphParametersPtr params) +{ + if (m_sampleSize <= params->m_duplicateEverythingThreshold) + { + DrTimeInterval defaultThreshold = params->m_defaultOutlierThreshold; + if (m_nonParametricOutlierEstimate < defaultThreshold) + { + return m_nonParametricOutlierEstimate; + } + else + { + return defaultThreshold; + } + } + else + { + return m_nonParametricOutlierEstimate; + } +} + +void DrStageStatistics::SetSampleSize(int sampleSize) +{ + m_sampleSize = sampleSize; + + /* recompute the threshold for the next re-estimation */ + int m_nextReEstimationPercentage = s_firstEstimationPercentage; + ComputeNextEstimationThreshold(); + + if (m_nextReEstimationPercentage == s_firstEstimationPercentage) + { + /* we don't have enough datapoints for a good estimate yet */ + m_gotEstimate = false; + m_nonParametricOutlierEstimate = DrTimeInterval_Infinite; + } +} + +void DrStageStatistics::IncrementSampleSize() +{ + SetSampleSize(m_sampleSize+1); +} + +void DrStageStatistics::DecrementSampleSize() +{ + DrAssert(m_sampleSize > 0); + SetSampleSize(m_sampleSize-1); +} + +void DrStageStatistics::IncrementStartedCount() +{ + ++m_numberOfStarted; +} + +void DrStageStatistics::DecrementStartedCount() +{ + DrAssert(m_numberOfStarted > 0); + --m_numberOfStarted; +} + +void DrStageStatistics::AddMeasurement(DrGraphParametersPtr params, DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics) +{ + MeasurementRef m = DrNew Measurement(); + m->Initialize(machine, statistics); + + m_measurement->Add(m); + + int nextReEstimation = (m_sampleSize * m_nextReEstimationPercentage) / 100; + if (nextReEstimation < 2) + { + nextReEstimation = 2; + } + + if (m_measurement->Size() == nextReEstimation) + { + DrLogI("Re-estimating stage statistics Stage %s", m_name.GetChars()); + + ReEstimate(params); + + DrTimeInterval tiStartup = (DrTimeInterval) m_startup; + DrTimeInterval tiMultiplier = (DrTimeInterval) m_dataMultiplier; + tiMultiplier = tiMultiplier / (1024*1024); + DrTimeInterval tiStdDev = (DrTimeInterval) m_stdDeviation; + + DrLogI("Got new stage model estimate Stage=%s startup=%lf multiplier=%lf/MB std.dev=%lf " + "relative std.dev=%lf number of outliers=%d", + m_name.GetChars(), + (double) tiStartup / (double) DrTimeInterval_Second, + (double) tiMultiplier / (double) DrTimeInterval_Second, + (double) tiStdDev / (double) DrTimeInterval_Second, + m_relativeStdDev, m_numberOfOutliers); + + ComputeNextEstimationThreshold(); + } + else + { + DrAssert(m_measurement->Size() < nextReEstimation); + } +} + +DRCLASS(DrSignedDeviationComparer) : public DrComparer +{ +public: + virtual int Compare(DrStageStatistics::MeasurementRef aM, DrStageStatistics::MeasurementRef bM) DROVERRIDE + { + double a = aM->m_deviation; + double b = bM->m_deviation; + + return ((a < b) ? (-1) : ((a > b) ? (1) : (0))); + } +}; +DRREF(DrSignedDeviationComparer); + +void DrStageStatistics::SortSignedDeviations() +{ + DrSignedDeviationComparerRef comparer = DrNew DrSignedDeviationComparer(); + m_measurement->Sort(comparer); +} + +#include + +DRCLASS(DrUnSignedDeviationComparer) : public DrComparer +{ +public: + virtual int Compare(DrStageStatistics::MeasurementRef aM, DrStageStatistics::MeasurementRef bM) DROVERRIDE + { + double a = fabs(aM->m_deviation); + double b = fabs(bM->m_deviation); + + return ((a < b) ? (-1) : ((a > b) ? (1) : (0))); + } +}; +DRREF(DrUnSignedDeviationComparer); + +void DrStageStatistics::SortUnsignedDeviations() +{ + DrUnSignedDeviationComparerRef comparer = DrNew DrUnSignedDeviationComparer(); + m_measurement->Sort(comparer); +} + +DRCLASS(DrElapsedComparer) : public DrComparer +{ +public: + virtual int Compare(DrStageStatistics::MeasurementRef aM, DrStageStatistics::MeasurementRef bM) DROVERRIDE + { + double a = aM->m_elapsed; + double b = bM->m_elapsed; + + return ((a < b) ? (-1) : ((a > b) ? (1) : (0))); + } +}; +DRREF(DrElapsedComparer); + +void DrStageStatistics::SortElapsed() +{ + DrElapsedComparerRef comparer = DrNew DrElapsedComparer(); + m_measurement->Sort(comparer); +} + +void DrStageStatistics::ComputeDeviations(double startup, double dataMultiplier, double stdDeviation) +{ + double outlierThreshold = stdDeviation * s_outlierThresholdInSigmas; + + dataMultiplier *= m_dataSizeScaler; + + m_numberOfOutliers = 0; + int i; + for (i=0; iSize(); ++i) + { + double expectedTime = startup + m_measurement[i]->m_dataSize * dataMultiplier; + double deviation = m_measurement[i]->m_elapsed - expectedTime; + + m_measurement[i]->m_deviation = deviation; + + if (deviation > outlierThreshold) + { + ++m_numberOfOutliers; + } + } +} + +double DrStageStatistics::StdDevFromMeasurementPrefix(int prefixLength) +{ + DrAssert(prefixLength <= m_measurement->Size()); + + double totalSquared = 0.0; + int i; + for (i=0; im_deviation * m_measurement[i]->m_deviation; + } + + double meanSquared = totalSquared / (double) prefixLength; + + return sqrt(meanSquared); +} + +void DrStageStatistics::ReflectDeviationPrefix(int prefixLength) +{ + DrAssert(prefixLength <= m_measurement->Size()); + + int suffixLength = m_measurement->Size() - prefixLength; + /* the prefix must be at least half the total number of measurements */ + DrAssert(suffixLength <= prefixLength); + + int i; + for (i=0; iSize()-1 - i]->m_deviation = m_measurement[i]->m_deviation; + } +} + +void DrStageStatistics::LinearRegression(MeasurementListRef measurementArray, + int prefixLength, double dataSizeScaler, + double& pStartup /* OUT */, double& pDataMultiplier /* OUT */) +{ + /* we'll estimate the best least-squares approximation for x + satisfying Ax = b where: + + A^T = ( 1 1 ... 1 ) + ( dataSize_0 dataSize_1 ... dataSize_prefixLength-1 ) + + b^T = ( elapsed_0 elapsed_1 ... elapsed_prefixLength-1 ) + + we do this via linear regression so \hat{x} = (A^T A)^-1 A^T b + where: + + \hat{x}^T = ( startup dataMultiplier ) + + (A^T A) = ( unitSum dataSizeSum ) + ( dataSizeSum dataSizeSquaredSum ) + + (A^T A)^-1 = 1/determinant ( dataSizeSquaredSum -dataSizeSum ) + ( -dataSizeSum unitSum ) + + (A^T B)^T = ( elapsedSum dataSizeElapsedProductSum ) + */ + + double unitSum = (double) prefixLength; + double dataSizeSum = 0.0; + double dataSizeSquaredSum = 0.0; + double elapsedSum = 0.0; + double dataSizeElapsedProductSum = 0.0; + + int i; + for (i=0; im_dataSize * dataSizeScaler; + + dataSizeSum += (scaledDataSize); + dataSizeSquaredSum += (scaledDataSize * scaledDataSize); + elapsedSum += (measurementArray[i]->m_elapsed); + dataSizeElapsedProductSum += (scaledDataSize * measurementArray[i]->m_elapsed); + } + + /* if the data sizes are almost equal then this calculation will + be ill-conditioned and we will just set the data size + multiplier to be zero and absorb everything into the startup + estimate. Let C = (A^T A), i.e. the matrix we need to + invert. Then the condition number of C is |C^-1| |C| where we + use the l_infinity norm for |.| so + + | (a b) | = max(|a|,|b|,|c|,|d|) + | (c d) | + + In this case since C is symmetric we denote it + + (a b) + (b d) + + and we are looking for + + |(a b)| |1/det(d -b)| + |(b d)| | (-b a)| + + = z^2/|det| where z = max(|a|,|b|,|d|) + + The system is therefore ill-conditioned if + + z^2 > conditionThreshold |det| + */ + + double absa = fabs(unitSum); + double absb = fabs(dataSizeSum); + double absd = fabs(dataSizeSquaredSum); + double z = (absa > absb) ? absa : absb; + z = (z > absd) ? z : absd; + + double determinant = unitSum*dataSizeSquaredSum - dataSizeSum*dataSizeSum; + + if (z*z > (s_conditionThreshold * fabs(determinant))) + { + /* this means the data sizes are close to equal, so we'll just + estimate everything as being startup time */ + pStartup = elapsedSum / unitSum; + pDataMultiplier = 0.0; + } + else + { + double unnormalizedStartup = dataSizeSquaredSum * elapsedSum + -dataSizeSum * dataSizeElapsedProductSum; + double unnormalizedDataSize = -dataSizeSum * elapsedSum + unitSum * dataSizeElapsedProductSum; + + pStartup = unnormalizedStartup / determinant; + pDataMultiplier = unnormalizedDataSize / determinant; + } +} + +void DrStageStatistics::RandomlySample(int robustEstimatePrefix) +{ + DrAssert(m_measurement->Size() > 1); + + double minStdDev = 0.0; + + int i; + for (i=0; iSize(); + int p2; + do { + p2 = rand() % m_measurement->Size(); + } while (p1 == p2); + + MeasurementListRef m = DrNew MeasurementList(); + + m->Add(m_measurement[p1]); + m->Add(m_measurement[p2]); + + double startup, dataMultiplier; + LinearRegression(m, 2, m_dataSizeScaler, startup, dataMultiplier); + + ComputeDeviations(startup, dataMultiplier, 0.0); + SortUnsignedDeviations(); + double trialStdDev = StdDevFromMeasurementPrefix(robustEstimatePrefix); + + if (i == 0 || trialStdDev < minStdDev) + { + minStdDev = trialStdDev; + m_startup = startup; + m_dataMultiplier = dataMultiplier; + } + } +} + +void DrStageStatistics::ComputeRelativeStandardDeviation() +{ + /* get the median elapsed time */ + int medianMeasurement = m_measurement->Size()/2; + double medianElapsed = m_measurement[medianMeasurement]->m_elapsed; + + /* now compute the std. deviation relative to the median elapsed + time: this is a proxy for how well we are estimating: the + smaller the better */ + if (medianElapsed == 0.0) + { + /* just conceivably everything could be finishing so fast that + we don't witness any running time, in which case we'll just + punt on the relative error */ + m_relativeStdDev = 1.0; + } + else + { + m_relativeStdDev = m_stdDeviation / medianElapsed; + } +} + +void DrStageStatistics::GetDataSizeScaler() +{ + /* for stability in the linear regression, it's good to have the + data sizes come out looking close to 1, so we'll find the max + data size and scale everything by it */ + double maxDataSize = m_measurement[0]->m_dataSize; + + int i; + for (i=1; iSize(); ++i) + { + if (m_measurement[i]->m_dataSize > maxDataSize) + { + maxDataSize = m_measurement[i]->m_dataSize; + } + } + + if (maxDataSize == 0.0) + { + m_dataSizeScaler = 1.0; + } + else + { + m_dataSizeScaler = 1.0 / maxDataSize; + } +} + +void DrStageStatistics::ReEstimate(DrGraphParametersPtr params) +{ + DrAssert(m_measurement->Size() > 1); + + int conservativeOutlierCount = (int) ((double) m_measurement->Size() * s_conservativeOutlierFraction); + int robustEstimatePrefix = m_measurement->Size() - conservativeOutlierCount; + + int outlierCount = (int) ((double) m_measurement->Size() * s_outlierFraction); + int reflectionPrefix = m_measurement->Size() - outlierCount; + + GetDataSizeScaler(); + + /* first try a bunch of candidate pairs of points to get a good + robust rough estimate of the startup and multiplier */ + RandomlySample(robustEstimatePrefix); + + /* get the deviations based on the best rough estimates we tried + during random sampling */ + ComputeDeviations(m_startup, m_dataMultiplier, 0.0); + + /* this replaces m_startup and m_dataMultiplier with a + regression based on the prefix of measurements with smallest + deviation */ + SortUnsignedDeviations(); + double startup, dataMultiplier; + LinearRegression(m_measurement, robustEstimatePrefix, + m_dataSizeScaler, startup, dataMultiplier); + m_startup = startup; + m_dataMultiplier = dataMultiplier; + + /* now we have the best estimate we're going to get of the startup + and multiplier, let's try to get a decent estimate of the + noise. */ + + /* First recompute the deviations again based on our good + parameter estimates */ + ComputeDeviations(m_startup, m_dataMultiplier, 0.0); + + /* replace the largest deviations by reflection. The assumption + here is that we may have picked up some outlier measurements, + but that all outliers are biased to be slow not fast. Therefore + if we throw away the slowest ones we get a better estimate, but + then it would be biased. To remove the bias, we replace them by + reflection with the matching fastest ones, which we believe are + not outliers and therefore fair draws from the real Gaussian + distribution */ + SortSignedDeviations(); + ReflectDeviationPrefix(reflectionPrefix); + + /* after the reflection, we use all the available deviations when + estimating the standard deviation */ + m_stdDeviation = StdDevFromMeasurementPrefix(m_measurement->Size()); + + /* compute the deviations one last time since a side-effect of + this is that we count the number of outliers */ + ComputeDeviations(m_startup, m_dataMultiplier, m_stdDeviation); + + /* now let's see how good our estimate is. First sort by elapsed + time */ + SortElapsed(); + + /* compute the std. deviation relative to the median elapsed time: + this is a proxy for how well we are estimating: the smaller the + better */ + ComputeRelativeStandardDeviation(); + + /* rescale the data multiplier to its true value */ + m_dataMultiplier *= m_dataSizeScaler; + + if (params != DrNull) + { + int thresholdIndex = (int) ((double) m_numberOfStarted * params->m_nonParametricThresholdFraction); + + if (thresholdIndex < m_measurement->Size()) + { + DrTimeInterval minThreshold = params->m_minOutlierThreshold; + + m_nonParametricOutlierEstimate = + (DrTimeInterval) m_measurement[thresholdIndex]->m_elapsed; + + DrLogI("Computed new non-parametric estimate %lf", + (double) m_nonParametricOutlierEstimate / (double) DrTimeInterval_Second); + + if (m_nonParametricOutlierEstimate < minThreshold) + { + m_nonParametricOutlierEstimate = minThreshold; + + DrLogI("Reset non-parametric estimate to minimum %lf", + (double) m_nonParametricOutlierEstimate / (double) DrTimeInterval_Second); + } + } + else + { + /* don't update the estimate. We may already have an estimate + from a previous call to ReEstimate though */ + DrLogI("Not recomputing new non-parametric estimate. " + "Number started=%d Fraction=%lf Completed=%d Estimate=%lf", + m_numberOfStarted, params->m_nonParametricThresholdFraction, + m_measurement->Size(), + (double) m_nonParametricOutlierEstimate / (double) DrTimeInterval_Second); + } + } + + m_gotEstimate = true; +} + +void DrStageStatistics::ReportFinalStatistics(FILE* f) +{ + if (m_reportedFinalStatistics) + { + /* this class may be attached to more than one stage manager, + and each one will tell it to report but we only want to do + it once */ + return; + } + + m_reportedFinalStatistics = true; + + if (m_measurement->Size() < 2) + { + fprintf(f, "Final statistics for stage %s unavailable: %s collected\n\n", m_name.GetChars(), + (m_measurement->Size() == 0) ? "no measurements" : "only 1 measurement"); + return; + } + + ReEstimate(DrNull); + + fprintf(f, "Final statistics for stage %s: %d measurements, %d vertices\n", + m_name.GetChars(), m_measurement->Size(), m_sampleSize); + + DrTimeInterval tiStartup = (DrTimeInterval) m_startup; + DrTimeInterval tiMultiplier = (DrTimeInterval) (m_dataMultiplier * (1024.0*1024.0)); + DrTimeInterval tiStdDev = (DrTimeInterval) m_stdDeviation; + + fprintf(f, "Model estimate:\n" + "startup=%lf multiplier=%lf/MB std.dev=%lf\n" + "relative std.dev=%lf number of outliers=%d\n\n", + (double) tiStartup / (double) DrTimeInterval_Second, + (double) tiMultiplier / (double) DrTimeInterval_Second, + (double) tiStdDev / (double) DrTimeInterval_Second, + m_relativeStdDev, m_numberOfOutliers); +} + +void DrStageStatistics::DumpRawStatisticsData(FILE* f) +{ + if (m_dumpedRawStatisticsData) + { + /* this class may be attached to more than one stage manager, + and each one will tell it to report but we only want to do + it once */ + return; + } + + m_dumpedRawStatisticsData = true; + + fprintf(f, + "Raw statistics for stage %s: %d measurements, %d vertices\n", + m_name.GetChars(), m_measurement->Size(), m_sampleSize); + + int i; + for (i=0; iSize(); ++i) + { + fprintf(f, "%s%s,%I64u,%I64u", + (i == 0) ? "" : ",", + m_measurement[i]->m_machine->GetName().GetChars(), + (UINT64) m_measurement[i]->m_dataSize, + (UINT64) m_measurement[i]->m_elapsed); + } + fprintf(f, "\n\n"); +} diff --git a/GraphManager/stagemanager/DrStageStatistics.h b/GraphManager/stagemanager/DrStageStatistics.h new file mode 100644 index 0000000..8cc3dab --- /dev/null +++ b/GraphManager/stagemanager/DrStageStatistics.h @@ -0,0 +1,155 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRBASECLASS(DrStageStatistics) +{ +public: + DrStageStatistics(); + + /* the name is used as a text string to identify this class of + statistics */ + void SetName(DrString name); + DrString GetName(); + + /* tell the class how many vertices there are in the sample + set. This number can be reset up or down even after + measurements have started coming in. If this number goes up, + then the number of measurements may fall below 50% of the + sample size in which case GotEstimate will start returning + false even if it returned true before */ + void SetSampleSize(int sampleSize); + void IncrementSampleSize(); + void DecrementSampleSize(); + void IncrementStartedCount(); + void DecrementStartedCount(); + + /* add a new measurement. The client can add more measurements + than sampleSize e.g. because of re-executions */ + void AddMeasurement(DrGraphParametersPtr params, DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics); + + /* this returns the class' current best estimate for an outlier + threshold. This may be CsTimeInterval_Infinite if insufficient + data has been gathered so far to figure it out. The graph is + passed in to read thresholds out of. */ + DrTimeInterval GetOutlierThreshold(DrGraphParametersPtr params); + + void ReportFinalStatistics(FILE* f); + void DumpRawStatisticsData(FILE* f); + + DRINTERNALBASECLASS(Measurement) + { + public: + void Initialize(DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics); + + DrResourceRef m_machine; + double m_elapsed; + double m_dataSize; + double m_deviation; + }; + DRREF(Measurement); + +private: + typedef DrArrayList MeasurementList; + DRAREF(MeasurementList,MeasurementRef); + + void SortSignedDeviations(); + void SortUnsignedDeviations(); + void SortElapsed(); + void GetDataSizeScaler(); + void ComputeNextEstimationThreshold(); + void ComputeDeviations(double startup, double dataMultiplier, double stdDeviation); + double StdDevFromMeasurementPrefix(int prefixLength); + void ReflectDeviationPrefix(int prefixLength); + void LinearRegression(MeasurementListRef measurementArray, int prefixLength, double dataSizeScaler, + double& pStartup /* OUT */, double& pDataMultiplier /* OUT */); + void RandomlySample(int robustEstimatePrefix); + void ComputeRelativeStandardDeviation(); + void ReEstimate(DrGraphParametersPtr params); + + /* a string to print out to identify these statistics */ + DrString m_name; + + /* this is the total number of vertices in the group of vertices + that we are considering together for the purposes of gathering + statistics. The number of measurements we see may be greater + than m_sampleSize because vertices can be executed more than + once. */ + int m_sampleSize; + + /* this is a list of measurements we have seen */ + MeasurementListRef m_measurement; + + /* this is the number of vertices that have started (and possibly + completed). It can go down if a vertex fails. */ + int m_numberOfStarted; + + /* this is the next percentage of completed executions at which we + are going to re-estimate the model. This starts at 50% then + increases in 5% increments */ + int m_nextReEstimationPercentage; + + /* this is true if we have an estimate of the model: in practice + due to the behavior of m_nextReEstimationPercentage this is + true whenever the number of measurements received is over 50% + of the sample size */ + bool m_gotEstimate; + + /* this is the current non-parametric estimate for the outlier + threshold, or DrTimeInterval_Infinite if we haven't seen enough + data to say yet */ + DrTimeInterval m_nonParametricOutlierEstimate; + + /* this is our estimate of the model parameters for execution + time. The model says that a vertex with x bytes of input data + will take time + + m_startup + m_dataMultiplier*x + \nu * m_stdDeviation + + where \nu is iid unit Gaussian-distributed noise + + */ + double m_startup; + double m_dataMultiplier; + double m_stdDeviation; + + /* for numerial stability we rescale the data sizes to be close to + 1 during internal computations: this is the scaler we use */ + double m_dataSizeScaler; + + /* this is the standard deviation relative to the median elapsed + time, which is a general-purpose estimate for how well we are + modeling the data */ + double m_relativeStdDev; + + /* this is the number of measurements more than 3 sigmas away from + the model prediction */ + int m_numberOfOutliers; + + /* this class may be attached to more than one stage manager, and + each one will tell it to report but we only want to do it + once. These flags record whether we've dumped yet. */ + bool m_reportedFinalStatistics; + bool m_dumpedRawStatisticsData; +}; +DRREF(DrStageStatistics); diff --git a/GraphManager/vertex/DrClique.cpp b/GraphManager/vertex/DrClique.cpp new file mode 100644 index 0000000..c046f52 --- /dev/null +++ b/GraphManager/vertex/DrClique.cpp @@ -0,0 +1,180 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +DrStartClique::DrStartClique(DrActiveVertexPtr initialMember) +{ + m_list = DrNew DrActiveVertexList(); + m_list->Add(initialMember); + + m_version = 0; + m_externalInputsRemaining = 0; +} + +void DrStartClique::Discard() +{ + m_list = DrNull; + + if (m_gang != DrNull) + { + m_gang->Discard(); + } + m_gang = DrNull; +} + +int DrStartClique::CountExternalInputs(DrActiveVertexPtr vertex) +{ + int numberOfExternalInputs = 0; + int i; + for (i=0; iGetInputs()->GetNumberOfEdges(); ++i) + { + DrEdge e = vertex->GetInputs()->GetEdge(i); + if (e.IsStartCliqueEdge() == false) + { + ++numberOfExternalInputs; + } + } + + return numberOfExternalInputs; +} + +void DrStartClique::SetGang(DrGangPtr gang) +{ + m_gang = gang; +} + +DrGangPtr DrStartClique::GetGang() +{ + return m_gang; +} + +DrActiveVertexListPtr DrStartClique::GetMembers() +{ + return m_list; +} + +void DrStartClique::AssimilateOther(DrStartCliquePtr other) +{ + DrActiveVertexListRef otherList = other->GetMembers(); + int i; + for (i=0; iSize(); ++i) + { + DrActiveVertexPtr otherMember = otherList[i]; + DrAssert(otherMember->GetStartClique() == other); + otherMember->SetStartClique(this); + m_list->Add(otherMember); + } + + DrGangPtr otherGang = other->GetGang(); + otherGang->RemoveStartClique(other); + + DrGang::Merge(otherGang, m_gang); + +} + +void DrStartClique::Merge(DrStartCliqueRef s1, DrStartCliqueRef s2) +{ + if (s1 == s2) + { + return; + } + + if (s1->GetMembers()->Size() > s2->GetMembers()->Size()) + { + s1->AssimilateOther(s2); + } + else + { + s2->AssimilateOther(s1); + } +} + +void DrStartClique::InstantiateVersion(int version) +{ + DrAssert(m_version == 0); + m_version = version; + m_externalInputsRemaining = 0; + + int i; + for (i=0; iSize(); ++i) + { + m_externalInputsRemaining += CountExternalInputs(m_list[i]); + } + + for (i=0; iSize(); ++i) + { + m_list[i]->InstantiateVersion(version); + } +} + +void DrStartClique::NotifyExternalInputsReady(int version, int numberOfInputs) +{ + DrAssert(m_version == version); + DrAssert(m_externalInputsRemaining >= numberOfInputs); + + m_externalInputsRemaining -= numberOfInputs; +} + +void DrStartClique::GrowExternalInputs(int numberOfInputs) +{ + if (m_version != 0) + { + m_externalInputsRemaining += numberOfInputs; + } +} + +bool DrStartClique::StartVersionIfReady(int version) +{ + DrAssert(m_version == version); + + if (m_externalInputsRemaining > 0) + { + /* there are still inputs needed before we can start */ + return false; + } + + m_version = 0; + + /* ok all the inputs are ready: now make sure there's a process running for everyone + in the clique and move their records from pending to running */ + int i; + for (i=0; iSize(); ++i) + { + m_list[i]->StartProcess(version); + } + + /* then make sure they start connecting to each other if there was already a process + running in the cohort */ + for (i=0; iSize(); ++i) + { + m_list[i]->CheckForProcessAlreadyStarted(version); + } + + return true; +} + +void DrStartClique::NotifyVersionRevoked(int version) +{ + if (m_version == version) + { + m_version = 0; + } +} diff --git a/GraphManager/vertex/DrClique.h b/GraphManager/vertex/DrClique.h new file mode 100644 index 0000000..98199c6 --- /dev/null +++ b/GraphManager/vertex/DrClique.h @@ -0,0 +1,64 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRDECLARECLASS(DrGraph); +DRREF(DrGraph); + +DRDECLARECLASS(DrGang); +DRREF(DrGang); + +DRBASECLASS(DrStartClique) +{ +public: + DrStartClique(DrActiveVertexPtr initialMember); + void Discard(); + + void SetGang(DrGangPtr gang); + DrGangPtr GetGang(); + + int CountExternalInputs(DrActiveVertexPtr vertex); + + DrActiveVertexListPtr GetMembers(); + + static void Merge(DrStartCliqueRef s1, DrStartCliqueRef s2); + + void InstantiateVersion(int version); + void NotifyExternalInputsReady(int version, int numberOfInputs); + void GrowExternalInputs(int numberOfInputs); + bool StartVersionIfReady(int version); + void NotifyVersionRevoked(int version); + +private: + void AssimilateOther(DrStartCliquePtr other); + + DrActiveVertexListRef m_list; + DrGangRef m_gang; + + /* If we are currently preparing a new version to start then m_version is non-zero and + m_externalInputsRemaining gives the number of external inputs for that version that we + are waiting for */ + int m_version; + int m_externalInputsRemaining; +}; + +typedef DrArrayList DrStartCliqueList; +DRAREF(DrStartCliqueList,DrStartCliqueRef); diff --git a/GraphManager/vertex/DrCohort.cpp b/GraphManager/vertex/DrCohort.cpp new file mode 100644 index 0000000..90b9fc8 --- /dev/null +++ b/GraphManager/vertex/DrCohort.cpp @@ -0,0 +1,1001 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +DrCohortProcess::DrCohortProcess(DrGraphPtr graph, DrCohortPtr parent, int version, + int numberOfVertices, DrTimeInterval timeout) + : DrSharedCritSec(graph) +{ + m_receivedProcess = false; + m_messagePump = graph->GetXCompute()->GetMessagePump(); + m_parent = parent; + m_version = version; + m_numberOfVerticesLeftToComplete = numberOfVertices; + m_timeout = timeout; +} + +void DrCohortProcess::DiscardParent() +{ + m_parent = DrNull; +} + +int DrCohortProcess::GetVersion() +{ + return m_version; +} + +DrLockBox DrCohortProcess::GetProcess() +{ + return m_process; +} + +void DrCohortProcess::ReceiveMessage(DrProcessInfoRef message) +{ + if (m_receivedProcess == false) + { + /* when the CohortStarter helper below actually scheduled the process, it + wasn't holding our lock. So the first thing it does is send us a + message containing our DrProcess. Because of the ordering guarantees + of the message queue, that is the first message we will receive. There + are two cases: either it sent the process to be scheduled, in which + case we will hear from a subsequent message whether it succeeded or not, + or there was an unsatisfiable hard constraint in which case we hear with + an error right now. */ + + m_receivedProcess = true; + + /* this can only be the fake message from the Cohort Starter */ + DrAssert(message->m_state->m_state == DPS_NotStarted); + + if (message->m_process.IsNull()) + { + /* there was an error scheduling */ + DrAssert(message->m_state->m_status != DrNull); + + DrString msg = DrError::ToShortText(message->m_state->m_status); + DrLogI("Cohort %s v.%d got scheduling error message on cohort startup %s", + m_parent->GetDescription().GetChars(), m_version, msg.GetChars()); + + m_parent->GetGang()->CancelVersion(m_version, message->m_state->m_status); + } + else + { + DrAssert(message->m_state->m_status == DrNull); + m_process = message->m_process; + DrLogI("Cohort %s v.%d got startup message", + m_parent->GetDescription().GetChars(), m_version); + } + + return; + } + + if (m_process.IsNull()) + { + /* we have already finished so do nothing */ + return; + } + + DrLogI("Cohort %s v.%d got message state %d", + m_parent->GetDescription().GetChars(), m_version, + message->m_state->m_state); + + DrProcessStateRecordPtr state = message->m_state; + if (state->m_state > DPS_Running) + { + /* in the normal course of affairs, we should have already seen the process + start running, in which case the vertices have all initiated their own + message sends to the process which will also return informing them that + it has finished, at which point we will be notified cleanly via NotifyVertexCompletion. + The DrProcess machinery is supposed to have delayed the message we are now + receiving in order to give the vertex messages a chance to arrive. So if we + ever get here, something has gone wrong: either the process never started or + the vertex messages didn't get sent. + + We are going to call Cancel below on the gang, which will result in all our vertices + calling NotifyVertexCompletion and eventually us cleaning up once they all report. + */ + + DrErrorRef error; + + if (state->m_state == DPS_Completed) + { + DrString reason; + if (m_processHandle == DrNull) + { + if (state->m_status == DrNull) + { + reason.SetF("Process completed with no error without starting"); + error = DrNew DrError(DrError_VertexError, "DrCohortProcess", reason); + } + else + { + reason.SetF("Process completed with code %s without starting", + DRERRORSTRING(state->m_status->m_code)); + error = DrNew DrError(state->m_status->m_code, "DrCohortProcess", reason); + error->AddProvenance(state->m_status); + } + } + else + { + if (state->m_status == DrNull) + { + reason.SetF("Process completed with no error but vertex message was never delivered"); + error = DrNew DrError(DrError_VertexError, "DrCohortProcess", reason); + } + else + { + reason.SetF("Process completed with code %s but vertex message was never delivered", + DRERRORSTRING(state->m_status->m_code)); + error = DrNew DrError(state->m_status->m_code, "DrCohortProcess", reason); + error->AddProvenance(state->m_status); + } + } + } + else + { + if (state->m_status == DrNull) + { + DrLogW("Empty status delivered with process info state %u", state->m_state); + DrString reason; + reason.SetF("Empty status with failed process state %u", state->m_state); + error = DrNew DrError(DrError_Unexpected, "DrCohortProcess", reason); + } + else + { + DrString reason; + reason.SetF("%s process code %s", + (state->m_state == DPS_Failed) ? "Failed" : "Zombie", + DRERRORSTRING(state->m_status->m_code)); + error = DrNew DrError(state->m_status->m_code, "DrCohortProcess", reason); + error->AddProvenance(state->m_status); + } + } + + DrString eString = DrError::ToShortText(error); + DrLogI("Cohort %s v.%d cancelling gang %s %s", + m_parent->GetDescription().GetChars(), m_version, eString.GetChars(), + error->m_explanation.GetChars()); + + m_parent->GetGang()->CancelVersion(m_version, error); + } + else if (m_processHandle == DrNull && state->m_state == DPS_Running) + { + /* the process has started so tell everyone about it */ + m_processHandle = state->m_process; + DrAssert(m_processHandle != DrNull); + + DrLogI("Cohort %s v.%d starting process", + m_parent->GetDescription().GetChars(), m_version); + + m_parent->NotifyProcessHasStarted(m_version); + } +} + +bool DrCohortProcess::ProcessHasStarted() +{ + return (m_processHandle != DrNull); +} + +void DrCohortProcess::NotifyVertexCompletion() +{ + DrLogI("Enter with m_numberOfVerticesLeftToComplete %d", m_numberOfVerticesLeftToComplete); + DrAssert(m_numberOfVerticesLeftToComplete > 0); + --m_numberOfVerticesLeftToComplete; + + if (m_numberOfVerticesLeftToComplete > 0) + { + DrLogI("Still have %d vertices to complete", m_numberOfVerticesLeftToComplete); + return; + } + + bool scheduleTermination = false; + if (m_process.IsEmpty() == false) + { + DrLockBoxKey process(m_process); + + if (process->GetInfo()->m_state->m_state <= DPS_Running) + { + /* the process hasn't exited yet even though all the vertices are done. It will + probably exit soon of its own accord, but in case it doesn't we'll schedule + a message that will terminate it after a while */ + DrLogI("Process %s has not exited, scheduling termination", process->GetName().GetChars()); + scheduleTermination = true; + } + else + { + /* The process completed and we're done with it, call Terminate to clean up */ + DrLogI("Process %s completed, calling Terminate to clean up", process->GetName().GetChars()); + DrAssert(process->GetInfo()->m_state->m_process != DrNull); + process->Terminate(); + } + } + + if (scheduleTermination) + { + /* The listener for the state message (DrProcess::ReceiveMessage(DrProcessState message)) + will call Terminate to clean up in response to this message. */ + DrPStateMessageRef message = DrNew DrPStateMessage(m_process, DPS_Failed); + m_messagePump->EnQueueDelayed(m_timeout, message); + } + + DrLogI("Notifying cohort of vertex completion"); + m_parent->NotifyProcessComplete(m_version); + + DrLogI("Discarding cohort process"); + Discard(); +} + +void DrCohortProcess::Discard() +{ + m_parent = DrNull; + + if (m_process.IsEmpty() == false) + { + { + DrLockBoxKey process(m_process); + + process->CancelListener(this); + + } + + m_process.Set(DrNull); + } + + /* make sure that if we were cancelled just as we were starting up, we + ignore the startup message containing the process when it arrives + instead of getting confused and setting m_process to be non-null again */ + m_receivedProcess = true; +} + + +DRBASECLASS(DrCohortStartInfo) +{ +public: + DrCohortStartInfo(DrXComputeRef xcompute, DrCohortProcessRef cohort, + DrString processName, DrString commandLine, + DrProcessTemplateRef processTemplate, + DrAffinityListRef affinityList) + { + m_xcompute = xcompute; + m_cohort = cohort; + m_processName = processName; + m_commandLine = commandLine; + m_processTemplate = processTemplate; + m_affinityList = affinityList; + } + + DrXComputeRef m_xcompute; + DrCohortProcessRef m_cohort; + DrString m_processName; + DrString m_commandLine; + DrProcessTemplateRef m_processTemplate; + DrAffinityListRef m_affinityList; +}; +DRREF(DrCohortStartInfo); + +typedef DrListener DrCohortStartListener; +typedef DrMessage DrCohortStartMessage; +DRREF(DrCohortStartMessage); + +DRCLASS(DrCohortStarter) : public DrCritSec, public DrCohortStartListener +{ +public: + /* implements the DrCohortStartListener interface */ + virtual void ReceiveMessage(DrCohortStartInfoRef message) + { + DrAffinityRef hardConstraint = DrAffinityIntersector::IntersectHardConstraints(DrNull, + message->m_affinityList); + + DrAffinityListRef affinityList; + if (hardConstraint == DrNull) + { + DrAffinityMergerRef merger = DrNew DrAffinityMerger(); + merger->AccumulateWeights(message->m_affinityList); + affinityList = merger->GetMergedAffinities(message->m_processTemplate->GetAffinityLevelThresholds()); + } + else + { + if (hardConstraint->GetLocalityArray()->Size() == 0) + { + DrString reason = "Unsatisfiable hard constraint for starting vertex"; + DrErrorRef error = DrNew DrError(DrError_HardConstraintCannotBeMet, "DrCohortStarter", reason); + + DrProcessInfoRef failureNotification = DrNew DrProcessInfo(); + failureNotification->m_state = DrNew DrProcessStateRecord(); + failureNotification->m_state->m_state = DPS_Failed; + failureNotification->m_state->m_status = error; + + DrProcessMessageRef failureMessage = DrNew DrProcessMessage(message->m_cohort, + failureNotification); + + message->m_xcompute->GetMessagePump()->EnQueue(failureMessage); + + return; + } + + affinityList = DrNew DrAffinityList(); + affinityList->Add(hardConstraint); + } + + DrProcessRef process = DrNew DrProcess(message->m_xcompute, message->m_processName, + message->m_commandLine, message->m_processTemplate); + + /* make a message to the cohort that will get delivered before the first message from its + process, including the process as payload. We rely on the ordering of the message queue + to ensure this message will arrive before anything from the process. We do it this way + so we can avoid acquiring the cohort's lock during scheduling, since the cohort shares + its lock with the global graph lock and we want to be able to start scheduling the + processes in parallel with graph state machine actions, particularly when a large stage + is starting up. */ + DrProcessInfoRef startNotification = DrNew DrProcessInfo(); + startNotification->m_process = process; + startNotification->m_state = DrNew DrProcessStateRecord(); + startNotification->m_state->m_state = DPS_NotStarted; + DrProcessMessageRef startMessage = DrNew DrProcessMessage(message->m_cohort, + startNotification); + message->m_xcompute->GetMessagePump()->EnQueue(startMessage); + + /* now actually schedule the process */ + process->SetAffinityList(affinityList); + process->AddListener(message->m_cohort); + + process->Schedule(); + } +}; +DRREF(DrCohortStarter); + + +DrCohort::DrCohort(DrProcessTemplatePtr processTemplate, DrActiveVertexPtr initialMember) +{ + m_processTemplate = processTemplate; + + m_list = DrNew DrActiveVertexList(); + m_list->Add(initialMember); + + m_versionList = DrNew VPList(); + + PrepareDescription(); +} + +void DrCohort::Discard() +{ + m_list = DrNull; + + if (m_gang != DrNull) + { + m_gang->Discard(); + } + m_gang = DrNull; + + m_versionList = DrNull; +} + +DrProcessTemplatePtr DrCohort::GetProcessTemplate() +{ + return m_processTemplate; +} + +void DrCohort::SetGang(DrGangPtr gang) +{ + m_gang = gang; +} + +DrGangPtr DrCohort::GetGang() +{ + return m_gang; +} + +DrString DrCohort::GetDescription() +{ + return m_description; +} + +DrActiveVertexListPtr DrCohort::GetMembers() +{ + return m_list; +} + +DrCohortProcessPtr DrCohort::GetProcessForVersion(int version) +{ + int i; + for (i=0; iSize(); ++i) + { + if (m_versionList[i].m_version == version) + { + return m_versionList[i].m_process; + } + } + return DrNull; +} + +DrCohortProcessPtr DrCohort::EnsureProcess(DrGraphPtr graph, int version) +{ + DrCohortProcessPtr process = GetProcessForVersion(version); + + if (process == DrNull) + { + m_gang->StartVersion(graph, version); + process = GetProcessForVersion(version); + } + + DrAssert(process != DrNull); + return process; +} + +void DrCohort::PrepareDescription() +{ + if (m_list->Size() > 1) + { + m_description = "vertices"; + } + else + { + m_description = "vertex"; + } + + int i; + for (i=0; iSize(); ++i) + { + m_description = m_description.AppendF(" %s", m_list[i]->GetDescription().GetChars()); + } +} + +void DrCohort::StartProcess(DrGraphPtr graph, int version) +{ + int i; + for (i=0; iSize(); ++i) + { + DrAssert(m_versionList[i].m_version != version); + } + + DrString processName; + processName.SetF("%s v.%d", m_description.GetChars(), version); + + DrString commandLine; + commandLine.SetF("%s --vertex --startfrompn %d", + m_processTemplate->GetCommandLineBase().GetChars(), m_list->Size()); + + for (i=0; iSize(); ++i) + { + commandLine = commandLine.AppendF(" %d %d", m_list[i]->GetId(), version); + } + + DrCohortProcessRef process = + DrNew DrCohortProcess(graph, this, version, m_list->Size(), + m_processTemplate->GetTimeOutBetweenProcessEndAndVertexNotification()); + + VersionProcess vp; + vp.m_version = version; + vp.m_process = process; + m_versionList->Add(vp); + + DrAffinityListRef affinity = DrNew DrAffinityList(); + for (i=0; iSize(); ++i) + { + m_list[i]->AddCurrentAffinitiesToList(version, affinity); + } + + graph->IncrementInFlightProcesses(); + + /* hand off the computation to merge the affinities (which can be slow) and the actual call to + start the process onto the work queue */ + DrCohortStartInfoRef info = DrNew DrCohortStartInfo(graph->GetXCompute(), process, + processName, commandLine, + m_processTemplate, affinity); + DrCohortStarterRef starter = DrNew DrCohortStarter(); + DrCohortStartMessageRef message = DrNew DrCohortStartMessage(starter, info); + graph->GetXCompute()->GetMessagePump()->EnQueue(message); +} + +void DrCohort::NotifyProcessHasStarted(int version) +{ + int i; + for (i=0; iSize(); ++i) + { + if (m_versionList[i].m_version == version) + { + break; + } + } + DrAssert(i < m_versionList->Size()); + DrLockBox process = m_versionList[i].m_process->GetProcess(); + + for (i=0; iSize(); ++i) + { + m_list[i]->ReactToStartedProcess(version, process); + } +} + +void DrCohort::NotifyProcessComplete(int version) +{ + int i; + for (i=0; iSize(); ++i) + { + if (m_versionList[i].m_version == version) + { + break; + } + } + DrAssert(i < m_versionList->Size()); + m_versionList->RemoveAt(i); + + GetGraph()->DecrementInFlightProcesses(); + + /* we don't need to tell the member vertices: the only reason we got here + was that they all called NotifyVertexCompletion on the DrCohortProcess + that is now calling us */ +} + +DrGraphPtr DrCohort::GetGraph() +{ + DrAssert(m_list != DrNull && m_list->Size() > 0); + return m_list[0]->GetStageManager()->GetGraph(); +} + +void DrCohort::CancelVertices(int version, DrErrorPtr error) +{ + DrLogI("Cancelling cohort vertices for version %d", version); + DrCohortProcessRef cohortProcess; + + int i; + for (i=0; iSize(); ++i) + { + if (m_versionList[i].m_version == version) + { + cohortProcess = m_versionList[i].m_process; + break; + } + } + + for (i=0; iSize(); ++i) + { + DrLogI("Cancelling cohort vertex %d.%d", m_list[i]->GetId(), version); + m_list[i]->CancelVersion(version, error, cohortProcess); + } + + /* if there was an entry in m_versionList above (so cohortProcess is non-NULL) then the vertices + will all have terminateed and told the cohortProcess so, and the cohortProcess will have + called us back on NotifyVersionComplete, and so it will have been removed from the list */ + for (i=0; iSize(); ++i) + { + DrAssert(m_versionList[i].m_version != version); + } +} + +void DrCohort::AssimilateOther(DrCohortPtr other) +{ + DrAssert(m_processTemplate == other->m_processTemplate); + + DrActiveVertexListRef otherList = other->GetMembers(); + int i; + for (i=0; iSize(); ++i) + { + DrActiveVertexPtr otherMember = otherList[i]; + DrAssert(otherMember->GetCohort() == other); + otherMember->SetCohort(this); + m_list->Add(otherMember); + } + + PrepareDescription(); + + DrGangPtr otherGang = other->GetGang(); + otherGang->RemoveCohort(other); + + DrGang::Merge(otherGang, m_gang); +} + +void DrCohort::Merge(DrCohortRef c1, DrCohortRef c2) +{ + if (c1->GetMembers()->Size() > c2->GetMembers()->Size()) + { + c1->AssimilateOther(c2); + } + else + { + c2->AssimilateOther(c1); + } +} + + +DrGang::DrGang(DrCohortPtr initialCohort, DrStartCliquePtr initialStartClique) +{ + m_cohort = DrNew DrCohortList(); + m_cohort->Add(initialCohort); + + m_clique = DrNew DrStartCliqueList(); + m_clique->Add(initialStartClique); + + m_pendingVersion = 0; + m_runningVersion = DrNew DrRunningGangList(); + m_completedVersion = 0; + + m_nextVersion = 1; +} + +void DrGang::Discard() +{ + m_cohort = DrNull; + m_clique = DrNull; +} + +void DrGang::IncrementUnreadyVertices() +{ + ++m_unreadyVertexCount; +} + +void DrGang::DecrementUnreadyVertices() +{ + DrAssert(m_unreadyVertexCount > 0); + --m_unreadyVertexCount; + + if (m_unreadyVertexCount == 0) + { + EnsurePendingVersion(0); + } +} + +bool DrGang::VerticesAreReady() +{ + return (m_unreadyVertexCount == 0); +} + +DrCohortListPtr DrGang::GetCohorts() +{ + return m_cohort; +} + +DrStartCliqueListPtr DrGang::GetStartCliques() +{ + return m_clique; +} + +void DrGang::RemoveCohort(DrCohortPtr cohort) +{ + bool removed = m_cohort->Remove(cohort); + DrAssert(removed); +} + +void DrGang::RemoveStartClique(DrStartCliquePtr startClique) +{ + bool removed = m_clique->Remove(startClique); + DrAssert(removed); +} + +void DrGang::Merge(DrGangRef g1, DrGangRef g2) +{ + if (g1 == g2) + { + return; + } + + if (g1->GetCohorts()->Size() + g1->GetStartCliques()->Size() > + g2->GetCohorts()->Size() + g2->GetStartCliques()->Size()) + { + g1->AssimilateOther(g2); + } + else + { + g2->AssimilateOther(g1); + } +} + +void DrGang::AssimilateOther(DrGangPtr other) +{ + int i; + + DrCohortListRef otherCohort = other->GetCohorts(); + for (i=0; iSize(); ++i) + { + DrCohortRef c = otherCohort[i]; + c->SetGang(this); + m_cohort->Add(c); + } + + DrStartCliqueListRef otherClique = other->GetStartCliques(); + for (i=0; iSize(); ++i) + { + DrStartCliqueRef c = otherClique[i]; + c->SetGang(this); + m_clique->Add(c); + } +} + +void DrGang::StartVersion(DrGraphPtr graph, int version) +{ + DrAssert(version < m_nextVersion); + + DrAssert(version == m_pendingVersion); + m_pendingVersion = 0; + + DrRunningGang rv; + rv.m_version = version; + rv.m_verticesLeftToComplete = 0; + + /* count the number of vertices in the gang */ + int i; + for (i=0; iSize(); ++i) + { + rv.m_verticesLeftToComplete += m_cohort[i]->GetMembers()->Size(); + } + + m_runningVersion->Add(rv); + + /* if we had a gang-scheduling interface to XCompute we would be calling it here */ + for (i=0; iSize(); ++i) + { + m_cohort[i]->StartProcess(graph, version); + } +} + +void DrGang::CancelAllVersions(DrErrorPtr error) +{ + DrLogI("Canceling all versions for gang"); + if (m_pendingVersion > 0) + { + DrLogI("Canceling pending version %d", m_pendingVersion); + CancelVersion(m_pendingVersion, error); + } + + while (m_runningVersion->Size() > 0) + { + DrLogI("Canceling running version %d", m_runningVersion[0].m_version); + CancelVersion(m_runningVersion[0].m_version, error); + } + + /* make sure we didn't try to schedule another version. The + graph should be shutting down if this was called, which + should be blocking new pending versions from being made + since GetGraph()->IsRunning() should be false */ + DrAssert(m_pendingVersion == 0); +} + +void DrGang::CancelVersion(int version, DrErrorPtr error) +{ + DrLogI("Canceling version %d for gang", version); + DrAssert(version < m_nextVersion); + + int i; + for (i=0; iSize(); ++i) + { + DrLogI("Canceling m_cohort[%d] version %d", i, version); + m_cohort[i]->CancelVertices(version, error); + } + + for (i=0; iSize(); ++i) + { + m_clique[i]->NotifyVersionRevoked(version); + } + + if (m_pendingVersion == version) + { + m_pendingVersion = 0; + + /* check our invariant holds */ + for (i=0; iSize(); ++i) + { + int j; + for (j=0; j < m_cohort[i]->GetMembers()->Size(); ++j) + { +#ifdef _MANAGED + DrAssert(m_cohort[i]->GetMembers()[j]->HasPendingVersion() == false); +#else + DrAssert(m_cohort[i]->GetMembers()->Get(j)->HasPendingVersion() == false); +#endif + } + } + } + + for (i=0; iSize(); ++i) + { + if (m_runningVersion[i].m_version == version) + { + m_runningVersion->RemoveAt(i); + + /* check our invariant holds */ + int j; + for (j=0; jSize(); ++j) + { + int k; + for (k=0; k < m_cohort[j]->GetMembers()->Size(); ++k) + { +#ifdef _MANAGED + DrAssert(m_cohort[j]->GetMembers()[k]->HasRunningVersion(version) == false); +#else + DrAssert(m_cohort[j]->GetMembers()->Get(k)->HasRunningVersion(version) == false); +#endif + } + } + + break; + } + } + + if (m_completedVersion == version) + { + m_completedVersion = 0; + + /* check our invariant holds */ + for (i=0; iSize(); ++i) + { + int j; + for (j=0; j < m_cohort[i]->GetMembers()->Size(); ++j) + { +#ifdef _MANAGED + DrAssert((m_cohort[i]->GetMembers()[j])->HasCompletedVersion(version) == false); +#else + DrAssert((m_cohort[i]->GetMembers()->Get(j))->HasCompletedVersion(version) == false); +#endif + } + } + } + + EnsurePendingVersion(0); +} + +void DrGang::EnsurePendingVersion(int duplicateVersion) +{ + int i; + + if (GetGraph()->IsRunning() == false) + { + /* we are shutting down so don't do any more */ + return; + } + + if (VerticesAreReady() == false) + { + /* some vertex is blocked by its stage from starting, so wait until it + unblocks */ + return; + } + + if (m_completedVersion != 0) + { + /* there is already a consistent completed version held by every vertex in + the gang, so no need to start a new one */ + return; + } + + if (m_pendingVersion != 0) + { + /* there is already a version pending, i.e. waiting for inputs, so no need + to start a new one */ + return; + } + + if (m_runningVersion->Size() > 0) + { + /* there is a version running. Only start a new one if we've been told to */ + if (duplicateVersion == 0) + { + /* we aren't been told to duplicate, so don't */ + return; + } + + if (m_runningVersion->Size() > 2) + { + /* don't have more than three copies running at a given time: maybe this vertex + is just slow... */ + return; + } + + for (i=0; iSize(); ++i) + { + if (m_runningVersion[i].m_version >= duplicateVersion) + { + /* a duplicate has already been started for this version or a subsequent one, + so no need to do so again */ + return; + } + } + } + + /* OK, we should set up a new pending version in every vertex in the gang */ + int newVersion = m_nextVersion; + ++m_nextVersion; + m_pendingVersion = newVersion; + + for (i=0; iSize(); ++i) + { + m_clique[i]->InstantiateVersion(newVersion); + } + + /* and then start anyone that is ready */ + for (i=0; iSize(); ++i) + { + /* StartVersionIfReady may actually kick off a process if all the inputs + are ready, and that would reset m_pendingVersion to 0 and add the version + to m_runningVersion, so it's important to use the local variable newVersion + in this call instead of just passing in m_pendingVersion */ + m_clique[i]->StartVersionIfReady(newVersion); + } +} + +void DrGang::ReactToCompletedVertex(int version) +{ + bool completed = false; + bool foundMatch = false; + + int i; + for (i=0; iSize(); ++i) + { + if (m_runningVersion[i].m_version == version) + { + DrRunningGang rv = m_runningVersion[i]; + DrAssert(rv.m_verticesLeftToComplete > 0); + --rv.m_verticesLeftToComplete; + m_runningVersion[i] = rv; + + if (rv.m_verticesLeftToComplete == 0) + { + m_runningVersion->RemoveAt(i); + + DrAssert(m_completedVersion == 0); + m_completedVersion = version; + + completed = true; + } + + foundMatch = true; + break; + } + } + + if (!foundMatch) + { + DrLogE("Failed to find match for running version %d", version); + DrLogE("m_runningVersion->Size() = %d", m_runningVersion->Size()); + DrLogE("m_completedVersion %d", m_completedVersion); + DrLogE("m_pendingVersion %d", m_pendingVersion); + DrLogE("m_nextVersion %d", m_nextVersion); + DrLogE("m_unreadyVertexCount %d", m_unreadyVertexCount); + } + DrAssert(foundMatch); + + if (completed) + { + if (m_pendingVersion != 0) + { + DrString reason = "Pending cohort being cancelled because a duplicate gang completed"; + DrErrorRef error = DrNew DrError(DrError_CohortShutdown, "DrGang", reason); + + DrLogI("Canceling pending version %d", m_pendingVersion); + CancelVersion(m_pendingVersion, error); + } + + while (m_runningVersion->Size() > 0) + { + DrString reason = "Running cohort being cancelled because a duplicate gang completed"; + DrErrorRef error = DrNew DrError(DrError_CohortShutdown, "DrGang", reason); + + /* CancelVersion removes the relevant record from the running version list, + hence using a while loop here instead of a for loop */ + DrLogI("Canceling running version %d", m_runningVersion[0].m_version); + CancelVersion(m_runningVersion[0].m_version, error); + } + } +} + +DrGraphPtr DrGang::GetGraph() +{ + DrAssert(m_cohort != DrNull && m_cohort->Size() > 0); + return m_cohort[0]->GetGraph(); +} \ No newline at end of file diff --git a/GraphManager/vertex/DrCohort.h b/GraphManager/vertex/DrCohort.h new file mode 100644 index 0000000..6581eaa --- /dev/null +++ b/GraphManager/vertex/DrCohort.h @@ -0,0 +1,170 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRCLASS(DrCohortProcess) : public DrSharedCritSec, public DrProcessListener +{ +public: + DrCohortProcess(DrGraphPtr graph, DrCohortPtr cohort, + int version, int numberOfVertices, DrTimeInterval timeout); + + void Discard(); + void DiscardParent(); + + int GetVersion(); + + DrLockBox GetProcess(); + bool ProcessHasStarted(); + + void NotifyVertexCompletion(); + + /* DrProcessListener implementation */ + virtual void ReceiveMessage(DrProcessInfoRef message); + +private: + bool m_receivedProcess; + DrMessagePumpRef m_messagePump; + DrCohortRef m_parent; + int m_version; + DrLockBox m_process; + int m_numberOfVerticesLeftToComplete; + DrProcessHandleRef m_processHandle; + DrTimeInterval m_timeout; +}; +DRREF(DrCohortProcess); + + +DRVALUECLASS(VersionProcess) +{ +public: + int m_version; + DrCohortProcessRef m_process; +}; +typedef DrArrayList VPList; +DRAREF(VPList,VersionProcess); + + +DRBASECLASS(DrCohort) +{ +public: + DrCohort(DrProcessTemplatePtr processTemplate, DrActiveVertexPtr initialMember); + void Discard(); + + DrProcessTemplatePtr GetProcessTemplate(); + + void SetGang(DrGangPtr gang); + DrGangPtr GetGang(); + + DrString GetDescription(); + + DrActiveVertexListPtr GetMembers(); + + DrCohortProcessPtr GetProcessForVersion(int version); + DrCohortProcessPtr EnsureProcess(DrGraphPtr graph, int version); + void StartProcess(DrGraphPtr graph, int version); + void NotifyProcessHasStarted(int version); + void CancelVertices(int version, DrErrorPtr originalReason); + void NotifyProcessComplete(int version); + + static void Merge(DrCohortRef c1, DrCohortRef c2); + + DrGraphPtr GetGraph(); + +private: + void AssimilateOther(DrCohortPtr other); + void PrepareDescription(); + + DrProcessTemplateRef m_processTemplate; + DrActiveVertexListRef m_list; + DrGangRef m_gang; + DrString m_description; + + VPListRef m_versionList; +}; + +typedef DrArrayList DrCohortList; +DRAREF(DrCohortList,DrCohortRef); + +DRINTERNALVALUECLASS(DrRunningGang) +{ +public: + int m_version; + int m_verticesLeftToComplete; +}; + +typedef DrArrayList DrRunningGangList; +DRAREF(DrRunningGangList, DrRunningGang); + + +DRBASECLASS(DrGang) +{ +public: + DrGang(DrCohortPtr initialCohort, DrStartCliquePtr initialStartClique); + void Discard(); + + void IncrementUnreadyVertices(); + void DecrementUnreadyVertices(); + bool VerticesAreReady(); + + void StartVersion(DrGraphPtr graph, int version); + void CancelVersion(int version, DrErrorPtr error); + void CancelAllVersions(DrErrorPtr error); + void EnsurePendingVersion(int duplicateVersion); + void ReactToCompletedVertex(int version); + + DrCohortListPtr GetCohorts(); + DrStartCliqueListPtr GetStartCliques(); + + void RemoveCohort(DrCohortPtr cohort); + void RemoveStartClique(DrStartCliquePtr runClique); + + static void Merge(DrGangRef g1, DrGangRef g2); + +private: + void AssimilateOther(DrGangPtr other); + DrGraphPtr GetGraph(); + + DrCohortListRef m_cohort; + DrStartCliqueListRef m_clique; + + /* If non-zero, this version number is complete and present in every vertex in the gang. + If this is non-zero then there are no pending versions or running versions of any + vertex in the gang, since we cancel any duplicate executions as soon as there is a + consistent completed version available. If this is zero then some vertices in the + gang may still have one or more completed versions, however there is no consistent + version that is complete in every vertex. */ + int m_completedVersion; + + /* If non-zero, this version number is pending in every vertex in the gang. If there is + already a pending version (i.e. if this is non-zero) then we will not create any more + pending version---in other words we will not 'queue' multiple duplicates to be + started. */ + int m_pendingVersion; + + /* This list contains an entry for every version that has been started for this gang, + for which not all vertices have completed (or failed). There can be more than one + entry in the list in the case that there are duplicates running. */ + DrRunningGangListRef m_runningVersion; + + /* This is the version number that will be handed out to the next pending version. */ + int m_nextVersion; + int m_unreadyVertexCount; +}; diff --git a/GraphManager/vertex/DrGraph.cpp b/GraphManager/vertex/DrGraph.cpp new file mode 100644 index 0000000..fdc9e41 --- /dev/null +++ b/GraphManager/vertex/DrGraph.cpp @@ -0,0 +1,463 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +DrFailureInfo::DrFailureInfo() +{ + m_numberOfFailures = 0; +} + +DrGraph::DrGraph(DrXComputePtr xcompute, DrGraphParametersPtr parameters) + : DrErrorNotifier(xcompute->GetMessagePump()) +{ + m_xcompute = xcompute; + m_parameters = parameters; + + m_dictionary = DrNew DrFailureDictionary(); + m_stageList = DrNew DrStageList(); + m_partitionGeneratorList = DrNew DrPartitionGeneratorList(); + + m_state = DGS_NotStarted; + m_activeVertexCount = 0; + m_activeVertexCompleteCount = 0; + + DrActiveVertexOutputGenerator::s_intermediateCompressionMode = parameters->m_intermediateCompressionMode; +} + +void DrGraph::Discard() +{ + m_xcompute->Shutdown(); + m_xcompute = DrNull; + m_dictionary = DrNull; + + int i; + for (i=0; iSize(); ++i) + { + m_stageList[i]->Discard(); + } + m_stageList = DrNull; + + m_partitionGeneratorList = DrNull; +} + +void DrGraph::AddStage(DrStageManagerPtr stage) +{ + m_stageList->Add(stage); +} + +DrStageListPtr DrGraph::GetStages() +{ + return m_stageList; +} + +void DrGraph::AddPartitionGenerator(DrIOutputPartitionGeneratorPtr partitionGenerator) +{ + m_partitionGeneratorList->Add(partitionGenerator); +} + +bool DrGraph::IsRunning() +{ + return (m_state == DGS_Running); +} + +void DrGraph::StartRunning() +{ + DrAssert(m_state == DGS_NotStarted); + + m_state = DGS_Running; + m_activeVertexCount = 0; + m_activeVertexCompleteCount = 0; + + // Send a delayed message to renew the temporary output stream lease + // before it expires + DrLeaseMessageRef leaseMessage = DrNew DrLeaseMessage(this, true); + m_xcompute->GetMessagePump()->EnQueueDelayed(23 * DrTimeInterval_Hour, leaseMessage); + + // Send a delayed message to check for duplicate vertices + DrDuplicateMessageRef duplicateMessage = DrNew DrDuplicateMessage(this, 0); + m_xcompute->GetMessagePump()->EnQueueDelayed(DrTimeInterval_Second, duplicateMessage); + + int i; + for (i=0; iSize(); ++i) + { + m_stageList[i]->InitializeForGraphExecution(); + } + + m_xcompute->IncrementTotalSteps(false); // Add a step for initialization + m_xcompute->IncrementProgress("initialization complete"); + + for (i=0; iSize(); ++i) + { + m_stageList[i]->KickStateMachine(); + } +} + +void DrGraph::TriggerShutdown(DrErrorRef status) +{ + HRESULT exitCode = 0; + + DrLogI("Triggering shutdown in state %d", m_state); + if (m_state == DGS_Running) + { + m_exitStatus = status; + + /* Write error, if any, to stderr so it will appear in HPC console */ + if (status != DrNull && status->m_code != 0) + { + exitCode = status->m_code; + + /* Write the proximate error */ + if (status->m_explanation.GetChars() != DrNull) + { + fprintf(stderr, status->m_explanation.GetChars()); + fprintf(stderr, "\n"); + } + else + { + fprintf(stderr, "Error (0x%08X): %s\n", status->m_code, DRERRORSTRING(status->m_code)); + } + + /* Look at the provenance and write any previous errors, for additional detail */ + if (status->m_errorProvenance != DrNull) + { + for (int i = 0; i < status->m_errorProvenance->Size(); ++i) + { +#ifdef _MANAGED + DrErrorRef previousError = status->m_errorProvenance[i]; +#else + DrErrorRef previousError = status->m_errorProvenance->Get(i); +#endif + if (previousError != DrNull) + { + if (previousError->m_explanation.GetChars() != DrNull) + { + fprintf(stderr, previousError->m_explanation.GetChars()); + fprintf(stderr, "\n"); + } + else + { + fprintf(stderr, "Previous error (0x%08X): %s\n", previousError->m_code, DRERRORSTRING(previousError->m_code)); + } + } + } + /* OMC: Old Managed code. + array ^previousErrors = status->m_errorProvenance->ToArray(); + for (int i = 0; i < previousErrors->Length; i++) + { + if (previousErrors[i] != DrNull) + { + if (previousErrors[i]->m_explanation.GetChars() != DrNull) + { + fprintf(stderr, previousErrors[i]->m_explanation.GetChars()); + fprintf(stderr, "\n"); + } + else + { + fprintf(stderr, "Previous error (0x%08X): %s\n", previousErrors[i]->m_code, DRERRORSTRING(previousErrors[i]->m_code)); + } + } + } + */ + } + fflush(stderr); + } + + /* Send ourself a message to do the actual shutdown */ + DrShutdownMessageRef message = DrNew DrShutdownMessage(this, exitCode); + m_xcompute->GetMessagePump()->EnQueue(message); + } +} + +void DrGraph::ReceiveMessage(DrErrorRef /* unused abortError*/) +{ + if (m_state == DGS_Stopping) + { + DrLogI("Received process shutdown timeout"); + + FinalizeGraph(); + } +} + +void DrGraph::FinalizeGraph() +{ + DrAssert(m_state == DGS_Stopping); + + m_state = DGS_Stopped; + + if (m_exitStatus == DrNull || SUCCEEDED(m_exitStatus->m_code)) + { + HRESULT err = S_OK; + int i; + for (i=0; SUCCEEDED(err) && iSize(); ++i) + { + err = m_partitionGeneratorList[i]->FinalizeSuccessfulPartitions(); + if (!SUCCEEDED(err)) + { + DrString reason = "Failed to finalize outputs"; + m_exitStatus = DrNew DrError(err, "DrGraph", reason); + } + } + } + + if (m_parameters->m_topologyReporter != DrNull) + { + int i; + for (i=0; iSize(); ++i) + { + DrVertexListRef vList = m_stageList[i]->GetVertexVector(); + int j; + for (j=0; jSize(); ++j) + { + vList[j]->ReportFinalTopology(m_parameters->m_topologyReporter); + } + } + } + + DeliverNotification(m_exitStatus); +} + +void DrGraph::ReceiveMessage(DrLeaseExtender /* unused leaseMessage */) +{ + int i; + for (i=0; iSize(); ++i) + { + m_partitionGeneratorList[i]->ExtendLease(DrTimeInterval_Day); + } + DrLeaseMessageRef message = DrNew DrLeaseMessage(this, true); + m_xcompute->GetMessagePump()->EnQueueDelayed(23 * DrTimeInterval_Hour, message); +} + +void DrGraph::ReceiveMessage(DrDuplicateChecker /* unused checkDuplicate */) +{ + int i; + for (i=0; iSize(); ++i) + { + m_stageList[i]->CheckForDuplicates(); + } + + DrDuplicateMessageRef message = DrNew DrDuplicateMessage(this, 0); + m_xcompute->GetMessagePump()->EnQueueDelayed(DrTimeInterval_Second, message); +} + +void DrGraph::ReceiveMessage(DrExitStatus /* unused exitStatus */) +{ + DrLogI("Receiving shutdown message in state %d", m_state); + if (m_state == DGS_Running) + { + m_state = DGS_Stopping; + + if (m_inFlightProcessCount == 0) + { + DrLogI("No processes in flight, finalizing graph"); + FinalizeGraph(); + } + else + { + /* tell all the outstanding versions to cancel themselves */ + int i; + for (i=0; iSize(); ++i) + { + DrLogI("Cancelling all vertices in stage %d", i); + DrString abortReason = "Job is being canceled"; + DrErrorRef error = + DrNew DrError(DrError_CohortShutdown, "DrGraph", abortReason); + m_stageList[i]->CancelAllVertices(error); + } + + /* now we will wait until the final outstanding process calls DecrementInFlightProcesses + at which point FinalizeGraph will be called */ + + if (m_parameters->m_processAbortTimeOut < DrTimeInterval_Infinite) + { + /* set a timeout in case some processes don't abort; if this fires then ReceiveMessage + will get the error */ + DrString reason = "Process abort timed out"; + DrErrorRef error = + DrNew DrError(HRESULT_FROM_WIN32(ERROR_TIMEOUT), "DrGraph", reason); + DrErrorMessageRef message = DrNew DrErrorMessage(this, error); + m_xcompute->GetMessagePump()->EnQueueDelayed(m_parameters->m_processAbortTimeOut, message); + } + } + } +} + +void DrGraph::IncrementInFlightProcesses() +{ + DrAssert(m_state == DGS_Running); + ++m_inFlightProcessCount; +} + +void DrGraph::DecrementInFlightProcesses() +{ + DrAssert(m_inFlightProcessCount > 0); + --m_inFlightProcessCount; + + if (m_state == DGS_Stopping) + { + DrLogI("Stopping: waiting for %d processes", m_inFlightProcessCount); + + if (m_inFlightProcessCount == 0) + { + FinalizeGraph(); + } + } +} + +void DrGraph::IncrementActiveVertexCount() +{ + ++m_activeVertexCount; + m_xcompute->IncrementTotalSteps(false); +} + +void DrGraph::DecrementActiveVertexCount() +{ + --m_activeVertexCount; + m_xcompute->DecrementTotalSteps(false); +} + +void DrGraph::NotifyActiveVertexComplete() +{ + DrAssert(m_activeVertexCompleteCount < m_activeVertexCount); + ++m_activeVertexCompleteCount; + + DrLogI("Got %d/%d complete active vertices", m_activeVertexCompleteCount, m_activeVertexCount); + + if (m_activeVertexCompleteCount == m_activeVertexCount) + { + DrLogI("Triggering shutdown"); + TriggerShutdown(DrNull); + } +} + +void DrGraph::NotifyActiveVertexRevoked() +{ + DrAssert(m_activeVertexCompleteCount > 0); + --m_activeVertexCompleteCount; + + DrLogI("Revoking: got %d/%d complete active vertices", m_activeVertexCompleteCount, m_activeVertexCount); + + if (m_activeVertexCompleteCount == m_activeVertexCount) + { + TriggerShutdown(DrNull); + } +} + +DrXComputePtr DrGraph::GetXCompute() +{ + return m_xcompute; +} + +DrGraphParametersPtr DrGraph::GetParameters() +{ + return m_parameters; +} + +int DrGraph::ReportFailure(DrActiveVertexPtr vertex, int version, + DrVertexProcessStatusPtr status, DrErrorPtr error) +{ + /* TODO much more sophisticated here */ + + if (status != DrNull) + { + DrInputChannelArrayRef inputs = status->GetInputChannels(); + int i; + for (i=0; iAllocated(); ++i) + { + DrChannelDescriptionPtr c = inputs[i]; + HRESULT err = c->GetChannelState(); + if ((err != S_OK) && + (err != DrError_EndOfStream) && + (err != DrError_ProcessingInterrupted)) + { + DrLogI("Reporting read error %s for vertex %d.%d input channel %d %s", + DRERRORSTRING(err), vertex->GetId(), version, i, c->GetChannelURI().GetChars()); + return i; + } + } + } + + if (error->m_code == DrError_CohortShutdown) + { + /* not our fault */ + return -1; + } + + DrFailureInfoRef info; + if (m_dictionary->TryGetValue(vertex, info) == false) + { + info = DrNew DrFailureInfo(); + m_dictionary->Add(vertex, info); + } + ++(info->m_numberOfFailures); + if (info->m_numberOfFailures == m_parameters->m_maxActiveFailureCount) + { + DrLogI("Triggering graph abort because vertex %d failed %d times", vertex->GetId(), info->m_numberOfFailures); + + DrString reason; + DrMTagStringPtr vertexErrorString = DrNull; + + if (status != DrNull && + status->GetVertexMetaData() != DrNull && + (vertexErrorString = dynamic_cast(status->GetVertexMetaData()->LookUp(DrProp_ErrorString))) != DrNull && + vertexErrorString->GetValue().GetString() != DrNull) + { + reason.SetF("Graph abort because vertex failed %d times: vertex %d in stage %s\n\nVERTEX FAILURE DETAILS:\n%s", info->m_numberOfFailures, + vertex->GetId(), + vertex->GetStageManager()->GetStageName().GetChars(), + vertexErrorString->GetValue().GetChars()); + } + else + { + reason.SetF("Graph abort because vertex failed %d times: vertex %d in stage %s\n", info->m_numberOfFailures, + vertex->GetId(), + vertex->GetStageManager()->GetStageName().GetChars()); + } + + DrErrorRef newError = DrNew DrError(DrError_VertexError, "DrGraph", reason); + newError->AddProvenance(error); + TriggerShutdown(newError); + } + + return -1; +} + +void DrGraph::ReportStorageFailure(DrStorageVertexPtr vertex, DrErrorPtr originalError) +{ + DrFailureInfoRef info; + if (m_dictionary->TryGetValue(vertex, info) == false) + { + info = DrNew DrFailureInfo(); + m_dictionary->Add(vertex, info); + } + ++(info->m_numberOfFailures); + if (info->m_numberOfFailures == m_parameters->m_maxActiveFailureCount) + { + DrLogI("Triggering abort because input read failed %d times vertex %d", + info->m_numberOfFailures, vertex->GetId()); + + DrString reason; + reason.SetF("Graph abort because input read failed %d times: vertex %d in stage %s", + info->m_numberOfFailures, vertex->GetId(), vertex->GetStageManager()->GetStageName().GetChars()); + DrErrorRef error = DrNew DrError(DrError_InputUnavailable, "DrGraph", reason); + error->AddProvenance(originalError); + TriggerShutdown(error); + } +} diff --git a/GraphManager/vertex/DrGraph.h b/GraphManager/vertex/DrGraph.h new file mode 100644 index 0000000..97dfb27 --- /dev/null +++ b/GraphManager/vertex/DrGraph.h @@ -0,0 +1,125 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRBASECLASS(DrFailureInfo) +{ +public: + DrFailureInfo(); + + int m_numberOfFailures; +}; +DRREF(DrFailureInfo); + +typedef DrDictionary DrFailureDictionary; +DRREF(DrFailureDictionary); + +DRBASECLASS(DrGraphParameters) +{ +public: + DrTimeInterval m_processAbortTimeOut; + int m_maxActiveFailureCount; + + int m_duplicateEverythingThreshold; + DrTimeInterval m_defaultOutlierThreshold; + DrTimeInterval m_minOutlierThreshold; + double m_nonParametricThresholdFraction; + + int m_intermediateCompressionMode; + + DrProcessTemplateRef m_defaultProcessTemplate; + DrVertexTemplateRef m_defaultVertexTemplate; + + DrVertexTopologyReporterIRef m_topologyReporter; +}; +DRREF(DrGraphParameters); + +/* Duplicate check timer message */ +typedef int DrDuplicateChecker; + +typedef DrListener DrDuplicateListener; +DRIREF(DrDuplicateListener); + +typedef DrMessage DrDuplicateMessage; +DRREF(DrDuplicateMessage); + +DRENUM(DrGraphState) +{ + DGS_NotStarted, + DGS_Running, + DGS_Stopping, + DGS_Stopped +}; + +DRCLASS(DrGraph) : public DrErrorNotifier, public DrErrorListener, public DrLeaseListener, public DrDuplicateListener, public DrShutdownListener +{ +public: + DrGraph(DrXComputePtr xcompute, DrGraphParametersPtr parameters); + void Discard(); + + DrXComputePtr GetXCompute(); + DrGraphParametersPtr GetParameters(); + + void AddStage(DrStageManagerPtr stage); + DrStageListPtr GetStages(); + + void AddPartitionGenerator(DrIOutputPartitionGeneratorPtr partitionGenerator); + + bool IsRunning(); + void StartRunning(); + void TriggerShutdown(DrErrorRef status); + /* implements the DrErrorListener interface */ + virtual void ReceiveMessage(DrErrorRef abortError); + /* implements the DrLeaseListener interface */ + virtual void ReceiveMessage(DrLeaseExtender leaseMessage); + /* implements the DrDuplicateListener interface */ + virtual void ReceiveMessage(DrDuplicateChecker duplicateCheck); + /* implements the DrShutdownListener interface */ + virtual void ReceiveMessage(DrExitStatus exitStatus); + + void IncrementActiveVertexCount(); + void DecrementActiveVertexCount(); + void NotifyActiveVertexComplete(); + void NotifyActiveVertexRevoked(); + + void IncrementInFlightProcesses(); + void DecrementInFlightProcesses(); + + int ReportFailure(DrActiveVertexPtr vertex, int version, DrVertexProcessStatusPtr status, DrErrorPtr error); + void ReportStorageFailure(DrStorageVertexPtr vertex, DrErrorPtr error); + +private: + void FinalizeGraph(); + + DrGraphState m_state; + DrErrorRef m_exitStatus; + + DrXComputeRef m_xcompute; + DrGraphParametersRef m_parameters; + + DrStageListRef m_stageList; + DrPartitionGeneratorListRef m_partitionGeneratorList; + DrFailureDictionaryRef m_dictionary; + int m_activeVertexCount; + int m_activeVertexCompleteCount; + int m_inFlightProcessCount; +}; +DRREF(DrGraph); diff --git a/GraphManager/vertex/DrOutputGenerator.cpp b/GraphManager/vertex/DrOutputGenerator.cpp new file mode 100644 index 0000000..21b066a --- /dev/null +++ b/GraphManager/vertex/DrOutputGenerator.cpp @@ -0,0 +1,349 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +bool DrEdge::IsStartCliqueEdge() +{ + return (m_type == DCT_Pipe || m_type == DCT_Fifo); +} + +bool DrEdge::IsGangEdge() +{ + return (m_type == DCT_Pipe || m_type == DCT_Fifo || m_type == DCT_FifoNonBlocking); +} + +bool DrEdge::operator==(DrEdgeR other) +{ + return + (m_remoteVertex == other.m_remoteVertex && + m_remotePort == other.m_remotePort && + m_type == other.m_type); +} + +DrEdgeHolder::DrEdgeHolder(bool edgeHolderIsInput) +{ + m_inputEdges = edgeHolderIsInput; + SetNumberOfEdges(0); +} + +void DrEdgeHolder::SetNumberOfEdges(int numberOfEdges) +{ + m_edge = DrNew DrEdgeList(); + int i; + for (i=0; iAdd(e); + } +} + +void DrEdgeHolder::GrowNumberOfEdges(int newNumberOfEdges) +{ + DrAssert(newNumberOfEdges >= m_edge->Size()); + + int newCount = newNumberOfEdges - m_edge->Size(); + int i; + for (i=0; iAdd(e); + } +} + +int DrEdgeHolder::GetNumberOfEdges() +{ + return m_edge->Size(); +} + +void DrEdgeHolder::SetEdge(int edgeIndex, DrEdge edge) +{ + if (edgeIndex < m_edge->Size()) + { + m_edge[edgeIndex] = edge; + } + else + { + DrAssert(edgeIndex == m_edge->Size()); + m_edge->Add(edge); + } +} + +void DrEdgeHolder::Compact(DrVertexPtr thisVertex) +{ + DrEdgeListRef newList = DrNew DrEdgeList(); + + int remoteCount = 0; + int i; + + for (i=0; iSize(); ++i) + { + if (m_edge[i].m_type != DCT_Tombstone) + { + newList->Add(m_edge[i]); + + int remotePort = m_edge[i].m_remotePort; + + DrVertexPtr remoteVertex = m_edge[i].m_remoteVertex; + DrEdgeHolderRef remoteHolder; + if (m_inputEdges) + { + remoteHolder = remoteVertex->GetOutputs(); + } + else + { + remoteHolder = remoteVertex->GetInputs(); + } + + DrEdge remoteEdge = remoteHolder->GetEdge(remotePort); + if (remoteEdge.m_type != DCT_Tombstone) + { + DrAssert(remoteEdge.m_remotePort == i); + remoteEdge.m_remotePort = remoteCount; + remoteHolder->SetEdge(remotePort, remoteEdge); + } + + ++remoteCount; // Incrememt for each one added + } + } + + if (thisVertex && m_inputEdges) + { + thisVertex->CompactPendingVersion(this, remoteCount); + } + m_edge = newList; + +} + +DrEdge DrEdgeHolder::GetEdge(int edgeIndex) +{ + return m_edge[edgeIndex]; +} + + +void DrActiveVertexOutputGenerator::StoreOutputLengths(DrVertexProcessStatusPtr status, DrTimeInterval runningTime) +{ + DrOutputChannelArrayRef outputs = status->GetOutputChannels(); + + m_lengthArray = DrNew DrUINT64Array(outputs->Allocated()); + + int i; + for (i=0; iAllocated(); ++i) + { + m_lengthArray[i] = outputs[i]->GetChannelProcessedLength(); + } + + m_runningTime = runningTime; +} + +DrTimeInterval DrActiveVertexOutputGenerator::GetRunningTime() +{ + return m_runningTime; +} + +#ifndef _MANAGED +int DrActiveVertexOutputGenerator::s_intermediateCompressionMode = 0; +#endif + +void DrActiveVertexOutputGenerator::SetProcess(DrProcessHandlePtr process, + int vertexId, int version) +{ + m_vertexId = vertexId; + m_version = version; + /* There are failure cases where SetProcess is called with process == DrNull, + so check for that */ + if (process != DrNull) + { + /* Cache some state, because process gets closed in process termination, + before GetURIForRead is called by downstream vertices */ + + /* Base URI, used later in GetURIForRead */ + m_uriBase = process->GetFileURIBase(); + + /* Assigned node */ + m_assignedNode = process->GetAssignedNode(); + } +} + +int DrActiveVertexOutputGenerator::GetVersion() +{ + return m_version; +} + +DrAffinityRef DrActiveVertexOutputGenerator::GetOutputAffinity(int output) +{ + DrAffinityRef a = DrNew DrAffinity(); + if (m_lengthArray != DrNull) + { + a->SetWeight(m_lengthArray[output]); + } + a->AddLocality(m_assignedNode); + return a; +} + +DrString DrActiveVertexOutputGenerator::GetURIForWrite(DrEdgeHolderPtr outputEdges, + DrResourcePtr runningResource, + int id, int version, + int output, DrConnectorType type, + DrMetaDataRef metaData) +{ + DrString uri; + + DrEdge e; + + switch (type) + { + case DCT_File: + uri.SetF("file://%d_%d_%d.tmp?c=%d", id, output, version, DrActiveVertexOutputGenerator::s_intermediateCompressionMode); + break; + + case DCT_Output: + e = outputEdges->GetEdge(output); + uri = e.m_remoteVertex->GetURIForWrite(e.m_remotePort, id, version, output, runningResource, metaData); + break; + + case DCT_Pipe: + /* TODO implement pipes */ + DrLogA("Pipes not implemented"); + break; + + case DCT_Fifo: + uri.SetF("fifo://%u/%d_%d_%d", 32, id, output, version); + break; + + case DCT_FifoNonBlocking: + uri.SetF("fifo://%u/%d_%d_%d", (UINT32) -1, id, output, version); + break; + } + + return uri; +} + +DrString DrActiveVertexOutputGenerator::GetURIForRead(int output, DrConnectorType type, + DrResourcePtr /* unused runningResource */) +{ + DrString uri; + + switch (type) + { + case DCT_File: + if (m_uriBase.GetCharsLength() > 0) + { + uri.SetF("%s\\%d_%d_%d.tmp?c=%d", m_uriBase.GetChars(), m_vertexId, output, m_version, DrActiveVertexOutputGenerator::s_intermediateCompressionMode); + } + else + { + /* This should never happen - but just in case it does, let's assert so we can debug */ + DrLogA("Active vertex output generator was asked for a read URI when no base URI is available vertex %d.%d", m_vertexId, m_version); + } + break; + + case DCT_Output: + /* can't be reading from an edge that leads to an output vertex */ + DrLogA("Active vertex output generator was asked for a read URI on output edge type %d", output); + break; + + case DCT_Pipe: + /* TODO implement pipes */ + DrLogA("Pipes not implemented"); + break; + + case DCT_Fifo: + uri.SetF("fifo://%u/%d_%d_%d", 32, m_vertexId, output, m_version); + break; + + case DCT_FifoNonBlocking: + uri.SetF("fifo://%u/%d_%d_%d", (UINT32) -1, m_vertexId, output, m_version); + break; + } + + return uri; +} + +DrResourcePtr DrActiveVertexOutputGenerator::GetResource() +{ + return m_assignedNode; +} + + +DrStorageVertexOutputGenerator::DrStorageVertexOutputGenerator(int partitionIndex, + DrIInputPartitionReaderPtr reader) +{ + m_partitionIndex = partitionIndex; + m_reader = reader; +} + +DrResourcePtr DrStorageVertexOutputGenerator::GetResource() +{ + return DrNull; +} + +int DrStorageVertexOutputGenerator::GetVersion() +{ + return 0; +} + +DrAffinityRef DrStorageVertexOutputGenerator::GetOutputAffinity(int /* unused output */) +{ + return m_reader->GetAffinity(m_partitionIndex); +} + +DrString DrStorageVertexOutputGenerator::GetURIForRead(int /* unused output */, + DrConnectorType type, + DrResourcePtr runningResource) +{ + DrAssert(type == DCT_File); + return m_reader->GetURIForRead(m_partitionIndex, runningResource); +} + + +DrTeeVertexOutputGenerator::DrTeeVertexOutputGenerator(DrVertexOutputGeneratorPtr wrappedGenerator) +{ + m_wrappedGenerator = wrappedGenerator; +} + +DrVertexOutputGeneratorPtr DrTeeVertexOutputGenerator::GetWrappedGenerator() +{ + return m_wrappedGenerator; +} + +DrResourcePtr DrTeeVertexOutputGenerator::GetResource() +{ + return m_wrappedGenerator->GetResource(); +} + +int DrTeeVertexOutputGenerator::GetVersion() +{ + return m_wrappedGenerator->GetVersion(); +} + +DrAffinityRef DrTeeVertexOutputGenerator::GetOutputAffinity(int /* unused output */) +{ + return m_wrappedGenerator->GetOutputAffinity(0); +} + +DrString DrTeeVertexOutputGenerator::GetURIForRead(int /* unused output */, + DrConnectorType type, + DrResourcePtr runningResource) +{ + return m_wrappedGenerator->GetURIForRead(0, type, runningResource); +} diff --git a/GraphManager/vertex/DrOutputGenerator.h b/GraphManager/vertex/DrOutputGenerator.h new file mode 100644 index 0000000..b7ba10b --- /dev/null +++ b/GraphManager/vertex/DrOutputGenerator.h @@ -0,0 +1,157 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRPUBLICENUM(DrConnectorType) +{ + DCT_File, + DCT_Output, + DCT_Pipe, + DCT_Fifo, + DCT_FifoNonBlocking, + DCT_Tombstone +}; + +DRDECLARECLASS(DrVertex); +DRREF(DrVertex); + +DRDECLAREVALUECLASS(DrEdge); +DRRREF(DrEdge); + +DRVALUECLASS(DrEdge) +{ +public: + bool IsStartCliqueEdge(); + bool IsGangEdge(); + + bool operator==(DrEdgeR other); + + DrConnectorType m_type; + DrVertexRef m_remoteVertex; + int m_remotePort; +}; + +DRMAKEARRAYLIST(DrEdge); + +DRBASECLASS(DrEdgeHolder) +{ +public: + DrEdgeHolder(bool edgeHolderIsInput); + + void SetNumberOfEdges(int numberOfEdges); + void GrowNumberOfEdges(int newNumberOfEdges); + + int GetNumberOfEdges(); + + void Compact(DrVertexPtr thisVertex); + + void SetEdge(int edgeIndex, DrEdge edge); + DrEdge GetEdge(int edgeIndex); + +private: + DrEdgeListRef m_edge; + bool m_inputEdges; +}; +DRREF(DrEdgeHolder); + +DRBASECLASS(DrVertexOutputGenerator abstract) +{ +public: + virtual DrResourcePtr GetResource() DRABSTRACT; + virtual int GetVersion() DRABSTRACT; + virtual DrAffinityRef GetOutputAffinity(int output) DRABSTRACT; + virtual DrString GetURIForRead(int output, DrConnectorType type, DrResourcePtr runningResource) DRABSTRACT; +}; +DRREF(DrVertexOutputGenerator); + +typedef DrArray DrGeneratorArray; +DRAREF(DrGeneratorArray,DrVertexOutputGeneratorRef); + +DRCLASS(DrActiveVertexOutputGenerator) : public DrVertexOutputGenerator +{ +public: + void StoreOutputLengths(DrVertexProcessStatusPtr status, DrTimeInterval runningTime); + void SetProcess(DrProcessHandlePtr process, int vertexId, int version); + + virtual DrResourcePtr GetResource() DROVERRIDE; + virtual int GetVersion() DROVERRIDE; + virtual DrAffinityRef GetOutputAffinity(int output) DROVERRIDE; + virtual DrString GetURIForRead(int output, DrConnectorType type, DrResourcePtr runningResource) DROVERRIDE; + static DrString GetURIForWrite(DrEdgeHolderPtr outputEdges, DrResourcePtr runningResource, + int id, int version, + int output, DrConnectorType type, + DrMetaDataRef metaData); + + DrTimeInterval GetRunningTime(); + + static int s_intermediateCompressionMode; + +private: + int m_vertexId; + int m_version; + DrUINT64ArrayRef m_lengthArray; + DrTimeInterval m_runningTime; + DrString m_uriBase; + DrResourcePtr m_assignedNode; + int m_compression; +}; +DRREF(DrActiveVertexOutputGenerator); + +DRINTERFACE(DrIInputPartitionReader) +{ +public: + virtual DrAffinityRef GetAffinity(int partitionIndex) DRABSTRACT; + virtual DrString GetURIForRead(int partitionIndex, DrResourcePtr runningResource) DRABSTRACT; +}; +DRIREF(DrIInputPartitionReader); + +DRCLASS(DrStorageVertexOutputGenerator) : public DrVertexOutputGenerator +{ +public: + DrStorageVertexOutputGenerator(int partitionIndex, DrIInputPartitionReaderPtr reader); + + virtual DrResourcePtr GetResource() DROVERRIDE; + virtual int GetVersion() DROVERRIDE; + virtual DrAffinityRef GetOutputAffinity(int output) DROVERRIDE; + virtual DrString GetURIForRead(int output, DrConnectorType type, DrResourcePtr runningResource) DROVERRIDE; + +private: + int m_partitionIndex; + DrIInputPartitionReaderIRef m_reader; +}; +DRREF(DrStorageVertexOutputGenerator); + +DRCLASS(DrTeeVertexOutputGenerator) : public DrVertexOutputGenerator +{ +public: + DrTeeVertexOutputGenerator(DrVertexOutputGeneratorPtr wrappedGenerator); + + DrVertexOutputGeneratorPtr GetWrappedGenerator(); + + virtual DrResourcePtr GetResource() DROVERRIDE; + virtual int GetVersion() DROVERRIDE; + virtual DrAffinityRef GetOutputAffinity(int output) DROVERRIDE; + virtual DrString GetURIForRead(int output, DrConnectorType type, DrResourcePtr runningResource) DROVERRIDE; + +private: + DrVertexOutputGeneratorRef m_wrappedGenerator; +}; +DRREF(DrTeeVertexOutputGenerator); diff --git a/GraphManager/vertex/DrStageManager.h b/GraphManager/vertex/DrStageManager.h new file mode 100644 index 0000000..bef5ed4 --- /dev/null +++ b/GraphManager/vertex/DrStageManager.h @@ -0,0 +1,176 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRDECLARECLASS(DrStageManager); +DRREF(DrStageManager); + +DRDECLARECLASS(DrConnectionManager); +DRREF(DrConnectionManager); + +DRDECLARECLASS(DrGraph); +DRREF(DrGraph); + +DRCLASS(DrStageManager abstract) : public DrSharedCritSec +{ +public: + DrStageManager(DrGraphPtr graph); + virtual ~DrStageManager(); + + virtual void Discard() = 0; + + virtual DrGraphPtr GetGraph() = 0; + + virtual DrString GetStageName() = 0; + + virtual void InitializeForGraphExecution() = 0; + virtual void KickStateMachine() = 0; + + /* the stage will only be included in job monitoring summaries if + includeInJobStageList is true. By convention, stages that are + not active (e.g. input or output streams) are not included in + monitoring, since there's not much to say about them */ + virtual bool GetIncludeInJobStageList() = 0; + virtual void SetIncludeInJobStageList(bool includeInJobStageList) = 0; + + /* assign connector to manage dynamic modifications to the + subgraph edges connecting this stage from upstreamStage, for + example to manage a dynamic merge tree */ + virtual void AddDynamicConnectionManager(DrStageManagerPtr upStreamStage, + DrConnectionManagerPtr connector) = 0; + + /* similar to the above, but do it not at graph-build time, but during runtime. + The difference is that the vertices have to be registered */ + virtual void AddDynamicConnectionManagerAtRuntime(DrStageManagerPtr upstreamStage, + DrConnectionManagerPtr connector) = 0; + + /* RegisterVertex should be called once for each vertex that is + added to the stage. RegisterVertexDerived is a virtual method + that is called automatically after other actions in + RegisterVertex so that derived classes can keep track of what + is happening, and the base class implementation does nothing. A + client wishing to let the stage know that a vertex has been + added should call RegisterVertex. RegisterVertex should not be + called if RegisterVertexSplit is also called on the new + vertex. */ + virtual void RegisterVertex(DrVertexPtr vertex) = 0; + + /* Some dynamic graph modifications increase the size of a stage + by "splitting" new vertices off from a base vertex, and they + should call RegisterVertexSplit which will notify any relevant + downstream vertex stage managers about the split. In this case + RegisterVertex will automatically be called, and the client + should not call it as well. Different downstream stage managers + will generally want to deal differently with a split vertex, + for example they may also choose to split. Consequently the + split vertex should not be connected to any downstream vertices + before this method is called, and the downstream manager will + add edges as appropriate. By default it will add a new edge + between the new vertex and every downstream vertex that the + baseToSplitFrom is currently connected + to. RegisterVertexSplitDerived is a virtual method that is + called after other actions in RegisterVertexSplit so that + derived classes can keep track of what is happening, and the + base class implementation does nothing. A client wishing to let + the stage know that a vertex has been added should call + RegisterVertexSplit. */ + virtual void RegisterVertexSplit(DrVertexPtr vertex, DrVertexPtr baseToSplitFrom, + int splitIndex) = 0; + + /* Some dynamic graph modifications remove vertices from + stages. UnRegisterVertex should be called before a vertex is + removed, and it will automatically call connected downstream + managers to notify them that the vertex is being + removed. UnRegisterVertexDerived is a virtual method that is + called automatically after other actions in UnRegisterVertex so + that derived classes can keep track of what is happening, and + the base class implementation does nothing. A client wishing to + let the stage know that a vertex has been added should call + UnRegisterVertex. */ + virtual void UnRegisterVertex(DrVertexPtr vertex) = 0; + + /* VertexIsReady is called by the job manager before it attempts + to run any vertex. If VertexIsReady returns false then the job + manager will not start the vertex, otherwise it will proceed as + normal and run the vertex when it sees fit. VertexIsReady may + be called many times for a given vertex. If it ever returns + false, then the application must subsequently call + vertex->NotifyVertexIsReady() once the vertex is ready + to run, otherwise the job may never make progress. The default + implementation of this method always returns true but it can be + overridden. */ + virtual bool VertexIsReady(DrActiveVertexPtr vertex) = 0; + + /* NotifyVertexStatus is called every time the job manager + receives an update on the vertex, which happens periodically + while the vertex is running, and once when it completes. If + completionStatus is DryadError_VertexRunning the vertex has not + yet completed. If completionStatus is + DryadError_VertexCompleted the vertex has successfully + completed and NotifyVertexStatus will not be called again for + this version of the vertex. Otherwise the vertex has exited + with an error. + + DVertexProcessStatus is defined in + dryad/system/common/include/dvertexcommand.h and it includes + the version of the vertex (with GetVertexInstanceVersion()), + metadata including any error information (with + GetVertexMetaData()) and information about all of its input and + output channels. + + The default implementation does nothing. + */ + virtual void NotifyVertexStatus(DrActiveVertexPtr vertex, + HRESULT completionStatus, + DrVertexProcessStatusPtr status) = 0; + + virtual void NotifyVertexRunning(DrActiveVertexPtr vertex, + int executionVersion, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics) = 0; + virtual void NotifyVertexCompleted(DrActiveVertexPtr vertex, + int executionVersion, + DrResourcePtr machine, + DrVertexExecutionStatisticsPtr statistics) = 0; + virtual void NotifyVertexFailed(DrActiveVertexPtr vertex, int executionVersion, + DrResourcePtr machine, DrVertexExecutionStatisticsPtr statistics) = 0; + + virtual void CheckForDuplicates() = 0; + + virtual void NotifyInputReady(DrStorageVertexPtr vertex, DrAffinityPtr affinity) = 0; + + virtual void SetStillAddingVertices(bool stillAddingVertices) = 0; + + /* this returns a set containing all the vertices that have been + registered with this stage */ + virtual DrVertexListPtr GetVertexVector() = 0; + + /* this tells all pending and running versions of all vertices to abort */ + virtual void CancelAllVertices(DrErrorPtr reason) = 0; + + /* this adds self's monitoring information to stats, which is a + container for the statistics of all the stages in the job. */ + //void FillInStageStatistics(CsJobExecutionStatistics* stats); +}; +DRREF(DrStageManager); + +typedef DrArrayList DrStageList; +DRAREF(DrStageList,DrStageManagerRef); diff --git a/GraphManager/vertex/DrVertex.cpp b/GraphManager/vertex/DrVertex.cpp new file mode 100644 index 0000000..094f2ff --- /dev/null +++ b/GraphManager/vertex/DrVertex.cpp @@ -0,0 +1,1744 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +#ifndef _MANAGED +int DrVertexIdSource::s_nextId = 0; +#endif + +int DrVertexIdSource::GetNextId() +{ + ++s_nextId; + return s_nextId; +} + +DrVertex::DrVertex(DrStageManagerPtr stage) : DrSharedCritSec(stage) +{ + m_stage = stage; + + m_id = DrVertexIdSource::GetNextId(); + m_name = stage->GetStageName(); + m_inputEdges = DrNew DrEdgeHolder(true); + m_outputEdges = DrNew DrEdgeHolder(false); + + m_numberOfSubgraphInputs = 1; + m_numberOfSubgraphOutputs = 1; +} + +void DrVertex::InitializeFromOther(DrVertexPtr other, int suffix) +{ + if (suffix >= 0 && other->m_name.GetString() != DrNull) + { + m_name.SetF("%s[%d]", other->m_name.GetChars(), suffix); + } + else + { + m_name = other->m_name; + } + + m_inputEdges->SetNumberOfEdges(other->m_inputEdges->GetNumberOfEdges()); + m_outputEdges->SetNumberOfEdges(other->m_outputEdges->GetNumberOfEdges()); + + m_numberOfSubgraphInputs = other->m_numberOfSubgraphInputs; + m_numberOfSubgraphOutputs = other->m_numberOfSubgraphOutputs; +} + +void DrVertex::Discard() +{ + DiscardDerived(); + + m_stage = DrNull; + m_inputEdges = DrNull; + m_outputEdges = DrNull; +} + +DrVertexRef DrVertex::MakeCopy(int suffix) +{ + return MakeCopy(suffix, m_stage); +} + +DrStageManagerPtr DrVertex::GetStageManager() +{ + return m_stage; +} + +int DrVertex::GetId() +{ + return m_id; +} + +void DrVertex::SetName(DrString name) +{ + m_name = name; +} + +DrString DrVertex::GetName() +{ + return m_name; +} + +void DrVertex::SetNumberOfSubgraphInputs(int numberOfInputs) +{ + m_numberOfSubgraphInputs = numberOfInputs; +} + +int DrVertex::GetNumberOfSubgraphInputs() +{ + return m_numberOfSubgraphInputs; +} + +void DrVertex::SetNumberOfSubgraphOutputs(int numberOfOutputs) +{ + m_numberOfSubgraphOutputs = numberOfOutputs; +} + +int DrVertex::GetNumberOfSubgraphOutputs() +{ + return m_numberOfSubgraphOutputs; +} + +DrEdgeHolderPtr DrVertex::GetInputs() +{ + return m_inputEdges; +} + +DrEdgeHolderPtr DrVertex::GetOutputs() +{ + return m_outputEdges; +} + +DrVertexPtr DrVertex::RemoteInputVertex(int localPort) +{ + return m_inputEdges->GetEdge(localPort).m_remoteVertex; +} + +int DrVertex::RemoteInputPort(int localPort) +{ + return m_inputEdges->GetEdge(localPort).m_remotePort; +} + +DrVertexPtr DrVertex::RemoteOutputVertex(int localPort) +{ + return m_outputEdges->GetEdge(localPort).m_remoteVertex; +} + +int DrVertex::RemoteOutputPort(int localPort) +{ + return m_outputEdges->GetEdge(localPort).m_remotePort; +} + +void DrVertex::ConnectOutput(int localPort, DrVertexPtr remoteVertex, int remotePort, DrConnectorType type) +{ + DrEdge e; + e.m_remoteVertex = remoteVertex; + e.m_remotePort = remotePort; + e.m_type = type; + m_outputEdges->SetEdge(localPort, e); + + e.m_remoteVertex = this; + e.m_remotePort = localPort; + e.m_type = type; + remoteVertex->GetInputs()->SetEdge(remotePort, e); +} + +void DrVertex::DisconnectOutput(int localPort, bool disconnectRemote) +{ + DrEdge e; + e.m_type = DCT_Tombstone; + + if (disconnectRemote) + { + RemoteOutputVertex(localPort)->GetInputs()->SetEdge(RemoteOutputPort(localPort), e); + } + m_outputEdges->SetEdge(localPort, e); +} + +void DrVertex::DisconnectInput(int localPort, bool disconnectRemote) +{ + DrEdge e; + e.m_type = DCT_Tombstone; + + if (disconnectRemote) + { + RemoteInputVertex(localPort)->GetOutputs()->SetEdge(RemoteInputPort(localPort), e); + } + m_inputEdges->SetEdge(localPort, e); +} + + + +DrActiveVertex::DrActiveVertex(DrStageManagerPtr stage, DrProcessTemplatePtr processTemplate, + DrVertexTemplatePtr vertexTemplate) : DrVertex(stage) +{ + m_vertexTemplate = vertexTemplate; + m_processTemplate = processTemplate; + + m_affinity = DrNew DrAffinity(); + + m_argument = DrNew DrStringList(); + + m_totalOutputSizeHint = 0; + m_maxOpenInputChannelCount = 0; + m_maxOpenOutputChannelCount = 0; + + m_numberOfReportedCompletions = 0; + m_registeredWithGraph = false; + + m_runningVertex = DrNew DrVertexRecordList(); + m_spareCompletedRecord = DrNew DrCompletedVertexList(); +} + +void DrActiveVertex::DiscardDerived() +{ + if (m_cohort != DrNull) + { + m_cohort->Discard(); + } + m_cohort = DrNull; + + if (m_startClique != DrNull) + { + m_startClique->Discard(); + } + m_startClique = DrNull; + + m_pendingVersion = DrNull; + + if (m_runningVertex != DrNull) + { + int i; + for (i=0; iSize(); ++i) + { + m_runningVertex[i]->Discard(); + } + } + m_runningVertex = DrNull; + + m_completedRecord = DrNull; + + m_spareCompletedRecord = DrNull; +} + +DrVertexRef DrActiveVertex::MakeCopy(int suffix, DrStageManagerPtr stage) +{ + DrActiveVertexRef other = DrNew DrActiveVertex(stage, m_processTemplate, m_vertexTemplate); + other->InitializeFromOther(this, suffix); + + int i; + for (i=0; iSize(); ++i) + { + other->m_argument->Add(m_argument[i]); + } + + other->m_affinity = m_affinity; + + other->m_outputSizeHint = m_outputSizeHint; + other->m_totalOutputSizeHint = m_totalOutputSizeHint; + other->m_maxOpenInputChannelCount = m_maxOpenInputChannelCount; + other->m_maxOpenOutputChannelCount = m_maxOpenOutputChannelCount; + + DrVertexRef vo = other; + return vo; +} + +void DrActiveVertex::SetStartClique(DrStartCliquePtr startClique) +{ + m_startClique = startClique; +} + +DrStartCliquePtr DrActiveVertex::GetStartClique() +{ + return m_startClique; +} + +void DrActiveVertex::SetCohort(DrCohortPtr cohort) +{ + m_cohort = cohort; +} + +DrCohortPtr DrActiveVertex::GetCohort() +{ + return m_cohort; +} + + +void DrActiveVertex::AddArgument(DrNativeString argument) +{ + AddArgumentInternal(DrString(argument)); +} + +void DrActiveVertex::AddArgumentInternal(DrString argument) +{ + m_argument->Add(argument); +} + +DrAffinityPtr DrActiveVertex::GetAffinity() +{ + return m_affinity; +} + +DrVertexOutputGeneratorPtr DrActiveVertex::GetOutputGenerator(int edgeIndex, DrConnectorType type, int version) +{ + DrEdge e = m_outputEdges->GetEdge(edgeIndex); + DrAssert(e.m_type == type); + + int i; + switch (type) + { + case DCT_File: + /* doesn't matter what downstream version wants to run, just give it the completed file, + if there is one */ + return m_completedRecord; + + case DCT_Output: + /* nobody should be asking to read from this port */ + DrLogA("Someone is trying to get a generator for the source of an output edge"); + break; + + case DCT_Pipe: + case DCT_Fifo: + case DCT_FifoNonBlocking: + /* the downstream version can only connect to the matching upstream version */ + for (i=0; iSize(); ++i) + { + DrActiveVertexOutputGeneratorPtr g = m_runningVertex[i]->GetGenerator(); + if (g != DrNull) + { + /* the vertex has finished waiting for an available process and actually started running */ + if (g->GetVersion() == version) + { + return g; + } + } + } + return DrNull; + } + + return DrNull; +} + +DrString DrActiveVertex::GetURIForWrite(int port, + int /* unused id */, int /* unused version */, + int /* unused outputPort */, + DrResourcePtr /* unused runningResource */, + DrMetaDataRef /* unused metaData */) +{ + /* only output vertices can supply this */ + DrLogA("Someone is trying to get a URI for write on active vertex %d port %d", m_id, port); + return DrString(); +} + +DrVertexCommandBlockRef DrActiveVertex::MakeVertexStartCommand(DrVertexVersionGeneratorPtr generators, + DrResourcePtr runningResource) +{ + DrVertexCommandBlockRef cmd = DrNew DrVertexCommandBlock(); + cmd->SetVertexCommand(DrVC_Start); + + cmd->SetArgumentCount(m_argument->Size()); + int i; + for (i=0; iSize(); ++i) + { + cmd->GetArgumentVector()[i] = m_argument[i]; + } + + DrVertexProcessStatusPtr ps = cmd->GetProcessStatus(); + + ps->SetVertexId(m_id); + ps->SetVertexInstanceVersion(generators->GetVersion()); + + DrAssert(m_inputEdges->GetNumberOfEdges() == generators->GetNumberOfInputs()); + + ps->SetInputChannelCount(m_inputEdges->GetNumberOfEdges()); + ps->SetMaxOpenInputChannelCount(m_maxOpenInputChannelCount); + for (i=0; iGetNumberOfEdges(); ++i) + { + DrVertexOutputGeneratorPtr g = generators->GetGenerator(i); + DrEdge e = m_inputEdges->GetEdge(i); + DrChannelDescriptionPtr c = ps->GetInputChannels()[i]; + + DrString uri = g->GetURIForRead(e.m_remotePort, e.m_type, runningResource); + DrAffinityRef affinity = g->GetOutputAffinity(e.m_remotePort); + + c->SetChannelState(S_OK); + c->SetChannelURI(uri); + c->SetChannelTotalLength(affinity->GetWeight()); + } + + ps->SetOutputChannelCount(m_outputEdges->GetNumberOfEdges()); + ps->SetMaxOpenOutputChannelCount(m_maxOpenOutputChannelCount); + + if (m_outputSizeHint != DrNull) + { + DrAssert(m_outputSizeHint->Allocated() == m_outputEdges->GetNumberOfEdges()); + } + + for (i=0; iGetNumberOfEdges(); ++i) + { + DrMetaDataRef metaData = DrNew DrMetaData(); + if (m_outputSizeHint != DrNull || m_totalOutputSizeHint > 0) + { + UINT64 sizeHint; + if (m_outputSizeHint == DrNull) + { + sizeHint = m_totalOutputSizeHint / (UINT64) (m_outputEdges->GetNumberOfEdges()); + } + else + { + sizeHint = m_outputSizeHint[i]; + } + + metaData->Append(DrNew DrMTagUInt64(DrProp_InitialChannelWriteSize, sizeHint)); + } + + DrConnectorType t = m_outputEdges->GetEdge(i).m_type; + + DrString uri = + DrActiveVertexOutputGenerator::GetURIForWrite(m_outputEdges, runningResource, + m_id, generators->GetVersion(), i, t, metaData); + + DrChannelDescriptionPtr c = ps->GetOutputChannels()[i]; + c->SetChannelState(S_OK); + c->SetChannelURI(uri); + c->SetChannelMetaData(metaData); + } + + return cmd; +} + +void DrActiveVertex::KickStateMachine() +{ + if (m_stage->GetGraph()->IsRunning() == false) + { + /* the graph is aborting so no new versions are desired */ + return; + } + + /* Get the gang to set up a new pending version for us and everyone else + in the gang if necessary. */ + m_cohort->GetGang()->EnsurePendingVersion(0); +} + +void DrActiveVertex::RequestDuplicate(int versionToDuplicate) +{ + if (m_stage->GetGraph()->IsRunning() == false) + { + /* the graph is aborting so no new versions are desired */ + return; + } + + /* Get the gang to set up a new pending version for us and everyone else + in the gang if necessary. */ + m_cohort->GetGang()->EnsurePendingVersion(versionToDuplicate); +} + +void DrActiveVertex::NotifyVertexIsReady() +{ + DrAssert(m_stage->VertexIsReady(this)); + m_cohort->GetGang()->DecrementUnreadyVertices(); +} + +void DrActiveVertex::InitializeForGraphExecution() +{ + DrLogI("Vertex %d(%s) %s", m_id, m_name.GetChars(), + (m_registeredWithGraph) ? "already initialized" : "initializing now"); + + if (!m_registeredWithGraph) + { + m_stage->GetGraph()->IncrementActiveVertexCount(); + m_registeredWithGraph = true; + + m_cohort = DrNew DrCohort(m_processTemplate, this); + m_startClique = DrNew DrStartClique(this); + /* this sets the gang in the run clique and cohort, so we don't need to keep a reference afterwards */ + DrGangRef gang = DrNew DrGang(m_cohort, m_startClique); + m_cohort->SetGang(gang); + m_startClique->SetGang(gang); + + if (m_stage->VertexIsReady(this) == false) + { + m_cohort->GetGang()->IncrementUnreadyVertices(); + } + } +} + +void DrActiveVertex::RemoveFromGraphExecution() +{ + DrAssert(m_registeredWithGraph); + + DrLogI("Vertex %d(%s)", m_id, m_name.GetChars()); + m_stage->GetGraph()->DecrementActiveVertexCount(); + + m_registeredWithGraph = false; +} + +void DrActiveVertex::ReportFinalTopology(DrVertexTopologyReporterPtr reporter) +{ + DrResourcePtr location = DrNull; + DrTimeInterval time = 0; + if (m_completedRecord != DrNull) + { + location = m_completedRecord->GetResource(); + time = m_completedRecord->GetRunningTime(); + } + reporter->ReportFinalTopology(this, location, time); +} + +DrVertexRecordPtr DrActiveVertex::GetRunningVersion(int version) +{ + int i; + for (i=0; iSize(); ++i) + { + if (m_runningVertex[i]->GetVersion() == version) + { + return m_runningVertex[i]; + } + } + return DrNull; +} + +void DrActiveVertex::GrowPendingVersion(int numberOfEdgesToGrow) +{ + if (m_pendingVersion != DrNull) + { + m_pendingVersion->Grow(numberOfEdgesToGrow); + } +} + +/* !!! This is only ever called via DrGang::EnsurePendingVersion */ +void DrActiveVertex::InstantiateVersion(int version) +{ + DrAssert(m_pendingVersion == DrNull); + + DrVertexVersionGeneratorRef p = DrNew DrVertexVersionGenerator(version, m_inputEdges->GetNumberOfEdges()); + + int externalReadyInputs = 0; + int i; + for (i=0; iGetNumberOfEdges(); ++i) + { + DrEdge e = m_inputEdges->GetEdge(i); + DrVertexOutputGeneratorPtr g = e.m_remoteVertex->GetOutputGenerator(e.m_remotePort, e.m_type, + version); + if (g != DrNull) + { + p->SetGenerator(i, g); + if (e.IsStartCliqueEdge() == false) + { + ++externalReadyInputs; + } + } + } + + m_pendingVersion = p; + + if (externalReadyInputs > 0) + { + /* this just decrements the count of ready inputs in the start clique: it won't actually trigger + anything to happen even if all the inputs are ready */ + m_startClique->NotifyExternalInputsReady(version, externalReadyInputs); + } +} + +void DrActiveVertex::AddCurrentAffinitiesToList(int version, DrAffinityListPtr list) +{ + /* add our vertex affinity if any */ + if (m_affinity != DrNull) + { + list->Add(m_affinity); + } + + if (m_pendingVersion == DrNull || m_pendingVersion->GetVersion() != version) + { + DrLogA("Requested version %d doesn't exist in vertex %d(%s)", + version, GetId(), GetName().GetChars()); + } + + DrAssert(m_pendingVersion->GetNumberOfInputs() == m_inputEdges->GetNumberOfEdges()); + + int i; + for (i=0; iGetNumberOfInputs(); ++i) + { + /* now add the affinities from all the generators we have managed to accumulate at + this point. Some of them may be null if we aren't ready to run yet: this can happen + if we are being called because someone else in the cohort is ready */ + DrVertexOutputGeneratorPtr inputGenerator = m_pendingVersion->GetGenerator(i); + if (inputGenerator != DrNull) + { + DrEdge e = m_inputEdges->GetEdge(i); + if (e.IsGangEdge()) + { + int inputVersion = inputGenerator->GetVersion(); + /* gang edges should be paired to have both ends running the same version */ + DrAssert(inputVersion == version); + } + + DrAffinityRef inputAffinity = inputGenerator->GetOutputAffinity(e.m_remotePort); + if (inputAffinity != DrNull) + { + list->Add(inputAffinity); + } + } + } +} + +void DrActiveVertex::StartProcess(int version) +{ + DrAssert(m_stage->VertexIsReady(this)); + + /* this will start the process if we are the first vertex in the cohort to want to start */ + DrCohortProcessRef cohortProcess = m_cohort->EnsureProcess(m_stage->GetGraph(), version); + + DrVertexVersionGeneratorRef generator = m_pendingVersion; + m_pendingVersion = DrNull; + DrAssert(generator != DrNull && generator->GetVersion() == version); + + DrVertexRecordRef execution = DrNew DrVertexRecord(m_stage->GetGraph()->GetXCompute(), this, + cohortProcess, generator, m_vertexTemplate); + + m_runningVertex->Add(execution); +} + +void DrActiveVertex::CheckForProcessAlreadyStarted(int version) +{ + DrCohortProcessRef cohortProcess = m_cohort->GetProcessForVersion(version); + DrAssert(cohortProcess != DrNull); + + if (cohortProcess->ProcessHasStarted()) + { + ReactToStartedProcess(version, cohortProcess->GetProcess()); + } +} + +void DrActiveVertex::ReactToStartedProcess(int version, DrLockBox process) +{ + DrVertexRecordPtr record = GetRunningVersion(version); + if (record == DrNull) + { + if (m_pendingVersion == DrNull || m_pendingVersion->GetVersion() != version) + { + DrLogA("Vertex %d(%s) has no pending or running version %d", m_id, m_name.GetChars(), version); + } + return; + } + + DrVertexVersionGeneratorPtr inputs = record->NotifyProcessHasStarted(process); + if (inputs->Ready()) + { + record->StartRunning(); + } + + int i; + for (i=0; iGetNumberOfEdges(); ++i) + { + DrEdge e = m_outputEdges->GetEdge(i); + e.m_remoteVertex->ReactToUpStreamRunningProcess(e.m_remotePort, e.m_type, + record->GetGenerator()); + } +} + +void DrActiveVertex::ReactToUpStreamRunningProcess(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator) +{ + DrEdge e = m_inputEdges->GetEdge(inputPort); + DrAssert(e.m_type == type); + + switch (type) + { + case DCT_File: + case DCT_FifoNonBlocking: + /* we don't care when these processes start running. The non blocking active + edges are taken care of when the vertices start running. */ + return; + + case DCT_Output: + /* there shouldn't be an edge of this type leading to an active vertex */ + DrLogA("Output edge leading to active vertex %d on port %d", m_id, inputPort); + break; + + case DCT_Pipe: + case DCT_Fifo: + break; + } + + /* this won't be called until everyone in the start clique has had its StartProcess + method called, so there ought to be a running version */ + DrVertexRecordPtr record = GetRunningVersion(generator->GetVersion()); + DrAssert(record != DrNull); + + /* if this was the last active input it was waiting for, it will start itself running */ + record->SetActiveInput(inputPort, generator); +} + +void DrActiveVertex::ReactToStartedVertex(DrVertexRecordPtr record, DrVertexExecutionStatisticsPtr stats) +{ + DrActiveVertexOutputGeneratorPtr generator = record->GetGenerator(); + DrAssert(generator != DrNull); + + m_stage->NotifyVertexRunning(this, record->GetVersion(), generator->GetResource(), stats); + + int i; + for (i=0; iGetNumberOfEdges(); ++i) + { + DrEdge e = m_outputEdges->GetEdge(i); + e.m_remoteVertex->ReactToUpStreamRunningVertex(e.m_remotePort, e.m_type, generator); + } +} + +void DrActiveVertex::ReactToUpStreamRunningVertex(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator) +{ + DrEdge e = m_inputEdges->GetEdge(inputPort); + DrAssert(e.m_type == type); + + switch (type) + { + case DCT_File: + case DCT_Pipe: + case DCT_Fifo: + /* we don't care when these vertices start running. The blocking active edges + were taken care of when the process started running. */ + return; + + case DCT_Output: + /* there shouldn't be an edge of this type leading to an active vertex */ + DrLogA("Output edge leading to active vertex %d on port %d", m_id, inputPort); + break; + + case DCT_FifoNonBlocking: + break; + } + + /* nobody in the gang should be running unless we're all ready */ + DrAssert(m_stage->VertexIsReady(this)); + + /* the versions have to match up at the ends of an active edge */ + DrAssert(m_pendingVersion != DrNull && m_pendingVersion->GetVersion() == generator->GetVersion()); + + DrAssert(m_pendingVersion->GetGenerator(inputPort) == DrNull); + m_pendingVersion->SetGenerator(inputPort, generator); + + /* this just decrements the count of ready inputs in the start clique: it won't actually trigger + anything to happen even if all the inputs are ready */ + m_startClique->NotifyExternalInputsReady(m_pendingVersion->GetVersion(), 1); + + /* now actually start everyone in our clique if this was the last edge we were waiting for */ + m_startClique->StartVersionIfReady(m_pendingVersion->GetVersion()); +} + +void DrActiveVertex::ReactToRunningVertexUpdate(DrVertexRecordPtr /* unused record */, + HRESULT exitStatus, DrVertexProcessStatusPtr status) +{ + m_stage->NotifyVertexStatus(this, exitStatus, status); +} + +void DrActiveVertex::ReactToCompletedVertex(DrVertexRecordPtr record, DrVertexExecutionStatisticsPtr stats) +{ + DrLogI("Reacting to completed vertex %d.%d", this->m_id, record->GetVersion()); + DrActiveVertexOutputGeneratorRef newCompletedRecord = record->GetGenerator(); + DrAssert(newCompletedRecord != DrNull); + + bool becomingComplete = false; + if (m_completedRecord == DrNull) + { + DrLogI("Becoming complete"); + becomingComplete = true; + m_completedRecord = newCompletedRecord; + } + else + { + DrLogI("Adding spare completed record"); + m_spareCompletedRecord->Add(newCompletedRecord); + } + + // + // go down output edges and let the start clique notify external + // inputs ready, this just does bookkeeping on the invariants for each + // downstream vertex about whether its inupts are ready or not + // + int i; + for (i=0; iGetNumberOfEdges(); ++i) + { + DrEdge e = m_outputEdges->GetEdge(i); + e.m_remoteVertex->NotifyUpStreamCompletedVertex(e.m_remotePort, e.m_type, newCompletedRecord); + } + + // + // during this call the graph may be rewritten!!! The set of output edges may be different, + // in particular + // + DrLogI("Notifying stage of vertex %d.%d completion", this->m_id, record->GetVersion()); + m_stage->NotifyVertexCompleted(this, record->GetVersion(), newCompletedRecord->GetResource(), stats); + ++m_numberOfReportedCompletions; + + DrString message; + message.SetF("completed vertex %s", (PCSTR)m_name.GetChars()); + m_stage->GetGraph()->GetXCompute()->IncrementProgress(message.GetChars()); + + // + // go down output edges and actually prod the vertices to start running if + // their external inputs are ready. This is separated from the bookkeeping above + // because invariant-checking is done during graph rewrites and it makes things simpler + // to have the number of ready inputs correct before the rewrite. However we can't + // actually start things running (in the following loop) until after the rewrite, obviously + // + for (i=0; iGetNumberOfEdges(); ++i) + { + DrEdge e = m_outputEdges->GetEdge(i); + e.m_remoteVertex->ReactToUpStreamCompletedVertex(e.m_remotePort, e.m_type, newCompletedRecord, stats); + } + + m_runningVertex->Remove(record); + + if (becomingComplete) + { + DrLogI("Notifying graph of vertex %d.%d completion", this->m_id, record->GetVersion()); + m_stage->GetGraph()->NotifyActiveVertexComplete(); + } + + // + // The gang keeps track of the number of vertices in any given version that have successfully + // completed. If everyone in the gang manages to get a completed vertex with the same version + // then we declare success and, for example, kill of any duplicate executions within the gang + // + DrLogI("Calling ReactToCompletedVertex for gang"); + m_cohort->GetGang()->ReactToCompletedVertex(newCompletedRecord->GetVersion()); +} + +void DrActiveVertex::NotifyUpStreamCompletedVertex(int inputPort, DrConnectorType type, DrVertexOutputGeneratorPtr generator) +{ + DrEdge e = m_inputEdges->GetEdge(inputPort); + DrAssert(e.m_type == type); + + if (type != DCT_File) + { + /* we don't care when these complete */ + return; + } + + if (m_pendingVersion != DrNull) + { + /* there shouldn't be pending versions if we aren't ready to run */ + DrAssert(m_stage->VertexIsReady(this)); + + if (m_pendingVersion->GetGenerator(inputPort) == DrNull) + { + m_pendingVersion->SetGenerator(inputPort, generator); + } + else + { + DrAssert(m_pendingVersion->Ready() == false); + } + + /* this just decrements the count of ready inputs in the start clique: it + won't actually trigger anything to happen even if all the inputs are ready */ + m_startClique->NotifyExternalInputsReady(m_pendingVersion->GetVersion(), 1); + } +} + +void DrActiveVertex::ReactToUpStreamCompletedVertex(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr /* unused generator */, + DrVertexExecutionStatisticsPtr /* unused stats */) +{ + DrEdge e = m_inputEdges->GetEdge(inputPort); + DrAssert(e.m_type == type); + + if (type != DCT_File) + { + /* we don't care when these complete */ + return; + } + + if (m_pendingVersion != DrNull) + { + /* someone's waiting to run: let them all know we at least are ready */ + + /* there shouldn't be pending versions if we aren't ready to run */ + DrAssert(m_stage->VertexIsReady(this)); + + /* now actually start everyone in our clique if this was the last edge + we were waiting for. If the version is ready to start then m_pendingVersion will + be DrNull when the call returns */ + m_startClique->StartVersionIfReady(m_pendingVersion->GetVersion()); + } +} + +bool DrActiveVertex::HasPendingVersion() +{ + return (m_pendingVersion != DrNull); +} + +bool DrActiveVertex::HasRunningVersion(int version) +{ + int i; + for (i=0; iSize(); ++i) + { + if (m_runningVertex[i]->GetVersion() == version) + { + return true; + } + } + + return false; +} + +bool DrActiveVertex::HasCompletedVersion(int version) +{ + if (m_completedRecord == DrNull) + { + DrAssert(m_spareCompletedRecord->Size() == 0); + return false; + } + + if (m_completedRecord->GetVersion() == version) + { + return true; + } + + int i; + for (i=0; iSize(); ++i) + { + if (m_spareCompletedRecord[i]->GetVersion() == version) + { + return true; + } + } + + return false; +} + +/* !!! This is only called from the Gang's CancelVersion method. It should stay that way, so that + the gang's invariants about which version is completed etc. track the state of the member + vertices */ +void DrActiveVertex::CancelVersion(int version, DrErrorPtr error, DrCohortProcessPtr cohortProcess) +{ + DrLogI("Canceling version %d for vertex %d", version, m_id); + if (m_pendingVersion != DrNull && m_pendingVersion->GetVersion() == version) + { + /* we weren't ready to start so we haven't got a vertex record */ + DrLogI("We weren't ready to start so we haven't got a vertex record"); + if (cohortProcess != DrNull) + { + if (cohortProcess->GetProcess().IsNull() == false) + { + /* however the process did already start, so we have to tell it we're never going + to run */ + DrLogI("However the process did already start, so we have to tell it we're never going to run"); + DrVertexRecord::SendTerminateCommand(m_id, version, cohortProcess->GetProcess()); + } + cohortProcess->NotifyVertexCompletion(); + } + m_pendingVersion = DrNull; + + return; + } + + int i; + for (i=0; iSize(); ++i) + { + if (m_runningVertex[i]->GetVersion() == version) + { + /* this will handle all the cleanup, killing the vertex record and calling back in + to ReactToFailedVertex, all on this callstack, so nothing more to do */ + DrLogI("Found version %d in running vertices, triggering failure", version); + m_runningVertex[i]->TriggerFailure(error); + + return; + } + } + + if (m_completedRecord != DrNull && m_completedRecord->GetVersion() == version) + { + /* this will handle the cleanup including setting the completed record to + NULL, managing the spare completed records, etc., so nothing more to do */ + DrLogI("Found version %d in completed record, calling ReactToFailedVertex", version); + ReactToFailedVertex(m_completedRecord, DrNull, DrNull, DrNull, error); + + return; + } + + for (i=0; iSize(); ++i) + { + if (m_spareCompletedRecord[i]->GetVersion() == version) + { + /* this will handle the cleanup including managing the spare completed records, + etc., so nothing more to do */ + DrLogI("Found version %d in spare completed records, calling ReactToFailedVertex", version); + ReactToFailedVertex(m_spareCompletedRecord[i], DrNull, DrNull, DrNull, error); + + return; + } + } + + /* we had already stopped: nothing to do */ + DrLogI("Version %d already stopped, nothing to do", version); +} + +void DrActiveVertex::ReactToDownStreamFailure(int port, + DrConnectorType type, + int /* unused downStreamVersion */) +{ + DrEdge e = m_outputEdges->GetEdge(port); + DrAssert(e.m_type == type); + + /* a downstream output should not be failing */ + DrAssert(type != DCT_Output); + + /* nothing to do: vertices connected by active edges will be failed by the gang anyway */ +} + +void DrActiveVertex::ReactToUpStreamFailure(int port, DrConnectorType type, + DrVertexOutputGeneratorPtr /* unused generator*/, + int /* unused downStreamVersion */) +{ + DrEdge e = m_inputEdges->GetEdge(port); + DrAssert(e.m_type == type); + + DrAssert(type != DCT_Output); + + if (m_pendingVersion != DrNull && e.m_type == DCT_File) + { + /* oldGenerator is the one we used to have from the vertex, if any */ + DrVertexOutputGeneratorRef oldGenerator = m_pendingVersion->GetGenerator(port); + if (oldGenerator != DrNull) + { + DrVertexPtr remoteVertex = e.m_remoteVertex; + int remotePort = e.m_remotePort; + DrVertexOutputGeneratorRef failedVertexGenerator = + remoteVertex->GetOutputGenerator(remotePort, e.m_type, 0); + if (failedVertexGenerator == DrNull) + { + /* we were holding a generator from the upstream vertex but it has now + been revoked */ + + m_pendingVersion->SetGenerator(port, DrNull); + m_startClique->GrowExternalInputs(1); + } + } + } +} + +void DrActiveVertex::ReactToFailedVertex(DrVertexOutputGeneratorPtr failedGenerator, + DrVertexVersionGeneratorPtr inputs, + DrVertexExecutionStatisticsPtr stats, DrVertexProcessStatusPtr status, + DrErrorPtr originalReason) +{ + int version = failedGenerator->GetVersion(); + + DrLogI("Vertex %d.%d", m_id, version); + + m_stage->NotifyVertexFailed(this, version, failedGenerator->GetResource(), stats); + + bool foundVersion = false; + + int i; + for (i=0; iSize(); ++i) + { + DrVertexRecordRef record = m_runningVertex[i]; + if (record->GetVersion() == version) + { + m_runningVertex->RemoveAt(i); + + int j; + for (j=0; jGetNumberOfEdges(); ++j) + { + DrEdge e = m_inputEdges->GetEdge(j); + e.m_remoteVertex->ReactToDownStreamFailure(e.m_remotePort, e.m_type, version); + } + + DrActiveVertexOutputGeneratorPtr generator = record->GetGenerator(); + if (generator != DrNull) + { + /* we have actually started running so tell any downstream guys that might care */ + for (j=0; jGetNumberOfEdges(); ++j) + { + DrEdge e = m_outputEdges->GetEdge(j); + + e.m_remoteVertex->ReactToUpStreamFailure(e.m_remotePort, e.m_type, + generator, version); + } + } + + foundVersion = true; + break; + } + } + + if (m_completedRecord != DrNull && m_completedRecord->GetVersion() == version) + { + DrActiveVertexOutputGeneratorRef generator = m_completedRecord; + m_completedRecord = DrNull; + + for (i=0; iGetNumberOfEdges(); ++i) + { + DrEdge e = m_outputEdges->GetEdge(i); + + e.m_remoteVertex->ReactToUpStreamFailure(e.m_remotePort, e.m_type, generator, version); + } + + int nSpares = m_spareCompletedRecord->Size(); + if (nSpares > 0) + { + m_completedRecord = m_spareCompletedRecord[nSpares-1]; + m_spareCompletedRecord->RemoveAt(nSpares-1); + } + else + { + m_stage->GetGraph()->NotifyActiveVertexRevoked(); + } + + foundVersion = true; + } + + for (i=0; iSize(); ++i) + { + if (m_spareCompletedRecord[i]->GetVersion() == version) + { + DrActiveVertexOutputGeneratorRef generator = m_spareCompletedRecord[i]; + m_spareCompletedRecord->RemoveAt(i); + + for (i=0; iGetNumberOfEdges(); ++i) + { + DrEdge e = m_outputEdges->GetEdge(i); + + e.m_remoteVertex->ReactToUpStreamFailure(e.m_remotePort, e.m_type, generator, version); + } + + foundVersion = true; + break; + } + } + + if (foundVersion) + { + int inputToKill = m_stage->GetGraph()->ReportFailure(this, version, status, originalReason); + + if (inputToKill >= 0) + { + DrAssert(inputs != DrNull); + + /* if the failure manager thinks we should be killing one of our upstream vertices, + tell it so */ + DrEdge inputEdgeToKill = m_inputEdges->GetEdge(inputToKill); + DrVertexPtr vertex = inputEdgeToKill.m_remoteVertex; + DrVertexOutputGeneratorPtr g = inputs->GetGenerator(inputToKill); + + DrString reason = "Downstream vertex reported a read error: invalidating completed version"; + DrErrorRef error = DrNew DrError(DrError_BadOutputReported, "DrActiveVertex", reason); + error->AddProvenance(originalReason); + + DrLogI("Vertex %d.%d: %s calling ReactToFailedVertex", this->m_id, version, reason.GetChars()); + vertex->ReactToFailedVertex(g, DrNull, DrNull, DrNull, error); + } + + if (originalReason->m_code != DrError_CohortShutdown) + { + /* send out a failed state to everyone. anyone who has already completed successfully will ignore it */ + DrString reason = "Cohort being cancelled"; + DrErrorRef error = DrNew DrError(DrError_CohortShutdown, "DrCohort", reason); + error->AddProvenance(originalReason); + + m_cohort->GetGang()->CancelVersion(version, error); + } + } +} + + +void DrActiveVertex::CompactPendingVersion(DrEdgeHolderPtr edgesBeingCompacted, int numberOfEdgesAfterCompaction) +{ + // a) assert that the m_runningVertex list is empty and m_completedRecord is NULL + // (to make sure nobody is changing the vertex after it has + // started running) + DrAssert(m_runningVertex->Size() == 0); + DrAssert(m_completedRecord == DrNull); + + // b) compact the m_pendingVersion if any + if (m_pendingVersion != DrNull) + { + m_pendingVersion->Compact(edgesBeingCompacted, numberOfEdgesAfterCompaction); + } +} + +void DrActiveVertex::CancelAllVersions(DrErrorPtr reason) +{ + m_cohort->GetGang()->CancelAllVersions(reason); +} + +int DrActiveVertex::GetNumberOfReportedCompletions() +{ + return m_numberOfReportedCompletions; +} + +DrString DrActiveVertex::GetDescription() +{ + DrString description; + description.SetF("%d (%s)", GetId(), GetName().GetChars()); + return description; +} + + +DrStorageVertex::DrStorageVertex(DrStageManagerPtr stage, int partitionIndex, + DrIInputPartitionReaderPtr reader) : DrVertex(stage) +{ + m_generator = DrNew DrStorageVertexOutputGenerator(partitionIndex, reader); + m_registeredInputReady = false; +} + +DrStorageVertex::DrStorageVertex(DrStageManagerPtr stage, DrStorageVertexOutputGeneratorPtr generator) + : DrVertex(stage) +{ + m_generator = generator; + m_registeredInputReady = false; +} + +void DrStorageVertex::DiscardDerived() +{ + m_generator = DrNull; +} + +DrVertexRef DrStorageVertex::MakeCopy(int suffix, DrStageManagerPtr stage) +{ + DrStorageVertexRef other = DrNew DrStorageVertex(stage, m_generator); + other->InitializeFromOther(this, suffix); + + DrVertexRef vo = other; + return vo; +} + +DrVertexOutputGeneratorPtr DrStorageVertex::GetOutputGenerator(int /* unused edgeIndex */, + DrConnectorType type, + int /* unused version */) +{ + DrAssert(type == DCT_File); + return m_generator; +} + +DrString DrStorageVertex::GetURIForWrite(int port, + int /* unused id */, int /* unused version */, + int /* unused outputPort */, + DrResourcePtr /* unused runningResource */, + DrMetaDataRef /* unused metaData */) +{ + /* only output vertices can supply this */ + DrLogA("Storage vertex %d asked for write URI on port %d", m_id, port); + return DrString(); +} + +void DrStorageVertex::KickStateMachine() +{ + /* nothing to do */ +} + +void DrStorageVertex::InitializeForGraphExecution() +{ + DrAssert(m_registeredInputReady == false); + m_registeredInputReady = true; + + DrAffinityRef affinity = m_generator->GetOutputAffinity(0); + m_stage->NotifyInputReady(this, affinity); +} + +void DrStorageVertex::RemoveFromGraphExecution() +{ + /* nothing to do */ +} + +void DrStorageVertex::ReportFinalTopology(DrVertexTopologyReporterPtr reporter) +{ + reporter->ReportFinalTopology(this, DrNull, 0); +} + +void DrStorageVertex::ReactToUpStreamRunningProcess(int inputPort, + DrConnectorType /* unused type */, + DrVertexOutputGeneratorPtr /* unused generator */) +{ + /* for now a storage vertex shouldn't have any upstream neighbours */ + DrLogA("Storage vertex %d has upstream neighbor on port %d", m_id, inputPort); +} + +void DrStorageVertex::ReactToUpStreamRunningVertex(int inputPort, + DrConnectorType /* unused type */, + DrVertexOutputGeneratorPtr /* unused generator */) +{ + /* for now a storage vertex shouldn't have any upstream neighbours */ + DrLogA("Storage vertex %d has upstream neighbor on port %d", m_id, inputPort); +} + +void DrStorageVertex::NotifyUpStreamCompletedVertex(int /*unused inputPort*/, DrConnectorType /* unused type */, DrVertexOutputGeneratorPtr /* unused generator */) +{ +} + +void DrStorageVertex::ReactToUpStreamCompletedVertex(int inputPort, + DrConnectorType /* unused type */, + DrVertexOutputGeneratorPtr /* unused generator */, + DrVertexExecutionStatisticsPtr /* unused stats */) +{ + /* for now a storage vertex shouldn't have any upstream neighbours */ + DrLogA("Storage vertex %d has upstream neighbor on port %d", m_id, inputPort); +} + +void DrStorageVertex::ReactToDownStreamFailure(int /* unused port */, + DrConnectorType /* unused type */, + int /* unused downStreamVersion */) +{ + /* we don't care */ +} + +void DrStorageVertex::ReactToUpStreamFailure(int inputPort, + DrConnectorType /* unused type */, + DrVertexOutputGeneratorPtr /* unused generator */, + int /* unused downStreamVersion */) +{ + /* for now a storage vertex shouldn't have any upstream neighbours */ + DrLogA("Storage vertex %d has upstream neighbor on port %d", m_id, inputPort); +} + +void DrStorageVertex::ReactToFailedVertex(DrVertexOutputGeneratorPtr /* unused failedGenerator */, + DrVertexVersionGeneratorPtr /* unused inputs */, + DrVertexExecutionStatisticsPtr /* unused stats */, + DrVertexProcessStatusPtr /* unused status */, + DrErrorPtr originalReason) +{ + m_stage->GetGraph()->ReportStorageFailure(this, originalReason); +} + + +void DrStorageVertex::CompactPendingVersion(DrEdgeHolderPtr /* unused edgesBeingCompacted*/, int /* unused numberOfEdgesAfterCompaction*/) +{ + // + // Only DrActiveVertex can do this but when this method is called + // we only have a DrVertex so this needs to be part of the + // interface. + // +} + +void DrStorageVertex::CancelAllVersions(DrErrorPtr /* unused reason*/) +{ + // + // Only DrActiveVertex can do this but when this method is called + // we only have a DrVertex so this needs to be part of the + // interface. + // +} + + +bool DrOutputPartition::operator==(DrOutputPartitionR other) +{ + return + (m_id == other.m_id && + m_version == other.m_version && + m_outputPort == other.m_outputPort && + m_resource == other.m_resource && + m_size == other.m_size); +} + + +DrOutputVertex::DrOutputVertex(DrStageManagerPtr stage, int partitionIndex, + DrIOutputPartitionGeneratorPtr generator) : DrVertex(stage) +{ + m_generator = generator; + m_partitionIndex = partitionIndex; + m_runningVersion = DrNew DrOutputPartitionList(); + m_successfulVersion = DrNew DrOutputPartitionList(); +} + +void DrOutputVertex::DiscardDerived() +{ + m_generator = DrNull; + m_runningVersion = DrNull; + m_successfulVersion = DrNull; +} + +DrVertexRef DrOutputVertex::MakeCopy(int suffix, DrStageManagerPtr stage) +{ + DrOutputVertexRef other = DrNew DrOutputVertex(stage, suffix, m_generator); + other->InitializeFromOther(this, suffix); + + m_generator->AddDynamicSplitVertex(this); + + DrVertexRef vo = other; + return vo; +} + +DrOutputPartition DrOutputVertex::FinalizeVersions() +{ + int i; + + /* abandon all but one successful version */ + for (i=1; iSize(); ++i) + { + m_generator->AbandonVersion(m_partitionIndex, + m_successfulVersion[i].m_id, + m_successfulVersion[i].m_version, + m_successfulVersion[i].m_outputPort, + m_successfulVersion[i].m_resource); + } + + /* abandon all versions that are 'still running' */ + for (i=0; iSize(); ++i) + { + m_generator->AbandonVersion(m_partitionIndex, + m_runningVersion[i].m_id, + m_runningVersion[i].m_version, + m_runningVersion[i].m_outputPort, + m_runningVersion[i].m_resource); + } + + return m_successfulVersion[0]; +} + +DrVertexOutputGeneratorPtr DrOutputVertex::GetOutputGenerator(int edgeIndex, + DrConnectorType /* unused type */, + int /* unused version */) +{ + /* we never run, and we're never upstream of anyone, so this shouldn't get called */ + DrLogA("Output generator called for output vertex %d port %d", m_id, edgeIndex); + return DrNull; +} + +DrString DrOutputVertex::GetURIForWrite(int port, int id, int version, int outputPort, + DrResourcePtr runningResource, DrMetaDataRef metaData) +{ + DrLogI("Output vertex %d(%s) making URI for write %d %d %d %d %d", + m_id, m_name.GetChars(), port, id, version, outputPort, m_partitionIndex); + + int i; + for (i=0; iSize(); ++i) + { + DrAssert(m_runningVersion[i].m_version != version); + } + for (i=0; iSize(); ++i) + { + DrAssert(m_successfulVersion[i].m_version != version); + } + + DrOutputPartition p; + p.m_id = id; + p.m_version = version; + p.m_outputPort = outputPort; + p.m_resource = runningResource; + p.m_size = 0; + m_runningVersion->Add(p); + + return m_generator->GetURIForWrite(m_partitionIndex, id, version, port, runningResource, metaData); +} + +void DrOutputVertex::KickStateMachine() +{ + /* nothing to do */ +} + +void DrOutputVertex::InitializeForGraphExecution() +{ + /* nothing to do */ +} + +void DrOutputVertex::RemoveFromGraphExecution() +{ + /* nothing to do */ +} + +void DrOutputVertex::ReportFinalTopology(DrVertexTopologyReporterPtr reporter) +{ + reporter->ReportFinalTopology(this, DrNull, 0); +} + +void DrOutputVertex::ReactToUpStreamRunningProcess(int /* unused inputPort */, + DrConnectorType /* unused type */, + DrVertexOutputGeneratorPtr /* unused generator */) +{ + /* nothing to do */ +} + +void DrOutputVertex::ReactToUpStreamRunningVertex(int /* unused inputPort */, + DrConnectorType /* unused type */, + DrVertexOutputGeneratorPtr /* unused generator */) +{ + /* nothing to do */ +} + +void DrOutputVertex::NotifyUpStreamCompletedVertex(int /*unused inputPort*/, DrConnectorType /*unused type*/, DrVertexOutputGeneratorPtr /*unused generator*/) +{ +} + +void DrOutputVertex::ReactToUpStreamCompletedVertex(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator, + DrVertexExecutionStatisticsPtr stats) +{ + DrEdge e = m_inputEdges->GetEdge(inputPort); + DrAssert(e.m_type == type); + + int i; + for (i=0; iSize(); ++i) + { + DrOutputPartition p = m_runningVersion[i]; + if (p.m_version == generator->GetVersion()) + { + p.m_size = stats->m_outputData[e.m_remotePort]->m_dataWritten; + m_successfulVersion->Add(p); + m_runningVersion->RemoveAt(i); + return; + } + } + DrLogA("Output vertex %d couldn't find matching running record for completion version %d", + m_id, generator->GetVersion()); +} + +void DrOutputVertex::ReactToDownStreamFailure(int port, + DrConnectorType /* unused type */, + int /* unused downStreamVersion */) +{ + /* for now we shouldn't have any downstream neighbours */ + DrLogA("Output vertex %d has downstream neighbor on port %d", m_id, port); +} + +void DrOutputVertex::ReactToUpStreamFailure(int /* unused port */, + DrConnectorType /* unused type */, + DrVertexOutputGeneratorPtr generator, + int /* unused downStreamVersion */) +{ + int i; + for (i=0; iSize(); ++i) + { + if (m_runningVersion[i].m_version == generator->GetVersion()) + { + m_runningVersion->RemoveAt(i); + return; + } + } + + /* there's no particular reason to invalidate a successful version we've already got */ +} + +void DrOutputVertex::ReactToFailedVertex(DrVertexOutputGeneratorPtr /* unused failedGenerator */, + DrVertexVersionGeneratorPtr /* unused inputs */, + DrVertexExecutionStatisticsPtr /* unused stats */, + DrVertexProcessStatusPtr /* unused status */, + DrErrorPtr /* unused originalReason */) +{ + /* we should never appear to fail */ + DrLogA("Output vertex %d claiming to fail", m_id); +} + + +void DrOutputVertex::CompactPendingVersion(DrEdgeHolderPtr /* unused edgesBeingCompacted*/, int /* unused numberOfEdgesAfterCompaction*/) +{ + // + // Only DrActiveVertex can do this but when this method is called + // we only have a DrVertex so this needs to be part of the + // interface. + // +} + +void DrOutputVertex::CancelAllVersions(DrErrorPtr /* unused reason*/) +{ + // + // Only DrActiveVertex can do this but when this method is called + // we only have a DrVertex so this needs to be part of the + // interface. + // +} + + +DrTeeVertex::DrTeeVertex(DrStageManagerPtr stage) : DrVertex(stage) +{ +} + +void DrTeeVertex::DiscardDerived() +{ + m_generator = DrNull; +} + +DrVertexRef DrTeeVertex::MakeCopy(int suffix, DrStageManagerPtr stage) +{ + DrTeeVertexRef other = DrNew DrTeeVertex(stage); + other->InitializeFromOther(this, suffix); + + DrVertexRef vo = other; + return vo; +} + +DrVertexOutputGeneratorPtr DrTeeVertex::GetOutputGenerator(int /* unused edgeIndex */, + DrConnectorType type, + int /* unused version */) +{ + DrAssert(type == DCT_File); + return m_generator; +} + +DrString DrTeeVertex::GetURIForWrite(int /* unused port */, + int /* unused id */, + int /* unused version */, + int /* unused outputPort */, + DrResourcePtr /* unused runningResource */, + DrMetaDataRef /* unused metaData */) +{ + /* only output vertices can supply this */ + DrLogA("Tee vertex %d asked for write URI", m_id); + return DrString(); +} + +void DrTeeVertex::KickStateMachine() +{ + /* nothing to do */ +} + +void DrTeeVertex::InitializeForGraphExecution() +{ + DrAssert(m_inputEdges->GetNumberOfEdges() == 1); + DrEdge e = m_inputEdges->GetEdge(0); + + /* fill in the generator if it's already there */ + DrVertexOutputGeneratorRef generator = + e.m_remoteVertex->GetOutputGenerator(e.m_remotePort, e.m_type, 0); + if (generator != DrNull) + { + ReactToUpStreamCompletedVertex(0, e.m_type, generator, DrNull); + } +} + +void DrTeeVertex::RemoveFromGraphExecution() +{ + /* nothing to do */ +} + +void DrTeeVertex::ReportFinalTopology(DrVertexTopologyReporterPtr reporter) +{ + reporter->ReportFinalTopology(this, DrNull, 0); +} + +void DrTeeVertex::ReactToUpStreamRunningProcess(int /* unused inputPort */, + DrConnectorType /* unused type */, + DrVertexOutputGeneratorPtr /* unused generator */) +{ + /* we don't care */ +} + +void DrTeeVertex::ReactToUpStreamRunningVertex(int /* unused inputPort */, + DrConnectorType /* unused type */, + DrVertexOutputGeneratorPtr /* unused generator */) +{ + /* we don't care */ +} + +void DrTeeVertex::NotifyUpStreamCompletedVertex(int /* unused inputPort*/, DrConnectorType /* unused type*/, DrVertexOutputGeneratorPtr /* unused generator*/) +{ +} + +void DrTeeVertex::ReactToUpStreamCompletedVertex(int inputPort, + DrConnectorType /* unused type */, + DrVertexOutputGeneratorPtr generator, + DrVertexExecutionStatisticsPtr stats) +{ + /* we'll steal their generator to use as our own, but wrap it to return the value for + inputPort 0 on all requests */ + DrAssert(inputPort == 0); + m_generator = DrNew DrTeeVertexOutputGenerator(generator); + + int i; + for (i=0; iGetNumberOfEdges(); ++i) + { + DrEdge e = m_outputEdges->GetEdge(i); + e.m_remoteVertex->NotifyUpStreamCompletedVertex(e.m_remotePort, e.m_type, m_generator); + e.m_remoteVertex->ReactToUpStreamCompletedVertex(e.m_remotePort, e.m_type, m_generator, stats); + } +} + +void DrTeeVertex::ReactToDownStreamFailure(int /* unused port */, + DrConnectorType /* unused type */, + int /* unused downStreamVersion */) +{ + /* we don't care */ +} + +void DrTeeVertex::ReactToUpStreamFailure(int port, DrConnectorType type, + DrVertexOutputGeneratorPtr generator, int /* unused downStreamVersion */) +{ + DrAssert(port == 0); + DrAssert(type == DCT_File); + + if (m_generator && (generator == m_generator->GetWrappedGenerator())) + { + DrVertexOutputGeneratorRef oldGenerator = m_generator; + + /* propagate the fact that this file is unreadable */ + m_generator = DrNull; + + int i; + for (i=0; iGetNumberOfEdges(); ++i) + { + DrEdge e = m_outputEdges->GetEdge(i); + e.m_remoteVertex->ReactToUpStreamFailure(e.m_remotePort, e.m_type, oldGenerator, 0); + } + } +} + + +void DrTeeVertex::ReactToFailedVertex(DrVertexOutputGeneratorPtr failedGenerator, + DrVertexVersionGeneratorPtr inputs, + DrVertexExecutionStatisticsPtr stats, DrVertexProcessStatusPtr status, + DrErrorPtr originalReason) +{ + DrAssert(failedGenerator != DrNull); + DrAssert(inputs == DrNull); + DrAssert(stats == DrNull); + DrAssert(status == DrNull); + + /* the upstream vertex has called this because its read failed: propagate it down to the source */ + if (m_generator == failedGenerator) + { + m_generator = DrNull; + } + + DrAssert(m_inputEdges->GetNumberOfEdges() == 1); + + DrEdge e = m_inputEdges->GetEdge(0); + DrLogI("Tee vertex %d.%d: calling ReactToFailedVertex on remote edge", this->m_id, GetVersion()); + e.m_remoteVertex->ReactToFailedVertex(failedGenerator, DrNull, DrNull, DrNull, originalReason); + + /* fill in a new generator if it's already there, e.g. if the upstream vertex is a DrStorageVertex */ + DrVertexOutputGeneratorRef newGenerator = + e.m_remoteVertex->GetOutputGenerator(e.m_remotePort, e.m_type, 0); + if (newGenerator != DrNull) + { + /* don�t call ReactToUpStreamCompletedVertex here because we don�t actually want to start + calling our downstream neighbors with ReactToUpStreamCompletedVertex just at the moment + they called into us with ReactToFailedVertex. */ + m_generator = DrNew DrTeeVertexOutputGenerator(newGenerator); + } +} + + +void DrTeeVertex::CompactPendingVersion(DrEdgeHolderPtr /* unused edgesBeingCompacted*/, int /* unused numberOfEdgesAfterCompaction*/) +{ + // + // Only DrActiveVertex can do this but when this method is called + // we only have a DrVertex so this needs to be part of the + // interface. + // +} + +void DrTeeVertex::CancelAllVersions(DrErrorPtr /* unused reason*/) +{ + // + // Only DrActiveVertex can do this but when this method is called + // we only have a DrVertex so this needs to be part of the + // interface. + // +} diff --git a/GraphManager/vertex/DrVertex.h b/GraphManager/vertex/DrVertex.h new file mode 100644 index 0000000..3ba2a44 --- /dev/null +++ b/GraphManager/vertex/DrVertex.h @@ -0,0 +1,436 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRCLASS(DrVertexIdSource) +{ +public: + static int GetNextId(); + +private: + static int s_nextId; +}; + +DRDECLARECLASS(DrVertex); +DRREF(DrVertex); + +typedef DrSet DrActiveVertexSet; +DRREF(DrActiveVertexSet); + +DRDECLARECLASS(DrStageManager); +DRREF(DrStageManager); + +DRINTERFACE(DrVertexTopologyReporter) +{ +public: + virtual void ReportFinalTopology(DrVertexPtr vertex, DrResourcePtr runningMachine, + DrTimeInterval runningTime) = 0; +}; +DRIREF(DrVertexTopologyReporter); + +DRCLASS(DrVertex abstract) : public DrSharedCritSec +{ +public: + DrVertex(DrStageManagerPtr stage); + + void Discard(); + + DrVertexRef MakeCopy(int suffix); + virtual DrVertexRef MakeCopy(int suffix, DrStageManagerPtr stage) = 0; + + DrStageManagerPtr GetStageManager(); + int GetId(); + + void SetName(DrString name); + DrString GetName(); + + DrEdgeHolderPtr GetInputs(); + DrEdgeHolderPtr GetOutputs(); + + void SetNumberOfSubgraphInputs(int numberOfInputs); + int GetNumberOfSubgraphInputs(); + void SetNumberOfSubgraphOutputs(int numberOfOutputs); + int GetNumberOfSubgraphOutputs(); + + DrVertexPtr RemoteInputVertex(int localPort); + int RemoteInputPort(int localPort); + DrVertexPtr RemoteOutputVertex(int localPort); + int RemoteOutputPort(int localPort); + + void ConnectOutput(int localPort, DrVertexPtr remoteVertex, int remotePort, DrConnectorType type); + void DisconnectInput(int localPort, bool disconnectRemote); + void DisconnectOutput(int localPort, bool disconnectRemote); + + virtual void InitializeForGraphExecution() = 0; + virtual void KickStateMachine() = 0; + virtual void RemoveFromGraphExecution() = 0; + virtual void ReportFinalTopology(DrVertexTopologyReporterPtr reporter) = 0; + + virtual DrVertexOutputGeneratorPtr GetOutputGenerator(int edgeIndex, DrConnectorType type, + int version) = 0; + virtual DrString GetURIForWrite(int port, int id, int version, int outputPort, + DrResourcePtr runningResource, DrMetaDataRef metaData) = 0; + + virtual void ReactToUpStreamRunningProcess(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator) = 0; + virtual void ReactToUpStreamRunningVertex(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator) = 0; + virtual void NotifyUpStreamCompletedVertex(int inputPort, DrConnectorType type, DrVertexOutputGeneratorPtr generator) = 0; + virtual void ReactToUpStreamCompletedVertex(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator, + DrVertexExecutionStatisticsPtr stats) = 0; + virtual void ReactToDownStreamFailure(int port, DrConnectorType type, + int downStreamVersion) = 0; + virtual void ReactToUpStreamFailure(int port, DrConnectorType type, + DrVertexOutputGeneratorPtr generator, int downStreamVersion) = 0; + virtual void ReactToFailedVertex(DrVertexOutputGeneratorPtr failedGenerator, + DrVertexVersionGeneratorPtr inputs, + DrVertexExecutionStatisticsPtr stats, DrVertexProcessStatusPtr status, + DrErrorPtr originalReason) = 0; + virtual void CompactPendingVersion(DrEdgeHolderPtr edgesBeingCompacted, int numberOfEdgesAfterCompaction) = 0; + + virtual void CancelAllVersions(DrErrorPtr reason) = 0; + +protected: + void InitializeFromOther(DrVertexPtr other, int suffix); + virtual void DiscardDerived() = 0; + + DrStageManagerRef m_stage; + DrString m_name; + int m_id; + + DrEdgeHolderRef m_inputEdges; + DrEdgeHolderRef m_outputEdges; + int m_numberOfSubgraphInputs; + int m_numberOfSubgraphOutputs; +}; + +typedef DrArrayList DrVertexList; +DRAREF(DrVertexList,DrVertexRef); + +typedef DrSet DrVertexSet; +DRREF(DrVertexSet); + +typedef DrDictionary DrVertexVListMap; +DRREF(DrVertexVListMap); + +typedef DrArrayList DrCompletedVertexList; +DRAREF(DrCompletedVertexList,DrActiveVertexOutputGeneratorRef); + +DRDECLARECLASS(DrStartClique); +DRREF(DrStartClique); + +DRDECLARECLASS(DrCohort); +DRREF(DrCohort); + +DRCLASS(DrActiveVertex) : public DrVertex +{ +public: + DrActiveVertex(DrStageManagerPtr stage, DrProcessTemplatePtr processTemplate, + DrVertexTemplatePtr vertexTemplate); + + virtual void DiscardDerived() DROVERRIDE; + + virtual DrVertexRef MakeCopy(int suffix, DrStageManagerPtr stage) DROVERRIDE; + + void AddArgument(DrNativeString argument); + void AddArgumentInternal(DrString argument); + + DrAffinityPtr GetAffinity(); + + virtual void InitializeForGraphExecution() DROVERRIDE; + virtual void KickStateMachine() DROVERRIDE; + virtual void RemoveFromGraphExecution() DROVERRIDE; + virtual void ReportFinalTopology(DrVertexTopologyReporterPtr reporter) DROVERRIDE; + + virtual DrVertexOutputGeneratorPtr GetOutputGenerator(int edgeIndex, DrConnectorType type, + int version) DROVERRIDE; + virtual DrString GetURIForWrite(int port, int id, int version, int outputPort, + DrResourcePtr runningResource, DrMetaDataRef metaData) DROVERRIDE; + + virtual void ReactToUpStreamRunningProcess(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator) DROVERRIDE; + virtual void ReactToUpStreamRunningVertex(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator) DROVERRIDE; + virtual void NotifyUpStreamCompletedVertex(int inputPort, DrConnectorType type, DrVertexOutputGeneratorPtr generator) DROVERRIDE; + virtual void ReactToUpStreamCompletedVertex(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator, + DrVertexExecutionStatisticsPtr stats) DROVERRIDE; + virtual void ReactToDownStreamFailure(int port, DrConnectorType type, + int downStreamVersion) DROVERRIDE; + virtual void ReactToUpStreamFailure(int port, DrConnectorType type, + DrVertexOutputGeneratorPtr generator, + int downStreamVersion) DROVERRIDE; + virtual void ReactToFailedVertex(DrVertexOutputGeneratorPtr failedGenerator, + DrVertexVersionGeneratorPtr inputs, + DrVertexExecutionStatisticsPtr stats, DrVertexProcessStatusPtr status, + DrErrorPtr originalReason) DROVERRIDE; + virtual void CompactPendingVersion(DrEdgeHolderPtr edgesBeingCompacted, int numberOfEdgesAfterCompaction) DROVERRIDE; + + virtual void CancelAllVersions(DrErrorPtr reason) DROVERRIDE; + + DrVertexCommandBlockRef MakeVertexStartCommand(DrVertexVersionGeneratorPtr pending, + DrResourcePtr runningResource); + + void RequestDuplicate(int versionToDuplicate); + + void NotifyVertexIsReady(); + bool HasPendingVersion(); + bool HasRunningVersion(int i); + bool HasCompletedVersion(int i); + void GrowPendingVersion(int numberOfEdgesToGrow); + void InstantiateVersion(int version); + void AddCurrentAffinitiesToList(int version, DrAffinityListPtr list); + void StartProcess(int version); + void CheckForProcessAlreadyStarted(int version); + void ReactToStartedProcess(int version, DrLockBox process); + void ReactToStartedVertex(DrVertexRecordPtr record, DrVertexExecutionStatisticsPtr stats); + void ReactToRunningVertexUpdate(DrVertexRecordPtr record, + HRESULT exitStatus, DrVertexProcessStatusPtr status); + void ReactToCompletedVertex(DrVertexRecordPtr record, DrVertexExecutionStatisticsPtr stats); + void CancelVersion(int version, DrErrorPtr error, DrCohortProcessPtr cohortProcess); + + int GetNumberOfReportedCompletions(); + + virtual DrString GetDescription(); + + void SetStartClique(DrStartCliquePtr cohort); + virtual DrStartCliquePtr GetStartClique(); + + void SetCohort(DrCohortPtr cohort); + virtual DrCohortPtr GetCohort(); + +private: + DrVertexRecordPtr GetRunningVersion(int version); + + DrCohortRef m_cohort; + DrStartCliqueRef m_startClique; + DrVertexTemplateRef m_vertexTemplate; + DrProcessTemplateRef m_processTemplate; + DrAffinityRef m_affinity; + + DrStringListRef m_argument; + DrUINT64ArrayRef m_outputSizeHint; + UINT64 m_totalOutputSizeHint; + int m_maxOpenInputChannelCount; + int m_maxOpenOutputChannelCount; + + int m_numberOfReportedCompletions; + bool m_registeredWithGraph; + + DrVertexVersionGeneratorRef m_pendingVersion; + DrVertexRecordListRef m_runningVertex; + DrActiveVertexOutputGeneratorRef m_completedRecord; + DrCompletedVertexListRef m_spareCompletedRecord; +}; + +typedef DrArrayList DrActiveVertexList; +DRAREF(DrActiveVertexList,DrActiveVertexRef); + + +DRCLASS(DrStorageVertex) : public DrVertex +{ +public: + DrStorageVertex(DrStageManagerPtr stage, int partitionIndex, DrIInputPartitionReaderPtr reader); + + virtual void DiscardDerived() DROVERRIDE; + + virtual DrVertexRef MakeCopy(int suffix, DrStageManagerPtr stage) DROVERRIDE; + + virtual void InitializeForGraphExecution() DROVERRIDE; + virtual void KickStateMachine() DROVERRIDE; + virtual void RemoveFromGraphExecution() DROVERRIDE; + virtual void ReportFinalTopology(DrVertexTopologyReporterPtr reporter) DROVERRIDE; + + virtual DrVertexOutputGeneratorPtr GetOutputGenerator(int edgeIndex, DrConnectorType type, + int version) DROVERRIDE; + virtual DrString GetURIForWrite(int port, int id, int version, int outputPort, + DrResourcePtr runningResource, DrMetaDataRef metaData) DROVERRIDE; + + virtual void ReactToUpStreamRunningProcess(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator) DROVERRIDE; + virtual void ReactToUpStreamRunningVertex(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator) DROVERRIDE; + virtual void NotifyUpStreamCompletedVertex(int inputPort, DrConnectorType type, DrVertexOutputGeneratorPtr generator) DROVERRIDE; + virtual void ReactToUpStreamCompletedVertex(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator, + DrVertexExecutionStatisticsPtr stats) DROVERRIDE; + virtual void ReactToDownStreamFailure(int port, DrConnectorType type, + int downStreamVersion) DROVERRIDE; + virtual void ReactToUpStreamFailure(int port, DrConnectorType type, + DrVertexOutputGeneratorPtr generator, + int downStreamVersion) DROVERRIDE; + virtual void ReactToFailedVertex(DrVertexOutputGeneratorPtr failedGenerator, + DrVertexVersionGeneratorPtr inputs, + DrVertexExecutionStatisticsPtr stats, DrVertexProcessStatusPtr status, + DrErrorPtr originalReason) DROVERRIDE; + virtual void CompactPendingVersion(DrEdgeHolderPtr edgesBeingCompacted, int numberOfEdgesAfterCompaction) DROVERRIDE; + + virtual void CancelAllVersions(DrErrorPtr reason) DROVERRIDE; + +private: + DrStorageVertex(DrStageManagerPtr stage, DrStorageVertexOutputGeneratorPtr generator); + + DrStorageVertexOutputGeneratorRef m_generator; + bool m_registeredInputReady; +}; +DRREF(DrStorageVertex); + +typedef DrArrayList DrStorageVertexList; +DRAREF(DrStorageVertexList, DrStorageVertexRef); + + +DRDECLARECLASS(DrOutputVertex); +DRREF(DrOutputVertex); + +DRDECLAREVALUECLASS(DrOutputPartition); +DRRREF(DrOutputPartition); + +DRVALUECLASS(DrOutputPartition) +{ +public: + bool operator==(DrOutputPartitionR other); + + int m_id; + int m_version; + int m_outputPort; + DrResourceRef m_resource; + UINT64 m_size; +}; + +DRMAKEARRAY(DrOutputPartition); +DRMAKEARRAYLIST(DrOutputPartition); + + +DRINTERFACE(DrIOutputPartitionGenerator) +{ +public: + virtual void AddDynamicSplitVertex(DrOutputVertexPtr newVertex) DRABSTRACT; + virtual HRESULT FinalizeSuccessfulPartitions() DRABSTRACT; + virtual DrString GetURIForWrite(int partitionIndex, int id, int version, int outputPort, + DrResourcePtr runningResource, DrMetaDataRef metaData) DRABSTRACT; + virtual void AbandonVersion(int partitionIndex, int id, int version, int outputPort, + DrResourcePtr runningResource) DRABSTRACT; + virtual void ExtendLease(DrTimeInterval) DRABSTRACT; +}; +DRIREF(DrIOutputPartitionGenerator); + +typedef DrArrayList DrPartitionGeneratorList; +DRAREF(DrPartitionGeneratorList,DrIOutputPartitionGeneratorIRef); + + +DRCLASS(DrOutputVertex) : public DrVertex +{ +public: + DrOutputVertex(DrStageManagerPtr stage, int partitionIndex, DrIOutputPartitionGeneratorPtr generator); + + virtual void DiscardDerived() DROVERRIDE; + + virtual DrVertexRef MakeCopy(int suffix, DrStageManagerPtr stage) DROVERRIDE; + + DrOutputPartition FinalizeVersions(); + + virtual void InitializeForGraphExecution() DROVERRIDE; + virtual void KickStateMachine() DROVERRIDE; + virtual void RemoveFromGraphExecution() DROVERRIDE; + virtual void ReportFinalTopology(DrVertexTopologyReporterPtr reporter) DROVERRIDE; + + virtual DrVertexOutputGeneratorPtr GetOutputGenerator(int edgeIndex, DrConnectorType type, + int version) DROVERRIDE; + virtual DrString GetURIForWrite(int port, int id, int version, int outputPort, + DrResourcePtr runningResource, DrMetaDataRef metaData) DROVERRIDE; + + virtual void ReactToUpStreamRunningProcess(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator) DROVERRIDE; + virtual void ReactToUpStreamRunningVertex(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator) DROVERRIDE; + virtual void NotifyUpStreamCompletedVertex(int inputPort, DrConnectorType type, DrVertexOutputGeneratorPtr generator) DROVERRIDE; + virtual void ReactToUpStreamCompletedVertex(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator, + DrVertexExecutionStatisticsPtr stats) DROVERRIDE; + virtual void ReactToDownStreamFailure(int port, DrConnectorType type, + int downStreamVersion) DROVERRIDE; + virtual void ReactToUpStreamFailure(int port, DrConnectorType type, + DrVertexOutputGeneratorPtr generator, + int downStreamVersion) DROVERRIDE; + virtual void ReactToFailedVertex(DrVertexOutputGeneratorPtr failedGenerator, + DrVertexVersionGeneratorPtr inputs, + DrVertexExecutionStatisticsPtr stats, DrVertexProcessStatusPtr status, + DrErrorPtr originalReason) DROVERRIDE; + virtual void CompactPendingVersion(DrEdgeHolderPtr edgesBeingCompacted, int numberOfEdgesAfterCompaction) DROVERRIDE; + + virtual void CancelAllVersions(DrErrorPtr reason) DROVERRIDE; + +private: + DrIOutputPartitionGeneratorIRef m_generator; + int m_partitionIndex; + DrOutputPartitionListRef m_runningVersion; + DrOutputPartitionListRef m_successfulVersion; +}; + +typedef DrArrayList DrOutputVertexList; +DRAREF(DrOutputVertexList,DrOutputVertexRef); + + +DRCLASS(DrTeeVertex) : public DrVertex +{ +public: + DrTeeVertex(DrStageManagerPtr stage); + + virtual void DiscardDerived() DROVERRIDE; + + virtual DrVertexRef MakeCopy(int suffix, DrStageManagerPtr stage) DROVERRIDE; + + virtual void InitializeForGraphExecution() DROVERRIDE; + virtual void KickStateMachine() DROVERRIDE; + virtual void RemoveFromGraphExecution() DROVERRIDE; + virtual void ReportFinalTopology(DrVertexTopologyReporterPtr reporter) DROVERRIDE; + + virtual DrVertexOutputGeneratorPtr GetOutputGenerator(int edgeIndex, DrConnectorType type, + int version) DROVERRIDE; + virtual DrString GetURIForWrite(int port, int id, int version, int outputPort, + DrResourcePtr runningResource, DrMetaDataRef metaData) DROVERRIDE; + + virtual void ReactToUpStreamRunningProcess(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator) DROVERRIDE; + virtual void ReactToUpStreamRunningVertex(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator) DROVERRIDE; + virtual void NotifyUpStreamCompletedVertex(int inputPort, DrConnectorType type, DrVertexOutputGeneratorPtr generator) DROVERRIDE; + virtual void ReactToUpStreamCompletedVertex(int inputPort, DrConnectorType type, + DrVertexOutputGeneratorPtr generator, + DrVertexExecutionStatisticsPtr stats) DROVERRIDE; + virtual void ReactToDownStreamFailure(int port, DrConnectorType type, + int downStreamVersion) DROVERRIDE; + virtual void ReactToUpStreamFailure(int port, DrConnectorType type, + DrVertexOutputGeneratorPtr generator, + int downStreamVersion) DROVERRIDE; + virtual void ReactToFailedVertex(DrVertexOutputGeneratorPtr failedGenerator, + DrVertexVersionGeneratorPtr inputs, + DrVertexExecutionStatisticsPtr stats, DrVertexProcessStatusPtr status, + DrErrorPtr originalReason) DROVERRIDE; + virtual void CompactPendingVersion(DrEdgeHolderPtr edgesBeingCompacted, int numberOfEdgesAfterCompaction) DROVERRIDE; + + virtual void CancelAllVersions(DrErrorPtr reason) DROVERRIDE; + +private: + DrTeeVertexOutputGeneratorRef m_generator; +}; +DRREF(DrTeeVertex); diff --git a/GraphManager/vertex/DrVertexCommand.cpp b/GraphManager/vertex/DrVertexCommand.cpp new file mode 100644 index 0000000..4db4346 --- /dev/null +++ b/GraphManager/vertex/DrVertexCommand.cpp @@ -0,0 +1,871 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +static const char* s_StatusPropertyLabel = "DVertexStatus"; +static const char* s_CommandPropertyLabel = "DVertexCommand"; + +DrChannelDescription::DrChannelDescription(bool isInputChannel) +{ + m_state = DrError_ChannelAbort; + m_totalLength = 0; + m_processedLength = 0; + m_isInputChannel = isInputChannel; +} + +DrChannelDescription::~DrChannelDescription() +{ +} + +HRESULT DrChannelDescription::GetChannelState() +{ + return m_state; +} + +void DrChannelDescription::SetChannelState(HRESULT state) +{ + m_state = state; +} + +DrString DrChannelDescription::GetChannelURI() +{ + return m_URI; +} + +void DrChannelDescription::SetChannelURI(DrString channelURI) +{ + m_URI = channelURI; +} + +DrMetaDataPtr DrChannelDescription::GetChannelMetaData() +{ + return m_metaData; +} + +void DrChannelDescription::SetChannelMetaData(DrMetaDataPtr metaData) +{ + m_metaData = metaData; +} + +UINT64 DrChannelDescription::GetChannelTotalLength() +{ + return m_totalLength; +} + +void DrChannelDescription::SetChannelTotalLength(UINT64 totalLength) +{ + m_totalLength = totalLength; +} + +UINT64 DrChannelDescription::GetChannelProcessedLength() +{ + return m_processedLength; +} + +void DrChannelDescription::SetChannelProcessedLength(UINT64 processedLength) +{ + m_processedLength = processedLength; +} + +void DrChannelDescription::Serialize(DrPropertyWriterPtr writer) +{ + UINT16 tagValue = (m_isInputChannel) ? + DrTag_InputChannelDescription : + DrTag_OutputChannelDescription; + + writer->WriteProperty(DrProp_BeginTag, tagValue); + + writer->WriteProperty(DrProp_ChannelState, m_state); + writer->WriteProperty(DrProp_ChannelURI, m_URI); + writer->WriteProperty(DrProp_ChannelTotalLength, m_totalLength); + writer->WriteProperty(DrProp_ChannelProcessedLength, m_processedLength); + if (m_metaData != DrNull) + { + writer->WriteProperty(DrProp_BeginTag, DrTag_ChannelMetaData); + m_metaData->Serialize(writer); + writer->WriteProperty(DrProp_EndTag, DrTag_ChannelMetaData); + } + + writer->WriteProperty(DrProp_EndTag, tagValue); +} + +HRESULT DrChannelDescription::ParseProperty(DrPropertyReaderPtr reader, + UINT16 enumID, UINT32 /* unused dataLen */) +{ + HRESULT err; + + switch (enumID) + { + default: + DrLogW("Unknown property in channel description enumID %u", (UINT32) enumID); + err = reader->SkipNextPropertyOrAggregate(); + break; + + case DrProp_ChannelState: + err = reader->ReadNextProperty(enumID, m_state); + break; + + case DrProp_ChannelURI: + { + DrString URI; + err = reader->ReadNextProperty(enumID, URI); + if (err == S_OK) + { + SetChannelURI(URI); + } + } + break; + + case DrProp_ChannelTotalLength: + err = reader->ReadNextProperty(enumID, m_totalLength); + break; + + case DrProp_ChannelProcessedLength: + err = reader->ReadNextProperty(enumID, m_processedLength); + break; + + case DrProp_BeginTag: + { + UINT16 tagID; + err = reader->PeekNextAggregateTag(&tagID); + if (err == S_OK) + { + if (tagID == DrTag_ChannelMetaData) + { + DrMetaDataRef mData = DrNew DrMetaData(); + err = reader->ReadAggregate(tagID, mData); + if (err == S_OK) + { + SetChannelMetaData(mData); + } + } + else + { + DrLogW("Unknown aggregate in channel description tagID %u", (UINT32) tagID); + } + } + } + break; + } + + return err; +} + +void DrChannelDescription::CopyFrom(DrChannelDescriptionPtr src, bool includeLengths) +{ + DrAssert(m_isInputChannel == src->m_isInputChannel); + + SetChannelURI(src->GetChannelURI()); + SetChannelState(src->GetChannelState()); + SetChannelMetaData(src->GetChannelMetaData()); + if (includeLengths) + { + SetChannelProcessedLength(src->GetChannelProcessedLength()); + SetChannelTotalLength(src->GetChannelTotalLength()); + } +} + + +DrInputChannelDescription::DrInputChannelDescription() : + DrChannelDescription(true) +{ +} + + +DrOutputChannelDescription::DrOutputChannelDescription() : + DrChannelDescription(false) +{ +} + + +DrVertexProcessStatus::DrVertexProcessStatus() +{ + m_id = 0; + m_version = 0; + m_maxInputChannels = 0; + m_maxOutputChannels = 0; + m_canShareWorkQueue = false; + + m_nextInputChannelToRead = 0; + m_nextOutputChannelToRead = 0; +} + +int DrVertexProcessStatus::GetVertexId() +{ + return m_id; +} + +void DrVertexProcessStatus::SetVertexId(int vertexId) +{ + m_id = vertexId; +} + +int DrVertexProcessStatus::GetVertexInstanceVersion() +{ + return m_version; +} + +void DrVertexProcessStatus::SetVertexInstanceVersion(int version) +{ + m_version = version; +} + +DrMetaDataPtr DrVertexProcessStatus::GetVertexMetaData() +{ + return m_metaData; +} + +void DrVertexProcessStatus::SetVertexMetaData(DrMetaDataPtr metaData) +{ + m_metaData = metaData; +} + +DrInputChannelArrayRef DrVertexProcessStatus::GetInputChannels() +{ + return m_inputChannel; +} + +void DrVertexProcessStatus::SetInputChannelCount(int nInputChannels) +{ + DrAssert(nInputChannels >= 0); + m_inputChannel = DrNew DrInputChannelArray(nInputChannels); + int i; + for (i=0; i= 0); + m_maxInputChannels = channelCount; +} + +DrOutputChannelArrayRef DrVertexProcessStatus::GetOutputChannels() +{ + return m_outputChannel; +} + +void DrVertexProcessStatus::SetOutputChannelCount(int nOutputChannels) +{ + DrAssert(nOutputChannels >= 0); + m_outputChannel = DrNew DrOutputChannelArray(nOutputChannels); + int i; + for (i=0; i= 0); + m_maxOutputChannels = channelCount; +} + +bool DrVertexProcessStatus::GetCanShareWorkQueue() +{ + return m_canShareWorkQueue; +} + +void DrVertexProcessStatus::SetCanShareWorkQueue(bool canShareWorkQueue) +{ + m_canShareWorkQueue = canShareWorkQueue; +} + +void DrVertexProcessStatus::Serialize(DrPropertyWriterPtr writer) +{ + int i; + + writer->WriteProperty(DrProp_BeginTag, DrTag_VertexProcessStatus); + + writer->WriteProperty(DrProp_VertexId, (UINT32) m_id); + writer->WriteProperty(DrProp_VertexVersion, (UINT32) m_version); + if (m_metaData != DrNull) + { + writer->WriteProperty(DrProp_BeginTag, DrTag_VertexMetaData); + m_metaData->Serialize(writer); + writer->WriteProperty(DrProp_EndTag, DrTag_VertexMetaData); + } + + if (m_inputChannel == DrNull) + { + writer->WriteProperty(DrProp_VertexInputChannelCount, (UINT32) 0); + } + else + { + writer->WriteProperty(DrProp_VertexInputChannelCount, (UINT32) m_inputChannel->Allocated()); + for (i=0; iAllocated(); ++i) + { + m_inputChannel[i]->Serialize(writer); + } + } + writer->WriteProperty(DrProp_VertexMaxOpenInputChannelCount, (UINT32) m_maxInputChannels); + + if (m_outputChannel == DrNull) + { + writer->WriteProperty(DrProp_VertexOutputChannelCount, (UINT32) 0); + } + else + { + writer->WriteProperty(DrProp_VertexOutputChannelCount, (UINT32) m_outputChannel->Allocated()); + for (i=0; iAllocated(); ++i) + { + m_outputChannel[i]->Serialize(writer); + } + } + writer->WriteProperty(DrProp_VertexMaxOpenOutputChannelCount, (UINT32) m_maxOutputChannels); + + writer->WriteProperty(DrProp_CanShareWorkQueue, m_canShareWorkQueue); + + writer->WriteProperty(DrProp_EndTag, DrTag_VertexProcessStatus); +} + +HRESULT DrVertexProcessStatus::ParseProperty(DrPropertyReaderPtr reader, + UINT16 enumID, UINT32 /* unused dataLen */) +{ + HRESULT err; + + switch (enumID) + { + default: + DrLogW("Unknown property in vertex status message enumID %u", (UINT32) enumID); + err = reader->SkipNextPropertyOrAggregate(); + break; + + case DrProp_VertexId: + UINT32 id; + err = reader->ReadNextProperty(enumID, id); + if (err == S_OK) + { + if (id >= 0x80000000) + { + DrLogW("Vertex ID out of range %u", id); + err = HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER); + } + else + { + m_id = id; + } + } + break; + + case DrProp_VertexVersion: + UINT32 version; + err = reader->ReadNextProperty(enumID, version); + if (err == S_OK) + { + if (version >= 0x80000000) + { + DrLogW("Vertex version out of range %u", version); + err = HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER); + } + else + { + m_version = version; + } + } + break; + + case DrProp_VertexInputChannelCount: + UINT32 nInputChannels; + err = reader->ReadNextProperty(enumID, nInputChannels); + if (err == S_OK) + { + if (nInputChannels >= 0x80000000) + { + DrLogW("Too many input channels %u", nInputChannels); + err = HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER); + } + else + { + SetInputChannelCount((int) nInputChannels); + } + } + break; + + case DrProp_VertexMaxOpenInputChannelCount: + UINT32 maxInputChannels; + err = reader->ReadNextProperty(enumID, maxInputChannels); + if (err == S_OK) + { + if (maxInputChannels >= 0x80000000) + { + DrLogW("Too many max input channels %u", maxInputChannels); + err = HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER); + } + else + { + m_maxInputChannels = (int) maxInputChannels; + } + } + break; + + case DrProp_VertexOutputChannelCount: + UINT32 nOutputChannels; + err = reader->ReadNextProperty(enumID, nOutputChannels); + if (err == S_OK) + { + if (nOutputChannels >= 0x80000000) + { + DrLogW("Too many output channels %u", nOutputChannels); + err = HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER); + } + else + { + SetOutputChannelCount((int) nOutputChannels); + } + } + break; + + case DrProp_VertexMaxOpenOutputChannelCount: + UINT32 maxOutputChannels; + err = reader->ReadNextProperty(enumID, maxOutputChannels); + if (err == S_OK) + { + if (maxOutputChannels >= 0x80000000) + { + DrLogW("Too many max output channels %d", maxOutputChannels); + err = HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER); + } + else + { + m_maxOutputChannels = (int) maxOutputChannels; + } + } + break; + + case DrProp_CanShareWorkQueue: + err = reader->ReadNextProperty(enumID, m_canShareWorkQueue); + break; + + case DrProp_BeginTag: + UINT16 tagValue; + err = reader->PeekNextAggregateTag(&tagValue); + if (err != S_OK) + { + DrLogW("Error reading DrProp_BeginTag %d", err); + } + else + { + switch (tagValue) + { + case DrTag_InputChannelDescription: + if (m_nextInputChannelToRead >= m_inputChannel->Allocated()) + { + DrLogW("Too many input channel descriptions nextInputChannelToRead=%d, nInputChannels=%d", + m_nextInputChannelToRead, m_inputChannel->Allocated()); + err = HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER); + } + else + { + err = reader->ReadAggregate(tagValue, m_inputChannel[m_nextInputChannelToRead]); + if (err == S_OK) + { + ++m_nextInputChannelToRead; + } + } + break; + + case DrTag_OutputChannelDescription: + if (m_nextOutputChannelToRead >= m_outputChannel->Allocated()) + { + DrLogW("Too many output channel descriptions nextOutputChannelToRead=%d, nOutputChannels=%d", + m_nextOutputChannelToRead, m_outputChannel->Allocated()); + err = HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER); + } + else + { + err = reader->ReadAggregate(tagValue, m_outputChannel[m_nextOutputChannelToRead]); + if (err == S_OK) + { + ++m_nextOutputChannelToRead; + } + } + break; + + case DrTag_VertexMetaData: + { + DrMetaDataRef metaData = DrNew DrMetaData(); + err = reader->ReadAggregate(tagValue, metaData); + if (err == S_OK) + { + m_metaData = metaData; + } + } + break; + + default: + DrLogW("Unexpected tag %d", tagValue); + err = reader->SkipNextPropertyOrAggregate(); + } + } + break; + } + + return err; +} + +void DrVertexProcessStatus::CopyFrom(DrVertexProcessStatusPtr src, + bool includeLengths) +{ + int i; + + SetVertexId(src->GetVertexId()); + SetVertexInstanceVersion(src->GetVertexInstanceVersion()); + SetVertexMetaData(src->GetVertexMetaData()); + + SetInputChannelCount(src->GetInputChannels()->Allocated()); + DrInputChannelArrayRef srcInputs = src->GetInputChannels(); + for (i=0; iAllocated(); ++i) + { + m_inputChannel[i]->CopyFrom(srcInputs[i], includeLengths); + } + SetMaxOpenInputChannelCount(src->GetMaxOpenInputChannelCount()); + + SetOutputChannelCount(src->GetOutputChannels()->Allocated()); + DrOutputChannelArrayRef srcOutputs = src->GetOutputChannels(); + for (i=0; iAllocated(); ++i) + { + m_outputChannel[i]->CopyFrom(srcOutputs[i], includeLengths); + } + SetMaxOpenOutputChannelCount(src->GetMaxOpenOutputChannelCount()); +} + + +DrVertexStatus::DrVertexStatus() +{ + m_state = S_OK; + m_processStatus = DrNew DrVertexProcessStatus(); +} + +HRESULT DrVertexStatus::GetVertexState() +{ + return m_state; +} + +void DrVertexStatus::SetVertexState(HRESULT state) +{ + m_state = state; +} + +DrVertexProcessStatusPtr DrVertexStatus::GetProcessStatus() +{ + return m_processStatus; +} + +void DrVertexStatus::SetProcessStatus(DrVertexProcessStatusPtr processStatus) +{ + m_processStatus = processStatus; +} + +void DrVertexStatus::Serialize(DrPropertyWriterPtr writer) +{ + writer->WriteProperty(DrProp_BeginTag, DrTag_VertexStatus); + + writer->WriteProperty(DrProp_VertexState, m_state); + m_processStatus->Serialize(writer); + + writer->WriteProperty(DrProp_EndTag, DrTag_VertexStatus); +} + +HRESULT DrVertexStatus::ParseProperty(DrPropertyReaderPtr reader, UINT16 enumID, + UINT32 /* unused dataLen */) +{ + HRESULT err; + + switch (enumID) + { + default: + DrLogW("Unknown property in vertex status message enumID %u", (UINT32) enumID); + err = reader->SkipNextPropertyOrAggregate(); + break; + + case DrProp_VertexState: + err = reader->ReadNextProperty(enumID, m_state); + break; + + case DrProp_BeginTag: + UINT16 tagValue; + err = reader->PeekNextAggregateTag(&tagValue); + if (err != S_OK) + { + DrLogW("Error reading DrProp_BeginTag %d", err); + } + else + { + switch (tagValue) + { + case DrTag_VertexProcessStatus: + err = reader->ReadAggregate(tagValue, m_processStatus); + break; + + default: + DrLogW("Unexpected tag %d", tagValue); + err = reader->SkipNextPropertyOrAggregate(); + } + } + break; + } + + return err; +} + +DrString DrVertexStatus::GetPropertyLabel(int vertexId, int vertexVersion) +{ + DrString s; + s.SetF("%s-%d.%d", s_StatusPropertyLabel, vertexId, vertexVersion); + return s; +} + + +DrVertexCommandBlock::DrVertexCommandBlock() +{ + m_command = DrVC_Terminate; + m_processStatus = DrNew DrVertexProcessStatus(); + m_setBreakpointOnCommandArrival = false; + m_nextArgumentToRead = 0; +} + +DrVertexCommand DrVertexCommandBlock::GetVertexCommand() +{ + return m_command; +} + +void DrVertexCommandBlock::SetVertexCommand(DrVertexCommand command) +{ + m_command = command; +} + +DrVertexProcessStatusPtr DrVertexCommandBlock::GetProcessStatus() +{ + return m_processStatus; +} + +void DrVertexCommandBlock::SetProcessStatus(DrVertexProcessStatusPtr processStatus) +{ + m_processStatus = processStatus; +} + +void DrVertexCommandBlock::SetArgumentCount(int nArguments) +{ + m_argument = DrNew DrStringArray(nArguments); + m_nextArgumentToRead = 0; +} + +DrStringArrayRef DrVertexCommandBlock::GetArgumentVector() +{ + return m_argument; +} + +DrByteArrayPtr DrVertexCommandBlock::GetRawSerializedBlock() +{ + return m_serializedBlock; +} + +void DrVertexCommandBlock::SetRawSerializedBlock(DrByteArrayPtr block) +{ + m_serializedBlock = block; +} + +void DrVertexCommandBlock::SetDebugBreak(bool setBreakpointOnCommandArrival) +{ + m_setBreakpointOnCommandArrival = setBreakpointOnCommandArrival; +} + +bool DrVertexCommandBlock::GetDebugBreak() +{ + return m_setBreakpointOnCommandArrival; +} + +void DrVertexCommandBlock::Serialize(DrPropertyWriterPtr writer) +{ + int i; + + writer->WriteProperty(DrProp_BeginTag, DrTag_VertexCommand); + + writer->WriteProperty(DrProp_VertexCommand, (UINT32) m_command); + + m_processStatus->Serialize(writer); + + if (m_argument == DrNull) + { + writer->WriteProperty(DrProp_VertexArgumentCount, 0); + } + else + { + writer->WriteProperty(DrProp_VertexArgumentCount, (UINT32) m_argument->Allocated()); + for (i=0; iAllocated(); ++i) + { + writer->WriteProperty(DrProp_VertexArgument, m_argument[i]); + } + } + + if (m_serializedBlock != DrNull) + { + DRPIN(BYTE) block = &(m_serializedBlock[0]); + writer->WriteProperty(DrProp_VertexSerializedBlock, m_serializedBlock->Allocated(), block); + } + + writer->WriteProperty(DrProp_DebugBreak, m_setBreakpointOnCommandArrival); + + writer->WriteProperty(DrProp_EndTag, DrTag_VertexCommand); +} + +HRESULT DrVertexCommandBlock::ParseProperty(DrPropertyReaderPtr reader, UINT16 enumID, + UINT32 /* unused dataLen */) +{ + HRESULT err; + + switch (enumID) + { + default: + DrLogW("Unknown property in vertex command message enumID %u", (UINT32) enumID); + err = reader->SkipNextPropertyOrAggregate(); + break; + + case DrProp_VertexCommand: + UINT32 marshaledCommand; + err = reader->ReadNextProperty(enumID, marshaledCommand); + if (err == S_OK) + { + if (marshaledCommand < DrVC_Max) + { + m_command = (DrVertexCommand) marshaledCommand; + } + else + { + DrLogW("Unknown vertex command %u", marshaledCommand); + err = HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER); + } + } + break; + + case DrProp_VertexArgumentCount: + UINT32 nArguments; + err = reader->ReadNextProperty(enumID, nArguments); + if (err == S_OK) + { + if (nArguments < 0x80000000) + { + SetArgumentCount((int) nArguments); + } + else + { + DrLogW("Too large argument count %u", nArguments); + err = HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER); + } + } + break; + + case DrProp_VertexArgument: + if (m_nextArgumentToRead >= m_argument->Allocated()) + { + DrLogW("Too many arguments nextArgumentToRead=%d, nArguments=%d", + m_nextArgumentToRead, m_argument->Allocated()); + err = HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER); + } + else + { + DrString arg; + err = reader->ReadNextProperty(enumID, arg); + if (err == S_OK) + { + m_argument[m_nextArgumentToRead] = arg; + ++m_nextArgumentToRead; + } + } + break; + + case DrProp_VertexSerializedBlock: + UINT32 blockLength; + err = reader->PeekNextPropertyTag(&enumID, &blockLength); + if (err == S_OK) + { + if (blockLength < 0x80000000) + { + DrByteArrayRef block = DrNew DrByteArray((int) blockLength); + { + DRPIN(BYTE) data = &(block[0]); + err = reader->ReadNextProperty(enumID, (UINT32) blockLength, data); + } + if (err == S_OK) + { + m_serializedBlock = block; + } + } + else + { + DrLogW("Block too large %u", blockLength); + err = HRESULT_FROM_WIN32(ERROR_INVALID_PARAMETER); + } + } + break; + + case DrProp_DebugBreak: + err = reader->ReadNextProperty(enumID, m_setBreakpointOnCommandArrival); + break; + + case DrProp_BeginTag: + UINT16 tagValue; + err = reader->PeekNextAggregateTag(&tagValue); + if (err != S_OK) + { + DrLogW("Error reading DrProp_BeginTag %d", err); + } + else + { + switch (tagValue) + { + case DrTag_VertexProcessStatus: + err = reader->ReadAggregate(tagValue, m_processStatus); + break; + + default: + DrLogW("Unexpected tag %d", tagValue); + err = reader->SkipNextPropertyOrAggregate(); + } + } + break; + } + + return err; +} + +DrString DrVertexCommandBlock::GetPropertyLabel(int vertexId, int vertexVersion) +{ + DrString s; + s.SetF("%s-%d.%d", s_CommandPropertyLabel, vertexId, vertexVersion); + return s; +} diff --git a/GraphManager/vertex/DrVertexCommand.h b/GraphManager/vertex/DrVertexCommand.h new file mode 100644 index 0000000..d85fd37 --- /dev/null +++ b/GraphManager/vertex/DrVertexCommand.h @@ -0,0 +1,197 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +DRDECLARECLASS(DrChannelDescription); +DRREF(DrChannelDescription); + +DRBASECLASS(DrChannelDescription), public DrPropertyParser +{ +public: + DrChannelDescription(bool isInputChannel); + virtual ~DrChannelDescription(); + + HRESULT GetChannelState(); + void SetChannelState(HRESULT state); + + DrString GetChannelURI(); + void SetChannelURI(DrString uri); + + DrMetaDataPtr GetChannelMetaData(); + void SetChannelMetaData(DrMetaDataPtr metaData); + + UINT64 GetChannelTotalLength(); + void SetChannelTotalLength(UINT64 totalLength); + + UINT64 GetChannelProcessedLength(); + void SetChannelProcessedLength(UINT64 processedLength); + + void Serialize(DrPropertyWriterPtr writer); + virtual HRESULT ParseProperty(DrPropertyReaderPtr reader, UINT16 enumID, UINT32 dataLen); + + void CopyFrom(DrChannelDescriptionPtr src, bool includeLengths); + +private: + HRESULT m_state; + DrString m_URI; + DrMetaDataRef m_metaData; + UINT64 m_totalLength; + UINT64 m_processedLength; + bool m_isInputChannel; +}; + +DRCLASS(DrInputChannelDescription) : public DrChannelDescription +{ +public: + DrInputChannelDescription(); +}; +DRREF(DrInputChannelDescription); + +DRCLASS(DrOutputChannelDescription) : public DrChannelDescription +{ +public: + DrOutputChannelDescription(); +}; +DRREF(DrOutputChannelDescription); + +typedef DrArray DrInputChannelArray; +DRAREF(DrInputChannelArray,DrInputChannelDescriptionRef); + +typedef DrArray DrOutputChannelArray; +DRAREF(DrOutputChannelArray,DrOutputChannelDescriptionRef); + +DRDECLARECLASS(DrVertexProcessStatus); +DRREF(DrVertexProcessStatus); + +DRBASECLASS(DrVertexProcessStatus), public DrPropertyParser +{ +public: + DrVertexProcessStatus(); + + int GetVertexId(); + void SetVertexId(int vertexId); + + int GetVertexInstanceVersion(); + void SetVertexInstanceVersion(int instanceVersion); + + DrMetaDataPtr GetVertexMetaData(); + void SetVertexMetaData(DrMetaDataPtr metaData); + + void SetInputChannelCount(int channelCount); + DrInputChannelArrayRef GetInputChannels(); + + int GetMaxOpenInputChannelCount(); + void SetMaxOpenInputChannelCount(int channelCount); + + void SetOutputChannelCount(int channelCount); + DrOutputChannelArrayRef GetOutputChannels(); + + int GetMaxOpenOutputChannelCount(); + void SetMaxOpenOutputChannelCount(int channelCount); + + bool GetCanShareWorkQueue(); + void SetCanShareWorkQueue(bool canShareWorkQueue); + + void Serialize(DrPropertyWriterPtr writer); + virtual HRESULT ParseProperty(DrPropertyReaderPtr reader, UINT16 enumID, UINT32 dataLen); + + void CopyFrom(DrVertexProcessStatusPtr src, bool includeLengths); + +private: + int m_id; + int m_version; + DrMetaDataRef m_metaData; + int m_maxInputChannels; + DrInputChannelArrayRef m_inputChannel; + int m_maxOutputChannels; + DrOutputChannelArrayRef m_outputChannel; + bool m_canShareWorkQueue; + + int m_nextInputChannelToRead; + int m_nextOutputChannelToRead; +}; + + +DRBASECLASS(DrVertexStatus), public DrPropertyParser +{ +public: + DrVertexStatus(); + + HRESULT GetVertexState(); + void SetVertexState(HRESULT state); + + DrVertexProcessStatusPtr GetProcessStatus(); + void SetProcessStatus(DrVertexProcessStatusPtr status); + + void Serialize(DrPropertyWriterPtr writer); + virtual HRESULT ParseProperty(DrPropertyReaderPtr reader, UINT16 enumID, UINT32 dataLen); + + static DrString GetPropertyLabel(int vertexId, int vertexVersion); + +private: + HRESULT m_state; + DrVertexProcessStatusRef m_processStatus; +}; +DRREF(DrVertexStatus); + +DRENUM(DrVertexCommand) +{ + DrVC_Start = 0, + DrVC_ReOpenChannels, + DrVC_Terminate, + DrVC_Max +}; + +DRBASECLASS(DrVertexCommandBlock) , public DrPropertyParser +{ +public: + DrVertexCommandBlock(); + + DrVertexCommand GetVertexCommand(); + void SetVertexCommand(DrVertexCommand command); + + DrVertexProcessStatusPtr GetProcessStatus(); + void SetProcessStatus(DrVertexProcessStatusPtr status); + + void SetArgumentCount(int nArguments); + DrStringArrayRef GetArgumentVector(); + + DrByteArrayPtr GetRawSerializedBlock(); + void SetRawSerializedBlock(DrByteArrayPtr block); + + void SetDebugBreak(bool setBreakpointOnCommandArrival); + bool GetDebugBreak(); + + void Serialize(DrPropertyWriterPtr writer); + virtual HRESULT ParseProperty(DrPropertyReaderPtr reader, UINT16 enumID, UINT32 dataLen); + + static DrString GetPropertyLabel(int vertexId, int vertexVersion); + +private: + DrVertexCommand m_command; + DrVertexProcessStatusRef m_processStatus; + DrStringArrayRef m_argument; + DrByteArrayRef m_serializedBlock; + bool m_setBreakpointOnCommandArrival; + + int m_nextArgumentToRead; +}; +DRREF(DrVertexCommandBlock); \ No newline at end of file diff --git a/GraphManager/vertex/DrVertexHeaders.h b/GraphManager/vertex/DrVertexHeaders.h new file mode 100644 index 0000000..2368a20 --- /dev/null +++ b/GraphManager/vertex/DrVertexHeaders.h @@ -0,0 +1,35 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include + +#include + +#include +#include +#include +#include +#include + +#include + +#include diff --git a/GraphManager/vertex/DrVertexRecord.cpp b/GraphManager/vertex/DrVertexRecord.cpp new file mode 100644 index 0000000..df182ab --- /dev/null +++ b/GraphManager/vertex/DrVertexRecord.cpp @@ -0,0 +1,873 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include + +DrVertexTemplate::DrVertexTemplate() +{ + m_statusBlockTime = 0; + m_listenerList = DrNew DrVertexListenerIRefList(); +} + +void DrVertexTemplate::SetStatusBlockTime(DrTimeInterval statusBlockTime) +{ + m_statusBlockTime = statusBlockTime; +} + +DrTimeInterval DrVertexTemplate::GetStatusBlockTime() +{ + return m_statusBlockTime; +} + +DrVertexListenerIRefListPtr DrVertexTemplate::GetListenerList() +{ + return m_listenerList; +} + + +DrInputChannelExecutionStatistics::DrInputChannelExecutionStatistics() +{ + m_dataRead = 0; + m_tempDataRead = 0; + m_tempDataReadCrossMachine = 0; + m_tempDataReadCrossPod = 0; +} + +DrOutputChannelExecutionStatistics::DrOutputChannelExecutionStatistics() +{ + Clear(); +} + +void DrOutputChannelExecutionStatistics::Clear() +{ + m_dataWritten = 0; + m_dataIntraPod = 0; + m_dataCrossPod = 0; +} + +DrVertexExecutionStatistics::DrVertexExecutionStatistics() +{ + m_totalLocalInputData = 0UL; + + m_creationTime = DrDateTime_Never; + m_startTime = DrDateTime_Never; + m_runningTime = DrDateTime_Never; + m_completionTime = DrDateTime_Never; + + m_exitCode = STILL_ACTIVE; + m_exitStatus = S_OK; +} + +void DrVertexExecutionStatistics::SetNumberOfChannels(int numberOfInputs, int numberOfOutputs) +{ + m_inputData = DrNew DrInputChannelStatsArray(numberOfInputs); + m_outputData = DrNew DrOutputChannelStatsArray(numberOfOutputs); + + int i; + for (i=0; iAllocated(); +} + +void DrVertexVersionGenerator::ResetVersion(int version) +{ + DrAssert(m_version == 0); + m_version = version; +} + +int DrVertexVersionGenerator::GetVersion() +{ + return m_version; +} + +void DrVertexVersionGenerator::SetGenerator(int inputIndex, DrVertexOutputGeneratorPtr generator) +{ + if (generator == DrNull) + { + DrAssert(m_generator[inputIndex] != DrNull); + m_generator[inputIndex] = DrNull; + ++m_unfilledCount; + } + else + { + DrAssert(m_generator[inputIndex] == DrNull); + m_generator[inputIndex] = generator; + DrAssert(m_unfilledCount > 0); + --m_unfilledCount; + } +} + +DrVertexOutputGeneratorPtr DrVertexVersionGenerator::GetGenerator(int inputIndex) +{ + return m_generator[inputIndex]; +} + +bool DrVertexVersionGenerator::Ready() +{ + return (m_unfilledCount == 0); +} + + +void DrVertexVersionGenerator::Compact(DrEdgeHolderPtr edgesBeingCompacted, int numberOfEdgesAfterCompaction) +{ + DrAssert(m_generator->Allocated() == edgesBeingCompacted->GetNumberOfEdges()); + + DrGeneratorArrayRef newArray = DrNew DrGeneratorArray(numberOfEdgesAfterCompaction); + + int nextCompactedLocation = 0; + int i; + for (i=0; iAllocated(); ++i) + { + DrEdge e = edgesBeingCompacted->GetEdge(i); + if (e.m_type == DCT_Tombstone) + { + if (m_generator[i] == DrNull) + { + /* we were waiting for this edge before being able to run, + but we can stop waiting for it since it has been pruned */ + DrAssert(m_unfilledCount > 0); + --m_unfilledCount; + } + } + else + { + newArray[nextCompactedLocation] = m_generator[i]; + ++nextCompactedLocation; + } + } + DrAssert(nextCompactedLocation == numberOfEdgesAfterCompaction); + m_generator = newArray; + +} + + +void DrVertexVersionGenerator::Grow(int numberOfEdgesToGrow) +{ + int numberExisiting = m_generator->Allocated(); + DrGeneratorArrayRef newArray = DrNew DrGeneratorArray(numberExisiting+ numberOfEdgesToGrow); + + int i; + for (i=0; iGetMessagePump(), parent) +{ + m_parent = parent; + m_inputs = inputs; + m_vertexTemplate = vertexTemplate; + m_cohort = cohort; + m_state = DVS_NotStarted; + m_lastSeenVersion = 0; + + m_creationTime = xcompute->GetCurrentTimeStamp(); + m_startTime = DrDateTime_Never; + m_runningTime = DrDateTime_Never; + m_completionTime = DrDateTime_Never; + + DrVertexListenerIRefListRef listeners = vertexTemplate->GetListenerList(); + int i; + for (i=0; iSize(); ++i) + { + AddListener(listeners[i]); + } +} + +void DrVertexRecord::Discard() +{ + m_parent = DrNull; + m_generator = DrNull; + m_inputs = DrNull; + m_cohort = DrNull; + m_process = DrNull; +} + +void DrVertexRecord::MakeGenerator() +{ + DrAssert(m_generator == DrNull); + m_generator = DrNew DrActiveVertexOutputGenerator(); + + if (m_process.IsEmpty() == false) + { + DrLockBoxKey process(m_process); + + m_generator->SetProcess(process->GetInfo()->m_state->m_process, + m_parent->GetId(), m_inputs->GetVersion()); + } + else + { + m_generator->SetProcess(DrNull, m_parent->GetId(), m_inputs->GetVersion()); + } +} + +DrActiveVertexOutputGeneratorPtr DrVertexRecord::GetGenerator() +{ + return m_generator; +} + +int DrVertexRecord::GetVersion() +{ + return m_inputs->GetVersion(); +} + +DrVertexStatusRef DrVertexRecord::TryToParseProperty(DrPropertyStatusPtr prop) +{ + if (prop->m_statusBlock == DrNull) + { + return DrNull; + } + + if (prop->m_statusVersion < m_lastSeenVersion) + { + DrLogW("Vertex record got out of order property %I64u < %I64u: ignoring", + prop->m_statusVersion, m_lastSeenVersion); + return DrNull; + } + + /* update this even if the parsing fails below so we don't keep requesting the same + unparseable property in a tight loop */ + m_lastSeenVersion = prop->m_statusVersion; + + DrVertexStatusRef status = DrNew DrVertexStatus(); + DrPropertyReaderRef parser = DrNew DrPropertyReader(prop->m_statusBlock); + HRESULT err = parser->ReadAggregate(DrTag_VertexStatus, status); + if (SUCCEEDED(err)) + { + DrLogI("Vertex %d.%d parsed property new version is %I64u", + m_parent->GetId(), GetVersion(), m_lastSeenVersion); + } + else + { + DrLogW("Vertex record got unparseable property version %I64u error %s: ignoring", + prop->m_statusVersion, DRERRORSTRING(err)); + return DrNull; + } + + return status; +} + +bool IsLocalPath(DrString machineName, DrString channelURI) +{ + if ((machineName.GetCharsLength() == 0) || + (channelURI.GetCharsLength() == 0)) + { + return false; + } + + if (channelURI.Compare("\\\\", 2) != 0) + { + return true; + } + else { + + int firstIndexOfSlash = channelURI.IndexOfChar('\\', 2); + if (firstIndexOfSlash < 0) + { + return false; + } + + if(firstIndexOfSlash - 2 == machineName.GetCharsLength()) + { +#ifdef _MANAGED + if (System::String::Compare(channelURI.GetString(), 2, machineName.GetString(), 0, machineName.GetCharsLength(), + true) == 0) + { + // + // If length of host name is same and strings match, then file is local + // + return true; + } +#else + const char* uriSuffix = channelURI.GetChars() + 2; + if (machineName.Compare(uriSuffix, machineName.GetCharsLength(), false) == 0) + { + return true; + } +#endif + } + } + return false; +} + +DrVertexExecutionStatisticsRef DrVertexRecord::MakeExecutionStatistics(DrVertexStatusPtr status, + UINT32 exitCode, HRESULT exitStatus) +{ + DrVertexExecutionStatisticsRef stats = DrNew DrVertexExecutionStatistics(); + + stats->m_creationTime = m_creationTime; + stats->m_startTime = m_startTime; + stats->m_runningTime = m_runningTime; + stats->m_completionTime = m_completionTime; + + stats->m_exitCode = exitCode; + stats->m_exitStatus = exitStatus; + + if (status != DrNull) + { + DrString machineName = "(no computer)"; + + if (m_process.IsEmpty() == false) + { + DrLockBoxKey process(m_process); + DrProcessHandlePtr handle = process->GetInfo()->m_state->m_process; + if (handle != DrNull) + { + if (handle->GetAssignedNode() != DrNull) + { + machineName = handle->GetAssignedNode()->GetName(); + } + } + } + + DrVertexProcessStatusPtr ps = status->GetProcessStatus(); + + stats->SetNumberOfChannels(ps->GetInputChannels()->Allocated(), ps->GetOutputChannels()->Allocated()); + + int i; + if (ps->GetInputChannels()->Allocated() == m_inputs->GetNumberOfInputs()) + { + stats->m_totalInputData = DrNew DrInputChannelExecutionStatistics(); + for (i=0; iGetInputChannels()->Allocated(); ++i) + { + DrChannelDescriptionPtr c = ps->GetInputChannels()[i]; + + UINT64 dataRead = c->GetChannelProcessedLength(); + UINT64 tempData, inPod, crossPod; + + DrResourcePtr tempSource = m_inputs->GetGenerator(i)->GetResource(); + if (tempSource == DrNull) + { + tempData = 0; + inPod = 0; + crossPod = 0; + if (IsLocalPath(machineName, c->GetChannelURI())) + { + stats->m_totalLocalInputData += dataRead; + } + } + else if (tempSource == m_generator->GetResource()) + { + tempData = dataRead; + inPod = 0; + crossPod = 0; + } + else if (tempSource->GetParent() == m_generator->GetResource()->GetParent()) + { + tempData = 0; + inPod = dataRead; + crossPod = 0; + } + else + { + tempData = 0; + inPod = 0; + crossPod = dataRead; + } + + stats->m_totalInputData->m_dataRead += dataRead; + stats->m_totalInputData->m_tempDataRead += tempData; + stats->m_totalInputData->m_tempDataReadCrossMachine += inPod; + stats->m_totalInputData->m_tempDataReadCrossPod += crossPod; + + DrInputChannelExecutionStatisticsPtr cs = stats->m_inputData[i]; + cs->m_remoteMachine = tempSource; + cs->m_dataRead = dataRead; + cs->m_tempDataRead = tempData; + cs->m_tempDataReadCrossMachine = inPod; + cs->m_tempDataReadCrossPod = crossPod; + } + } + + stats->m_totalOutputData = DrNew DrOutputChannelExecutionStatistics(); + for (i=0; iGetOutputChannels()->Allocated(); ++i) + { + DrChannelDescriptionPtr c = ps->GetOutputChannels()[i]; + + UINT64 dataWritten = c->GetChannelProcessedLength(); + + stats->m_totalOutputData->m_dataWritten += dataWritten; + + DrOutputChannelExecutionStatisticsPtr cs = stats->m_outputData[i]; + cs->m_dataWritten = dataWritten; + } + } + + return stats; +} + +void DrVertexRecord::TriggerFailure(DrErrorPtr error) +{ + /* fake a failed state message to ourselves */ + ReceiveMessage(DrNew DrPropertyStatus(DPBS_Failed, STILL_ACTIVE, error)); +} + +DrVertexVersionGeneratorPtr DrVertexRecord::NotifyProcessHasStarted(DrLockBox process) +{ + DrAssert(m_state == DVS_NotStarted); + + /* this is the first sign that we have started running: save the process */ + DrAssert(process.IsEmpty() == false); + m_process = process; + + /* we now need a generator record for upstream vertices connected using active + edges to use */ + MakeGenerator(); + + return m_inputs; +} + +void DrVertexRecord::SetActiveInput(int inputPort, DrVertexOutputGeneratorPtr generator) +{ + DrAssert(m_inputs->GetGenerator(inputPort) == DrNull); + m_inputs->SetGenerator(inputPort, generator); + + if (m_inputs->Ready()) + { + StartRunning(); + } +} + +void DrVertexRecord::StartRunning() +{ + DrAssert(m_state == DVS_NotStarted); + + /* send the command to start the vertex running */ + SendStartCommand(); + + m_state = DVS_Starting; + + DrLogI("Vertex %d.%d transition to starting", m_parent->GetId(), GetVersion()); + + /* we have changed state so tell our listeners */ + DrVertexInfoRef info = DrNew DrVertexInfo(); + info->m_name = m_parent->GetName(); + info->m_state = m_state; + info->m_info = DrNull; + info->m_statistics = DrNull; + info->m_process = m_process; + + DeliverNotification(info); + + /* now that there is a process waiting, ask it for a property status */ + RequestStatus(); +} + +/* there are two places a message could come from: the cohort or the process. The cohort sends a Running state + message when the process starts up, after which we are responsible for sending the process status requests + which are the source of the process' responses. The cohort may also send a Failed state message, which we + ignore if we have already seen successful completion from the process status, and believe otherwise */ +void DrVertexRecord::ReceiveMessage(DrPropertyStatusRef prop) +{ + DrLogI("Vertex %d.%d receiving message while in state %d", m_parent->GetId(), GetVersion(), (int) m_state); + + if (m_state > DVS_RunningStatus) + { + /* we don't really care: we've already determined that we are completed or failed */ + return; + } + + DrErrorRef error; + DrString reason; + + bool requestStatus = false; + bool sendNotification = false; + DrVertexStatusRef status = TryToParseProperty(prop); + if (status != DrNull) + { + /* tell the listeners every time there's a new status block */ + sendNotification = true; + } + + HRESULT exitStatus = DrError_VertexRunning; + + if (status == DrNull) + { + DrLogI("Vertex %d.%d no status block in message", m_parent->GetId(), GetVersion()); + + /* there was no status property block */ + switch (prop->m_processState) + { + case DPBS_NotStarted: + DrLogA("Vertex record unexpectedly received message in NotStarted state"); + break; + + case DPBS_Running: + DrAssert(m_state > DVS_NotStarted); + + if (m_state > DVS_Starting) + { + DrLogW("Vertex %d.%d sent empty status after it started running", + m_parent->GetId(), GetVersion()); + + reason = "Process sent status with no parseable block"; + error = DrNew DrError(DrError_Unexpected, "DrVertexRecord", reason); + + m_state = DVS_Failed; + exitStatus = error->m_code; + /* we are changing state, so tell the listeners */ + sendNotification = true; + + DrLogI("Vertex %d.%d transition to failed", m_parent->GetId(), GetVersion()); + } + /* else we're in the starting state, so it's ok to get a message without status + since the remote vertex may not have written a property yet */ + break; + + case DPBS_Completed: + case DPBS_Failed: + DrLogI("Vertex %d.%d received message with processState %d", m_parent->GetId(), GetVersion(), prop->m_processState); + error = prop->m_status; + if (error == DrNull) + { + reason = "Process ended with no error status"; + error = DrNew DrError(DrError_Unexpected, "DrVertexRecord", reason); + } + + if (error->m_code == DrError_CohortShutdown && m_process.IsEmpty() == false) + { + /* this is the cohort or graph killing us, so tell the process to terminate us */ + DrLogI("Vertex %d.%d this is the cohort or graph killing us, tell process to terminate us", m_parent->GetId(), GetVersion()); + SendTerminateCommand(m_parent->GetId(), m_inputs->GetVersion(), m_process); + } + + m_state = DVS_Failed; + exitStatus = error->m_code; + /* we are changing state, so tell the listeners */ + sendNotification = true; + + DrLogI("Vertex %d.%d transition to failed", m_parent->GetId(), GetVersion()); + break; + } + } + else + { + DrLogI("Vertex %d.%d has status block", m_parent->GetId(), GetVersion()); + + if (m_state == DVS_Starting) + { + DrLogI("Vertex %d.%d transition to running", m_parent->GetId(), GetVersion()); + + m_state = DVS_Running; + m_runningTime = m_parent->GetStageManager()->GetGraph()->GetXCompute()->GetCurrentTimeStamp(); + + DrVertexExecutionStatisticsRef startStats = + MakeExecutionStatistics(status, prop->m_exitCode, exitStatus); + /* we started running at some point: tell our parent. It may be that we have already + completed or failed, in which case the state will be updated again below and the parent + will get another message */ + m_parent->ReactToStartedVertex(this, startStats); + + DrVertexInfoRef info = DrNew DrVertexInfo(); + info->m_name = m_parent->GetName(); + info->m_state = m_state; + info->m_info = status; + info->m_statistics = startStats; + info->m_process = m_process; + DeliverNotification(info); + + /* we only send the DVS_Running state to listeners once: after that we repeatedly send + DVS_RunningStatus messages until the vertex finishes */ + m_state = DVS_RunningStatus; + } + + DrAssert(m_state == DVS_RunningStatus); + DrAssert(m_process.IsEmpty() == false); + + exitStatus = status->GetVertexState(); + + DrLogI("Vertex %d.%d processState %d vertexState %s", m_parent->GetId(), GetVersion(), + prop->m_processState, DRERRORSTRING(exitStatus)); + + switch (status->GetVertexState()) + { + case DrError_VertexRunning: + switch (prop->m_processState) + { + case DPBS_Running: + /* no change of state, but request the next status message and send the status to + our listeners */ + DrLogI("Vertex %d.%d process running, request next status message and send status to listeners", m_parent->GetId(), GetVersion()); + requestStatus = true; + sendNotification = true; + break; + + case DPBS_Completed: + case DPBS_Failed: + /* the process stopped without the property being updated, so + fail the vertex */ + DrLogI("Vertex %d.%d process stopped without property being updated, fail the vertex", m_parent->GetId(), GetVersion()); + error = prop->m_status; + if (error == DrNull) + { + error = DrNew DrError(DrError_Unexpected, "DrVertexRecord", + DrString("Process ended with no error status")); + } + + if (error->m_code == DrError_CohortShutdown) + { + /* this is the cohort or graph killing us, so tell the process to terminate us */ + DrLogI("Vertex %d.%d killed by cohort or graph, tell process to terminate us", m_parent->GetId(), GetVersion()); + SendTerminateCommand(m_parent->GetId(), m_inputs->GetVersion(), m_process); + } + m_state = DVS_Failed; + exitStatus = error->m_code; + /* we are changing state, so tell the listeners */ + sendNotification = true; + + DrLogI("Vertex %d.%d transition to failed unexpectedly", m_parent->GetId(), GetVersion()); + break; + + default: + DrLogA("Unknown state"); + } + break; + + case DrError_VertexCompleted: + DrLogI("Vertex %d.%d transition to completed, old m_state %d", m_parent->GetId(), GetVersion(), m_state); + m_state = DVS_Completed; + + DrAssert(m_generator != DrNull); + m_generator->StoreOutputLengths(status->GetProcessStatus(), + m_parent->GetStageManager()->GetGraph()-> + GetXCompute()->GetCurrentTimeStamp() - + m_runningTime); + + /* we are changing state, so tell the listeners */ + sendNotification = true; + + DrLogI("Vertex %d.%d transition to completed, new m_state %d", m_parent->GetId(), GetVersion(), m_state); + break; + + default: + /* any error state */ + DrLogI("Vertex %d.%d transition to completed, old m_state %d", m_parent->GetId(), GetVersion(), m_state); + m_state = DVS_Failed; + + reason.SetF("Vertex ended cleanly reporting error %s", DRERRORSTRING(status->GetVertexState())); + error = DrNew DrError(status->GetVertexState(), "DrVertexRecord", DrString()); + + if (status->GetProcessStatus()->GetVertexMetaData() != DrNull) + { + DrMetaDataPtr md = status->GetProcessStatus()->GetVertexMetaData(); + DrMTagHRESULTPtr code = dynamic_cast(md->LookUp(DrProp_ErrorCode)); + DrMTagStringPtr stringTag = dynamic_cast(md->LookUp(DrProp_ErrorString)); + + if (code != DrNull) + { + DrString eString; + if (stringTag == DrNull || stringTag->GetValue().GetString() == DrNull) + { + eString = "No reason supplied"; + } + else + { + eString = stringTag->GetValue(); + } + + DrErrorRef subReason = DrNew DrError(code->GetValue(), "RemoteVertex", eString); + error->AddProvenance(subReason); + + DrString txt = DrError::ToShortText(subReason); + DrLogI("Vertex reported error '%s' as failure reason", txt.GetChars()); + } + } + + /* we are changing state, so tell the listeners */ + sendNotification = true; + + DrLogI("Vertex %d.%d transition to failed neatly, m_state = %d", m_parent->GetId(), GetVersion(), m_state); + break; + } + } + + if (m_state > DVS_Running) + { + m_completionTime = m_parent->GetStageManager()->GetGraph()->GetXCompute()->GetCurrentTimeStamp(); + } + + DrVertexExecutionStatisticsRef stats = MakeExecutionStatistics(status, prop->m_exitCode, exitStatus); + + if (sendNotification) + { + if (status == DrNull) + { + /* make sure the receivers of our notification have some idea what's going on */ + DrLogI("Vertex %d.%d status is null, creating one for receivers of notification", m_parent->GetId(), GetVersion()); + + status = DrNew DrVertexStatus(); + status->GetProcessStatus()->SetVertexId(m_parent->GetId()); + status->GetProcessStatus()->SetVertexInstanceVersion(GetVersion()); + status->GetProcessStatus()->SetInputChannelCount(0); + status->GetProcessStatus()->SetOutputChannelCount(0); + } + + DrVertexInfoRef info = DrNew DrVertexInfo(); + info->m_name = m_parent->GetName(); + info->m_state = m_state; + info->m_info = status; + info->m_statistics = stats; + info->m_process = m_process; + DrLogI("Vertex %d.%d delivering notification", m_parent->GetId(), GetVersion()); + DeliverNotification(info); + } + + if (requestStatus) + { + DrLogI("Vertex %d.%d requesting status", m_parent->GetId(), GetVersion()); + RequestStatus(); + } + + switch (m_state) + { + case DVS_NotStarted: + case DVS_Starting: + /* do nothing */ + break; + + case DVS_Running: + DrLogA("Vertex %d.%d logic wrongly ended up in DVS_Running state", m_parent->GetId(), GetVersion()); + break; + + case DVS_RunningStatus: + /* we have updated status: tell our parent */ + m_parent->ReactToRunningVertexUpdate(this, status->GetVertexState(), status->GetProcessStatus()); + break; + + case DVS_Completed: + case DVS_Failed: + DrLogI("Vertex %d.%d completed, state is now %d", m_parent->GetId(), GetVersion(), m_state); + m_cohort->NotifyVertexCompletion(); + DrLogI("Vertex %d.%d notified cohort of completion, state is now %d", m_parent->GetId(), GetVersion(), m_state); + + m_cohort = DrNull; + + if (m_state == DVS_Completed) + { + DrLogI("Vertex %d.%d completed, calling ReactToCompletedVertex", m_parent->GetId(), GetVersion()); + m_parent->ReactToCompletedVertex(this, stats); + } + else + { + DrLogI("Vertex %d.%d failed, calling ReactToFailedVertex", m_parent->GetId(), GetVersion()); + DrAssert(error != DrNull); + DrVertexProcessStatusPtr ps = DrNull; + if (status != DrNull) + { + ps = status->GetProcessStatus(); + } + + if (m_generator == DrNull) + { + /* make a fake generator to send to the failure machinery */ + MakeGenerator(); + } + + m_parent->ReactToFailedVertex(m_generator, m_inputs, stats, ps, error); + } + } +} + +void DrVertexRecord::RequestStatus() +{ + DrString label = DrVertexStatus::GetPropertyLabel(m_parent->GetId(), m_inputs->GetVersion()); + + DrAssert(m_process.IsEmpty() == false); + { + DrLockBoxKey process(m_process); + process->RequestProperty(m_lastSeenVersion, label, + m_vertexTemplate->GetStatusBlockTime(), this); + } +} + +void DrVertexRecord::SendStartCommand() +{ + DrString label = DrVertexCommandBlock::GetPropertyLabel(m_parent->GetId(), m_inputs->GetVersion()); + DrString description; + description.SetF("Start command for vertex %d.%d", m_parent->GetId(), m_inputs->GetVersion()); + + DrVertexCommandBlockRef startCommand = + m_parent->MakeVertexStartCommand(m_inputs, m_generator->GetResource()); + + DrPropertyWriterRef writer = DrNew DrPropertyWriter(); + startCommand->Serialize(writer); + DrByteArrayRef block = writer->GetBuffer(); + + DrAssert(m_process.IsEmpty() == false); + { + DrLockBoxKey process(m_process); + process->SendCommand(1, label, description, block); + } + + m_startTime = m_parent->GetStageManager()->GetGraph()->GetXCompute()->GetCurrentTimeStamp(); +} + +void DrVertexRecord::SendTerminateCommand(int id, int version, DrLockBox process) +{ + DrString label = DrVertexCommandBlock::GetPropertyLabel(id, version); + DrString description; + description.SetF("Terminate command for vertex %d.%d", id, version); + + DrVertexCommandBlockRef cmd = DrNew DrVertexCommandBlock(); + + cmd->SetVertexCommand(DrVC_Terminate); + cmd->GetProcessStatus()->SetVertexId(id); + cmd->GetProcessStatus()->SetVertexInstanceVersion(version); + + DrPropertyWriterRef writer = DrNew DrPropertyWriter(); + cmd->Serialize(writer); + DrByteArrayRef block = writer->GetBuffer(); + + DrAssert(process.IsEmpty() == false); + { + DrLockBoxKey p(process); + p->SendCommand(2, label, description, block); + } +} diff --git a/GraphManager/vertex/DrVertexRecord.h b/GraphManager/vertex/DrVertexRecord.h new file mode 100644 index 0000000..4cce563 --- /dev/null +++ b/GraphManager/vertex/DrVertexRecord.h @@ -0,0 +1,248 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + + +DRENUM(DrVertexState) +{ + DVS_NotStarted, + DVS_Starting, + DVS_Running, + DVS_RunningStatus, + DVS_Completed, + DVS_Failed +}; + +DRBASECLASS(DrInputChannelExecutionStatistics) +{ +public: + /* the constructor zeros all the fields */ + DrInputChannelExecutionStatistics(); + + /* The data read on an input channel */ + UINT64 m_dataRead; + + /* The remote machine if any of the temp data values below is + non-zero. Otherwise, DrNull. */ + DrResourceRef m_remoteMachine; + + /* The temp data read on an input channel. This is zero if the + channel is reading from a stable storage vertex, e.g. an input + stream. */ + UINT64 m_tempDataRead; + + /* The temp data read across machines in the same pod. This is + equal to m_tempDataRead if the temp data was read from + another machine in the same pod, zero otherwise. */ + UINT64 m_tempDataReadCrossMachine; + + /* The temp data read across machines in different pods. This is + equal to m_tempDataRead if the temp data was read from + another machine in a different pod, zero otherwise. */ + UINT64 m_tempDataReadCrossPod; +}; +DRREF(DrInputChannelExecutionStatistics); + +typedef DrArray DrInputChannelStatsArray; +DRAREF(DrInputChannelStatsArray,DrInputChannelExecutionStatisticsRef); + +DRBASECLASS(DrOutputChannelExecutionStatistics) +{ +public: + /* the constructor zeros all the fields */ + DrOutputChannelExecutionStatistics(); + + /* zero all fields */ + void Clear(); + + /* The output data written */ + UINT64 m_dataWritten; + + /* data written within pod machines */ + UINT64 m_dataIntraPod; + + /* data written across pods */ + UINT64 m_dataCrossPod; +}; +DRREF(DrOutputChannelExecutionStatistics); + +typedef DrArray DrOutputChannelStatsArray; +DRAREF(DrOutputChannelStatsArray,DrOutputChannelExecutionStatisticsRef); + +DRBASECLASS(DrVertexExecutionStatistics) +{ +public: + DrVertexExecutionStatistics(); + + void SetNumberOfChannels(int numberOfInputs, int numberOfOutputs); + + /* The total data read on all input channels */ + DrInputChannelExecutionStatisticsRef m_totalInputData; + /* The total local input data read on all input channels */ + UINT64 m_totalLocalInputData; + + /* The total data written on all output channels */ + DrOutputChannelExecutionStatisticsRef m_totalOutputData; + + + + /* the data read on each input channel broken down by channel */ + DrInputChannelStatsArrayRef m_inputData; + + /* the data written on each output channel broken down by channel */ + DrOutputChannelStatsArrayRef m_outputData; + + /* the time the record was created */ + DrDateTime m_creationTime; + /* the time the vertex start command was sent */ + DrDateTime m_startTime; + /* the time we first learned the vertex was running */ + DrDateTime m_runningTime; + /* the time we first learned the vertex had finished */ + DrDateTime m_completionTime; + + UINT32 m_exitCode; + HRESULT m_exitStatus; + DrMetaDataRef m_metaData; +}; +DRREF(DrVertexExecutionStatistics); + +DRBASECLASS(DrVertexInfo) +{ +public: + DrString m_name; + DrVertexState m_state; + DrVertexStatusRef m_info; + DrVertexExecutionStatisticsRef m_statistics; + DrLockBox m_process; +}; +DRREF(DrVertexInfo); + +typedef DrListener DrVertexListener; +DRIREF(DrVertexListener); + +DRMAKEARRAYLIST(DrVertexListenerIRef); + +typedef DrMessage DrVertexMessage; +DRREF(DrVertexMessage); + +typedef DrNotifier DrVertexNotifier; + +DRBASECLASS(DrVertexTemplate) +{ +public: + DrVertexTemplate(); + + void SetStatusBlockTime(DrTimeInterval delay); + DrTimeInterval GetStatusBlockTime(); + + DrVertexListenerIRefListPtr GetListenerList(); + +private: + DrTimeInterval m_statusBlockTime; + DrVertexListenerIRefListRef m_listenerList; +}; +DRREF(DrVertexTemplate); + + +DRBASECLASS(DrVertexVersionGenerator) +{ +public: + DrVertexVersionGenerator(int version, int numberOfInputs); + + int GetNumberOfInputs(); + void ResetVersion(int version); + int GetVersion(); + + void SetGenerator(int inputIndex, DrVertexOutputGeneratorPtr generator); + DrVertexOutputGeneratorPtr GetGenerator(int inputIndex); + + bool Ready(); + void Compact(DrEdgeHolderPtr edgesBeingCompacted, int numberOfEdgesAfterCompaction); + void Grow(int numberOfEdgesToGrow); + +private: + int m_version; + int m_unfilledCount; + DrGeneratorArrayRef m_generator; +}; +DRREF(DrVertexVersionGenerator); + + +DRDECLARECLASS(DrCohortProcess); +DRREF(DrCohortProcess); + +DRDECLARECLASS(DrActiveVertex); +DRREF(DrActiveVertex); + +DRCLASS(DrVertexRecord) : public DrVertexNotifier, public DrPropertyListener +{ +public: + DrVertexRecord(DrXComputePtr xcompute, DrActiveVertexPtr parent, DrCohortProcessPtr cohort, + DrVertexVersionGeneratorPtr generator, DrVertexTemplatePtr vertexTemplate); + + void Discard(); + + int GetVersion(); + DrActiveVertexOutputGeneratorPtr GetGenerator(); + + DrVertexVersionGeneratorPtr NotifyProcessHasStarted(DrLockBox process); + void SetActiveInput(int inputPort, DrVertexOutputGeneratorPtr generator); + void StartRunning(); + void TriggerFailure(DrErrorPtr originalReason); + + virtual void ReceiveMessage(DrPropertyStatusRef prop); + + static void SendTerminateCommand(int id, int version, DrLockBox process); + +private: + DrVertexStatusRef TryToParseProperty(DrPropertyStatusPtr prop); + DrVertexExecutionStatisticsRef MakeExecutionStatistics(DrVertexStatusPtr status, + UINT32 exitCode, HRESULT exitStatus); + + void MakeGenerator(); + + void SendStartCommand(); + void RequestStatus(); + + DrActiveVertexRef m_parent; + DrActiveVertexOutputGeneratorRef m_generator; + DrVertexVersionGeneratorRef m_inputs; + DrVertexTemplateRef m_vertexTemplate; + DrCohortProcessRef m_cohort; + DrLockBox m_process; + + DrVertexState m_state; + UINT64 m_lastSeenVersion; + + /* the time the record was created */ + DrDateTime m_creationTime; + /* the time the vertex start command was sent */ + DrDateTime m_startTime; + /* the time we first learned the vertex was running */ + DrDateTime m_runningTime; + /* the time we first learned the vertex had finished */ + DrDateTime m_completionTime; +}; +DRREF(DrVertexRecord); + +typedef DrArrayList DrVertexRecordList; +DRAREF(DrVertexRecordList,DrVertexRecordRef); \ No newline at end of file diff --git a/Hdfs/HdfsBridgeManaged/HdfsBridgeManaged.cpp b/Hdfs/HdfsBridgeManaged/HdfsBridgeManaged.cpp new file mode 100644 index 0000000..42f7240 --- /dev/null +++ b/Hdfs/HdfsBridgeManaged/HdfsBridgeManaged.cpp @@ -0,0 +1,319 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// HdfsBridgeManaged.cpp : main project file. + +//#include "stdafx.h" + +#pragma warning( disable: 4793 ) + +#include "HdfsBridgeNative.h" + +#include "HdfsBridgeManaged.h" + +//#include +//#include + + +#define SUCCESS 0 +#define FAILURE -1 + +using namespace System; +using namespace Microsoft::Research::Dryad::Hdfs; + + +//--------------------------------------------------------------------------------------------------- +#if 0 +int main(array ^args) +{ + Console::WriteLine(L"Hello World"); + + Uri^ headUri = gcnew Uri("hpchdfs://svc-d1-17:9000/"); + String^ hdfsFileSetName = "/data"; + String^ hdfsFileSetName1 = "/data/inputPart1.txt"; + + bool ret = HdfsInstance::Initialize(); + + HdfsInstance^ managedHdfsClient1 = gcnew HdfsInstance(headUri); + + HdfsFileInfo^ managedFileInfo = managedHdfsClient1->GetFileInfo(hdfsFileSetName, true); + + HdfsInstance^ managedHdfsClient2 = gcnew HdfsInstance(headUri); + bool exists = managedHdfsClient2->IsFileExists(hdfsFileSetName1); + + //String^ hdfsFileSetName3 = "/data/tpch/customer_1G_128MB.txt"; + //HdfsInstance^ managedHdfsClient3 = gcnew HdfsInstance(headNode, hdfsPort); + //array^ blocks = managedHdfsClient3->GetBlocks(hdfsFileSetName3); + + //String^ hdfsFileSetName4 = "/data/tpch/customer_1G_128MB.txt"; + //HdfsInstance^ managedHdfsClient4 = gcnew HdfsInstance(headNode, hdfsPort); + //String^ blockContent = managedHdfsClient4->ReadBlock(hdfsFileSetName4, 0, 0); + + //String^ hdfsFileSetName3 = "/data/tpch/customer_1G_128MB.txt"; + //HdfsInstance^ managedHdfsClient6 = gcnew HdfsInstance(headNode, hdfsPort); + //bool fileExists2 = managedHdfsClient6->IsFileExists(hdfsFileSetName3); + + //String^ hdfsFileSetName7 = "/data/tpch/customer_1G_128MB.txt"; + //HdfsInstance^ managedHdfsClient7 = gcnew HdfsInstance(headNode, hdfsPort); + //bool fileExists3 = managedHdfsClient7->IsFileExists(hdfsFileSetName7); + + //String^ hdfsFileSetName8 = "/data/tpch/customer_1G_128MB.txt"; + //HdfsInstance^ managedHdfsClient8 = gcnew HdfsInstance(headNode, hdfsPort); + //bool fileExists4 = managedHdfsClient8->IsFileExists(hdfsFileSetName8); + + return 0; +} +#endif + +namespace Microsoft { namespace Research { namespace Dryad { namespace Hdfs +{ + bool HdfsInstance::Initialize() + { + return HdfsBridgeNative::Initialize(); + } + + HdfsInstance::HdfsInstance(String^ hdfsString) + { + bool ret = HdfsInstance::Initialize(); + + if (!ret) + { + throw gcnew ApplicationException("Unable to initialize Hdfs bridge"); + } + + Uri^ hdfsUri = gcnew Uri(hdfsString); + ret = Open(hdfsUri->Host, hdfsUri->Port); + + if (!ret) + { + throw gcnew + ApplicationException(String::Format("Unable to connect to Hdfs at {0}:{1}", hdfsUri->Host, hdfsUri->Port)); + } + + m_serviceUri = hdfsUri->Scheme + "://" + hdfsUri->Host + ":" + hdfsUri->Port + "/"; + } + + HdfsInstance::~HdfsInstance() + { + Close(); + } + + bool HdfsInstance::Open(String^ headNode, long hdfsPort) + { + char* cHeadNode = (char *) Marshal::StringToHGlobalAnsi(headNode).ToPointer(); + + HdfsBridgeNative::Instance* instance; + bool ret = HdfsBridgeNative::OpenInstance(cHeadNode, hdfsPort, &instance); + + Marshal::FreeHGlobal(IntPtr(cHeadNode)); + + if (ret) + { + m_instance = IntPtr(instance); + } + else + { + m_instance = IntPtr::Zero; + } + + return ret; + } + + void HdfsInstance::Close() + { + if (m_instance != IntPtr::Zero) + { + HdfsBridgeNative::InstanceAccessor ia((HdfsBridgeNative::Instance *) m_instance.ToPointer()); + + ia.Dispose(); + + m_instance = IntPtr::Zero; + } + + m_serviceUri = nullptr; + } + + HdfsFileInfo^ HdfsInstance::GetFileInfo(String^ fileName, bool getBlockArray) + { + // Marshal the managed string to unmanaged memory. + char* cFileName = (char*) Marshal::StringToHGlobalAnsi(fileName).ToPointer(); + + HdfsBridgeNative::InstanceAccessor ia((HdfsBridgeNative::Instance *) m_instance.ToPointer()); + + HdfsBridgeNative::FileStat* fileStat; + bool ret = ia.OpenFileStat(cFileName, getBlockArray, &fileStat); + + // free the unmanaged string. + Marshal::FreeHGlobal(IntPtr(cFileName)); + + if (!ret) + { + char* msg = ia.GetExceptionMessage(); + String^ errorMsg = Marshal::PtrToStringAnsi((IntPtr) msg); + HdfsBridgeNative::DisposeString(msg); + throw gcnew ApplicationException("Hdfs GetFileInfo: " + errorMsg); + } + + HdfsBridgeNative::FileStatAccessor fs(fileStat); + + HdfsFileInfo^ fileInfo = gcnew HdfsFileInfo(); + + fileInfo->Name = fileName; + fileInfo->IsDirectory = fs.IsDir(); + fileInfo->Size = fs.GetFileLength(); + fileInfo->LastModified = fs.GetFileLastModified(); + fileInfo->Replication = fs.GetFileReplication(); + fileInfo->BlockSize = fs.GetFileBlockSize(); + + long numberOfFiles = fs.GetNumberOfFiles(); + fileInfo->fileNameArray = gcnew array(numberOfFiles); + + char** cArray = fs.GetFileNameArray(); + for (long i=0; ifileNameArray[i] = Marshal::PtrToStringAnsi((IntPtr) cArray[i]); + } + + fs.DisposeFileNameArray(numberOfFiles, cArray); + + if (getBlockArray) + { + fileInfo->blockArray = gcnew array(fs.GetNumberOfBlocks()); + + for (long i=0; iblockArray->Length; ++i) + { + HdfsBridgeNative::HdfsBlockLocInfo* info = fs.GetBlockInfo(i); + + fileInfo->blockArray[i] = gcnew HdfsBlockInfo(); + + fileInfo->blockArray[i]->fileIndex = info->fileIndex; + fileInfo->blockArray[i]->Size = info->Size; + fileInfo->blockArray[i]->Offset = info->Offset; + fileInfo->blockArray[i]->Hosts = gcnew array(info->numberOfHosts); + + for (int j=0; jnumberOfHosts; ++j) + { + String^ h = Marshal::PtrToStringAnsi((IntPtr) info->Hosts[j]); + fileInfo->blockArray[i]->Hosts[j] = h; + } + + fs.DisposeBlockInfo(info); + } + + fileInfo->totalSize = fs.GetTotalFileLength(); + } + + fs.Dispose(); + + return fileInfo; + } + + bool HdfsInstance::IsFileExists(String^ fileName) + { + // Marshal the managed string to unmanaged memory. + char* cFileName = (char*) Marshal::StringToHGlobalAnsi(fileName).ToPointer(); + + HdfsBridgeNative::InstanceAccessor ia((HdfsBridgeNative::Instance *) m_instance.ToPointer()); + + bool exists = false; + bool result = ia.IsFileExists(cFileName, &exists); + + // free the unmanaged string. + Marshal::FreeHGlobal(IntPtr(cFileName)); + + if (!result) + { + char* msg = ia.GetExceptionMessage(); + String^ errorMsg = Marshal::PtrToStringAnsi((IntPtr) msg); + HdfsBridgeNative::DisposeString(msg); + throw gcnew ApplicationException("Hdfs IsFileExists: " + errorMsg); + } + + return exists; + } + + bool HdfsInstance::DeleteFile(String^ fileName, bool recursive) + { + // Marshal the managed string to unmanaged memory. + char* cFileName = (char*) Marshal::StringToHGlobalAnsi(fileName).ToPointer(); + + HdfsBridgeNative::InstanceAccessor ia((HdfsBridgeNative::Instance *) m_instance.ToPointer()); + + bool deleted = false; + bool result = ia.DeleteFileOrDir(cFileName, recursive, &deleted); + + // free the unmanaged string. + Marshal::FreeHGlobal(IntPtr(cFileName)); + + if (!result) + { + char* msg = ia.GetExceptionMessage(); + String^ errorMsg = Marshal::PtrToStringAnsi((IntPtr) msg); + HdfsBridgeNative::DisposeString(msg); + throw gcnew ApplicationException("Hdfs DeleteFile: " + errorMsg); + } + + return deleted; + } + + bool HdfsInstance::RenameFile(String^ dstFileName, String^ srcFileName) + { + // Marshal the managed strings to unmanaged memory. + char* cDstFileName = (char*) Marshal::StringToHGlobalAnsi(dstFileName).ToPointer(); + char* cSrcFileName = (char*) Marshal::StringToHGlobalAnsi(srcFileName).ToPointer(); + + HdfsBridgeNative::InstanceAccessor ia((HdfsBridgeNative::Instance *) m_instance.ToPointer()); + + bool renamed; + bool result = ia.RenameFileOrDir(cDstFileName, cSrcFileName, &renamed); + + // free the unmanaged strings. + Marshal::FreeHGlobal(IntPtr(cSrcFileName)); + Marshal::FreeHGlobal(IntPtr(cDstFileName)); + + if (!result) + { + char* msg = ia.GetExceptionMessage(); + String^ errorMsg = Marshal::PtrToStringAnsi((IntPtr) msg); + HdfsBridgeNative::DisposeString(msg); + throw gcnew ApplicationException("Hdfs RenameFile: " + errorMsg); + } + + return renamed; + } + + String^ HdfsInstance::FromInternalUri(String^ inputString) + { + if (inputString->StartsWith(m_serviceUri)) + { + return inputString->Substring(m_serviceUri->Length); + } + else + { + throw gcnew ApplicationException(inputString + " doesn't start with " + m_serviceUri); + return nullptr; + } + } + + String^ HdfsInstance::ToInternalUri(String^ inputString) + { + return m_serviceUri + inputString; + } +}}}} diff --git a/Hdfs/HdfsBridgeManaged/HdfsBridgeManaged.h b/Hdfs/HdfsBridgeManaged/HdfsBridgeManaged.h new file mode 100644 index 0000000..6f1799c --- /dev/null +++ b/Hdfs/HdfsBridgeManaged/HdfsBridgeManaged.h @@ -0,0 +1,92 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + + +//#pragma once +//#include + +using namespace System; +using namespace System::Collections::Generic; +using namespace System::Runtime::InteropServices; + +namespace Microsoft { namespace Research { namespace Dryad { namespace Hdfs +{ + //--------------------------------------------------------------------------------------------------- + + public ref class HdfsBlockInfo + { + public: + array^ Hosts; + long long Size; + long long Offset; + int fileIndex; + }; + + //--------------------------------------------------------------------------------------------------- + + public ref class HdfsFileInfo + { + public: + String^ Name; + bool IsDirectory; + long long Size; + long long LastModified; + short Replication; + long long BlockSize; + + long long totalSize; + + array^ blockArray; + array^ fileNameArray; + }; + + //--------------------------------------------------------------------------------------------------- + + public ref class HdfsInstance : public IDisposable + { + public: + static bool Initialize(); + + HdfsInstance(String^ hdfsUri); + ~HdfsInstance(); + + void Close(); + + bool IsFileExists(String^ fileName); + + HdfsFileInfo^ GetFileInfo(String^ fileName, bool getBlockArray); + + bool DeleteFile(String^ fileName, bool recursive); + + bool RenameFile(String^ dstFileName, String^ srcFileName); + + String^ ToInternalUri(String^ fileName); + + String^ FromInternalUri(String^ fileName); + + private: + bool Open(String^ headNode, long hdfsPort); + + IntPtr m_instance; + String^ m_serviceUri; + }; + + //--------------------------------------------------------------------------------------------------- +}}}} diff --git a/Hdfs/HdfsBridgeManaged/HdfsBridgeManaged.vcxproj b/Hdfs/HdfsBridgeManaged/HdfsBridgeManaged.vcxproj new file mode 100644 index 0000000..eed1295 --- /dev/null +++ b/Hdfs/HdfsBridgeManaged/HdfsBridgeManaged.vcxproj @@ -0,0 +1,155 @@ + + + + + Debug + Win32 + + + Debug + x64 + + + Release + Win32 + + + Release + x64 + + + + ManagedCProj + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5} + + + + DynamicLibrary + true + true + Unicode + + + DynamicLibrary + true + true + Unicode + + + DynamicLibrary + false + true + + + DynamicLibrary + false + true + + + + + + + + + + + + + + + + + + + true + Microsoft.Research.Dryad.Hdfs + + + false + Microsoft.Research.Dryad.Hdfs + ..\..\bin\$(Configuration)\ + + + true + + + false + + + + WIN32;_DEBUG;_WINDOWS;_USRDLL;HDFSBRIDGEMANAGED_EXPORTS;%(PreprocessorDefinitions) + MultiThreadedDebugDLL + Level3 + ProgramDatabase + Disabled + ..\HdfsBridgeNative;%(AdditionalIncludeDirectories) + + + MachineX86 + true + Windows + + + + + WIN32;_DEBUG;_WINDOWS;_USRDLL;HDFSBRIDGEMANAGED_EXPORTS;%(PreprocessorDefinitions) + MultiThreadedDebugDLL + Level3 + ProgramDatabase + Disabled + ..\HdfsBridgeNative;%(AdditionalIncludeDirectories) + + + true + Windows + + + + + WIN32;NDEBUG;_WINDOWS;_USRDLL;HDFSBRIDGEMANAGED_EXPORTS;%(PreprocessorDefinitions) + MultiThreadedDLL + Level3 + ProgramDatabase + + + MachineX86 + true + Windows + true + true + + + + + WIN32;NDEBUG;_WINDOWS;_USRDLL;HDFSBRIDGEMANAGED_EXPORTS;%(PreprocessorDefinitions) + MultiThreadedDLL + Level3 + ProgramDatabase + ..\HdfsBridgeNative;%(AdditionalIncludeDirectories) + + + true + Windows + true + true + + + + + + + + + + + + + + {95fbf9b7-9407-4554-a74a-3527839bd1b6} + + + + + + \ No newline at end of file diff --git a/Hdfs/HdfsBridgeNative/HdfsBridgeNative.cpp b/Hdfs/HdfsBridgeNative/HdfsBridgeNative.cpp new file mode 100644 index 0000000..d1589c4 --- /dev/null +++ b/Hdfs/HdfsBridgeNative/HdfsBridgeNative.cpp @@ -0,0 +1,1105 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#include "HdfsBridgeNative.h" + +#include +#include +#include + +static JavaVM* s_jvm = NULL; + +static char* GetExceptionMessageLocal(JNIEnv* env, jclass cls, jobject obj) +{ + jfieldID fidMessage = env->GetFieldID( + cls, "exceptionMessage", "Ljava/lang/String;"); + + assert(fidMessage != NULL); + + jstring message = (jstring) env->GetObjectField(obj, fidMessage); + + char* msg = NULL; + + if (message == NULL) + { + msg = _strdup(""); + } + else + { + const char* msgCopy = (const char*)(env->GetStringUTFChars(message, NULL)); + msg = _strdup(msgCopy); + env->ReleaseStringUTFChars(message, msgCopy); + } + + env->DeleteLocalRef(message); + + return msg; +} + +static JNIEnv* AttachToJvm() +{ + JNIEnv* env; + int ret = s_jvm->AttachCurrentThread((void**) &env, NULL); + + assert(ret == JNI_OK); + + return env; +} + +namespace HdfsBridgeNative +{ + struct Env + { + JNIEnv* e; + }; + + class InstanceInternal + { + public: + jclass m_clsInstance; + jobject m_obj; + Instance* m_holder; + }; + + class FileStatInternal + { + public: + jclass m_clsFileStat; + jobject m_fileStat; + jclass m_clsBlockLocations; + jobject m_blockLocations; + FileStat* m_holder; + }; + + class ReaderInternal + { + public: + jclass m_clsReader; + jobject m_reader; + jclass m_clsReaderBlock; + Reader* m_holder; + }; + + class WriterInternal + { + public: + jclass m_clsWriter; + jobject m_writer; + Writer* m_holder; + }; + + HdfsBlockLocInfo::HdfsBlockLocInfo() + { + numberOfHosts = 0; + Hosts = NULL; + Size = 0; + Offset = 0; + } + + HdfsBlockLocInfo::~HdfsBlockLocInfo() + { + for (long i=0; i 1) + { + fprintf(stderr, "\nProcess already contains %d Java VMs\n", nVMs); + return false; + } + + char classPath[_MAX_ENV]; + DWORD dRet = GetEnvironmentVariableA("JNI_CLASSPATH", classPath, _MAX_ENV); + if (dRet == 0) + { + fprintf(stderr, "Failed to get 'classpath' environment variable\n"); + return false; + } + + JavaVMInitArgs vm_args; + JNI_GetDefaultJavaVMInitArgs(&vm_args); + vm_args.version = JNI_VERSION_1_6; + + JavaVMOption options[1]; // increment when turning on verbose JNI + vm_args.nOptions = 1; + vm_args.options = options; + options[0].optionString = new char[_MAX_ENV]; + sprintf_s(options[0].optionString, _MAX_ENV, "-Djava.class.path=%s", classPath); + //fprintf(stderr, "JNI_CLASSPATH:[%s]\n", options[0].optionString); + //options[1].optionString = "-verbose:jni"; + /* + vm_args.nOptions = 1; + JavaVMOption options; + options.optionString = "-verbose:jni"; + vm_args.options = &options; + */ + vm_args.ignoreUnrecognized = 0; + + JNIEnv* env; + ret = JNI_CreateJavaVM(&s_jvm, (void**) &env, &vm_args); + + delete [] options[0].optionString; + + if (ret < 0) + { + s_jvm = NULL; + printf("\nCreateJavaVM returned %d\n", ret); + return false; + } + + return true; + } + + void DisposeString(char* str) + { + free(str); + } + + bool OpenInstance(const char* headNode, long portNumber, Instance** pInstance) + { + JNIEnv* env = AttachToJvm(); + + jclass clsHdfsBridge = env->FindClass("GSLHDFS/HdfsBridge"); + if (clsHdfsBridge == NULL) + { + printf("Failed to find HdfsBridge class\n"); + return false; + } + + jmethodID midOpenInstance = env->GetStaticMethodID( + clsHdfsBridge, "OpenInstance", + "(Ljava/lang/String;J)LGSLHDFS/HdfsBridge$Instance;"); + assert(midOpenInstance != NULL); + + jstring jHeadNode = env->NewStringUTF(headNode); + jlong jPortNumber = portNumber; + + jobject localInstance = env->CallStaticObjectMethod( + clsHdfsBridge, midOpenInstance, jHeadNode, jPortNumber); + env->DeleteLocalRef(jHeadNode); + + if (localInstance == NULL) + { + printf("Failed to open instance %s:%d\n", headNode, portNumber); + return false; + } + + InstanceInternal* instance = new InstanceInternal(); + + instance->m_clsInstance = env->FindClass("GSLHDFS/HdfsBridge$Instance"); + assert(instance->m_clsInstance != NULL); + + instance->m_obj = env->NewGlobalRef(localInstance); + env->DeleteLocalRef(localInstance); + + Instance* holder = new Instance(); + holder->p = instance; + instance->m_holder = holder; + *pInstance = holder; + + return true; + }; + + InstanceAccessor::InstanceAccessor(Instance* instance) + { + m_env = new Env; + m_env->e = AttachToJvm(); + m_inst = (InstanceInternal* ) instance->p; + } + + InstanceAccessor::~InstanceAccessor() + { + delete m_env; + } + + void InstanceAccessor::Dispose() + { + if (m_inst->m_obj != NULL) + { + m_env->e->DeleteGlobalRef(m_inst->m_obj); + } + delete m_inst->m_holder; + delete m_inst; + } + + char* InstanceAccessor::GetExceptionMessage() + { + return GetExceptionMessageLocal(m_env->e, m_inst->m_clsInstance, m_inst->m_obj); + } + + bool InstanceAccessor::IsFileExists(char* fileName, bool* pExists) + { + jmethodID midIsFileExist = + m_env->e->GetMethodID(m_inst->m_clsInstance, "IsFileExist", "(Ljava/lang/String;)I"); + assert(midIsFileExist != NULL); + + jstring jFileName = m_env->e->NewStringUTF(fileName); + jint jFileExists = + m_env->e->CallIntMethod(m_inst->m_obj, midIsFileExist, jFileName); + m_env->e->DeleteLocalRef(jFileName); + + if (jFileExists == -1) + { + return false; + } + else + { + *pExists = (jFileExists == 1) ? true : false; + return true; + } + } + + bool InstanceAccessor::DeleteFileOrDir(char* fileName, bool recursive, bool* pDeleted) + { + jmethodID midDeleteFile = + m_env->e->GetMethodID(m_inst->m_clsInstance, "DeleteFile", "(Ljava/lang/String;Z)I"); + assert(midDeleteFile != NULL); + + jstring jFileName = m_env->e->NewStringUTF(fileName); + jboolean jRecursive = (jboolean) recursive; + jint jFileDeleted = + m_env->e->CallIntMethod(m_inst->m_obj, midDeleteFile, jFileName, jRecursive); + m_env->e->DeleteLocalRef(jFileName); + + if (jFileDeleted == -1) + { + return false; + } + else + { + *pDeleted = (jFileDeleted == 1) ? true : false; + return true; + } + } + + bool InstanceAccessor::RenameFileOrDir(char* dstFileName, char* srcFileName, bool* pRenamed) + { + jmethodID midRenameFile = + m_env->e->GetMethodID(m_inst->m_clsInstance, "RenameFile", "(Ljava/lang/String;Ljava/lang/String;)I"); + assert(midRenameFile != NULL); + + jstring jDstFileName = m_env->e->NewStringUTF(dstFileName); + jstring jSrcFileName = m_env->e->NewStringUTF(srcFileName); + jint jFileRenamed = + m_env->e->CallIntMethod(m_inst->m_obj, midRenameFile, jDstFileName, jSrcFileName); + m_env->e->DeleteLocalRef(jDstFileName); + m_env->e->DeleteLocalRef(jSrcFileName); + + if (jFileRenamed == -1) + { + return false; + } + else + { + *pRenamed = (jFileRenamed == 1) ? true : false; + return true; + } + } + + bool InstanceAccessor::OpenFileStat( + const char* fileName, + bool getBlockArray, + FileStat** pFileStat) + { + FileStatInternal* fs = new FileStatInternal(); + + fs->m_clsFileStat = m_env->e->FindClass("org/apache/hadoop/fs/FileStatus"); + assert(fs->m_clsFileStat != NULL); + + fs->m_clsBlockLocations = m_env->e->FindClass("GSLHDFS/HdfsBridge$Instance$BlockLocations"); + assert(fs->m_clsBlockLocations != NULL); + + jmethodID midOpenFileStat = m_env->e->GetMethodID( + m_inst->m_clsInstance, + "OpenFileStatus", + "(Ljava/lang/String;Z)Lorg/apache/hadoop/fs/FileStatus;"); + assert(midOpenFileStat != NULL); + + fs->m_fileStat = NULL; + fs->m_blockLocations = NULL; + + jstring jFileName = m_env->e->NewStringUTF(fileName); + jobject localFileStat = m_env->e->CallObjectMethod(m_inst->m_obj, midOpenFileStat, jFileName); + m_env->e->DeleteLocalRef(jFileName); + + if (localFileStat == NULL) + { + delete fs; + return false; + } + else + { + fs->m_fileStat = m_env->e->NewGlobalRef(localFileStat); + m_env->e->DeleteLocalRef(localFileStat); + } + + jmethodID midOpenBlockLocations = m_env->e->GetMethodID( + m_inst->m_clsInstance, + "OpenBlockLocations", + "(Lorg/apache/hadoop/fs/FileStatus;Z)LGSLHDFS/HdfsBridge$Instance$BlockLocations;"); + assert(midOpenBlockLocations != NULL); + + jboolean jGetBlockArray = (jboolean) getBlockArray; + jobject localBlockLoc = + m_env->e->CallObjectMethod( + m_inst->m_obj, + midOpenBlockLocations, + fs->m_fileStat, jGetBlockArray); + + if (localBlockLoc == NULL) + { + m_env->e->DeleteGlobalRef(fs->m_fileStat); + delete fs; + return false; + } + else + { + fs->m_blockLocations = m_env->e->NewGlobalRef(localBlockLoc); + m_env->e->DeleteLocalRef(localBlockLoc); + } + + FileStat* fileStat = new FileStat(); + fileStat->p = fs; + fs->m_holder = fileStat; + *pFileStat = fileStat; + + return true; + } + + bool InstanceAccessor::OpenReader(const char* fileName, Reader** pReader) + { + ReaderInternal* r = new ReaderInternal; + + r->m_clsReader = m_env->e->FindClass("GSLHDFS/HdfsBridge$Instance$Reader"); + assert(r->m_clsReader != NULL); + + r->m_clsReaderBlock = m_env->e->FindClass("GSLHDFS/HdfsBridge$Instance$Reader$Block"); + assert(r->m_clsReaderBlock != NULL); + + jmethodID midOpenReader = m_env->e->GetMethodID( + m_inst->m_clsInstance, + "OpenReader", + "(Ljava/lang/String;)LGSLHDFS/HdfsBridge$Instance$Reader;"); + assert(midOpenReader != NULL); + + jstring jFileName = m_env->e->NewStringUTF(fileName); + jobject localReader = m_env->e->CallObjectMethod(m_inst->m_obj, midOpenReader, jFileName); + m_env->e->DeleteLocalRef(jFileName); + + if (localReader == NULL) + { + delete r; + return false; + } + else + { + r->m_reader = m_env->e->NewGlobalRef(localReader); + m_env->e->DeleteLocalRef(localReader); + } + + Reader* reader = new Reader(); + reader->p = r; + r->m_holder = reader; + *pReader = reader; + + return true; + } + + bool InstanceAccessor::OpenWriter(const char* fileName, Writer** pWriter) + { + WriterInternal* w = new WriterInternal; + + w->m_clsWriter = m_env->e->FindClass("GSLHDFS/HdfsBridge$Instance$Writer"); + assert(w->m_clsWriter != NULL); + + jmethodID midOpenWriter = m_env->e->GetMethodID( + m_inst->m_clsInstance, + "OpenWriter", + "(Ljava/lang/String;)LGSLHDFS/HdfsBridge$Instance$Writer;"); + assert(midOpenWriter != NULL); + + jstring jFileName = m_env->e->NewStringUTF(fileName); + jobject localWriter = m_env->e->CallObjectMethod(m_inst->m_obj, midOpenWriter, jFileName); + m_env->e->DeleteLocalRef(jFileName); + + if (localWriter == NULL) + { + delete w; + return false; + } + else + { + w->m_writer = m_env->e->NewGlobalRef(localWriter); + m_env->e->DeleteLocalRef(localWriter); + } + + Writer* writer = new Writer(); + writer->p = w; + w->m_holder = writer; + *pWriter = writer; + + return true; + } + + + FileStatAccessor::FileStatAccessor(FileStat* fileStat) + { + m_env = new Env; + m_env->e = AttachToJvm(); + m_stat = (FileStatInternal *) fileStat->p; + } + + FileStatAccessor::~FileStatAccessor() + { + delete m_env; + } + + void FileStatAccessor::Dispose() + { + if (m_stat->m_fileStat != NULL) + { + m_env->e->DeleteGlobalRef(m_stat->m_fileStat); + m_stat->m_fileStat = NULL; + } + + if (m_stat->m_blockLocations != NULL) + { + m_env->e->DeleteGlobalRef(m_stat->m_blockLocations); + m_stat->m_blockLocations = NULL; + } + + delete m_stat->m_holder; + delete m_stat; + } + + char* FileStatAccessor::GetExceptionMessage() + { + return GetExceptionMessageLocal(m_env->e, m_stat->m_clsFileStat, m_stat->m_fileStat); + } + + char* FileStatAccessor::GetBlockExceptionMessage() + { + return GetExceptionMessageLocal(m_env->e, m_stat->m_clsBlockLocations, m_stat->m_blockLocations); + } + + long long FileStatAccessor::GetFileLength() + { + jmethodID mid = m_env->e->GetMethodID( + m_stat->m_clsFileStat, "getLen", "()J"); + assert(mid != NULL); + + return m_env->e->CallLongMethod(m_stat->m_fileStat, mid); + } + + bool FileStatAccessor::IsDir() + { + jmethodID mid = m_env->e->GetMethodID( + m_stat->m_clsFileStat, "isDir", "()Z"); + assert(mid != NULL); + + jboolean isDir = m_env->e->CallBooleanMethod(m_stat->m_fileStat, mid); + return (isDir) ? true : false; + } + + long long FileStatAccessor::GetFileLastModified() + { + jmethodID mid = m_env->e->GetMethodID( + m_stat->m_clsFileStat, "getModificationTime", "()J"); + assert(mid != NULL); + + return m_env->e->CallLongMethod(m_stat->m_fileStat, mid); + } + + short FileStatAccessor::GetFileReplication() + { + jmethodID mid = m_env->e->GetMethodID( + m_stat->m_clsFileStat, "getReplication", "()S"); + assert(mid != NULL); + + return m_env->e->CallShortMethod(m_stat->m_fileStat, mid); + } + + long long FileStatAccessor::GetFileBlockSize() + { + jmethodID mid = m_env->e->GetMethodID( + m_stat->m_clsFileStat, "getBlockSize", "()J"); + assert(mid != NULL); + + return m_env->e->CallLongMethod(m_stat->m_fileStat, mid); + } + + long FileStatAccessor::GetNumberOfBlocks() + { + jmethodID mid = m_env->e->GetMethodID( + m_stat->m_clsBlockLocations, "GetNumberOfBlocks", "()I"); + assert(mid != NULL); + + return m_env->e->CallIntMethod(m_stat->m_blockLocations, mid); + } + + HdfsBlockLocInfo* FileStatAccessor::GetBlockInfo(long blockId) + { + HdfsBlockLocInfo* block = new HdfsBlockLocInfo(); + + jint jId = blockId; + + jmethodID mid = m_env->e->GetMethodID( + m_stat->m_clsBlockLocations, "GetBlockLength", "(I)J"); + assert(mid != NULL); + + block->Size = m_env->e->CallLongMethod(m_stat->m_blockLocations, mid, jId); + + mid = m_env->e->GetMethodID( + m_stat->m_clsBlockLocations, "GetBlockOffset", "(I)J"); + assert(mid != NULL); + + block->Offset = m_env->e->CallLongMethod(m_stat->m_blockLocations, mid, jId); + + mid = m_env->e->GetMethodID( + m_stat->m_clsBlockLocations, "GetBlockHosts", "(I)[Ljava/lang/String;"); + assert(mid != NULL); + + jobjectArray hostArray = (jobjectArray) m_env->e->CallObjectMethod(m_stat->m_blockLocations, mid, jId); + if (hostArray != NULL) + { + jsize jArrayLength = m_env->e->GetArrayLength(hostArray); + block->numberOfHosts = jArrayLength; + block->Hosts = new char* [jArrayLength]; + + const char *hostName; + for (int i=0; ie->GetObjectArrayElement(hostArray, i); + hostName = (const char*) m_env->e->GetStringUTFChars(jHost, NULL); + block->Hosts[i] = _strdup(hostName); + m_env->e->ReleaseStringUTFChars(jHost, hostName); + m_env->e->DeleteLocalRef(jHost); + } + + m_env->e->DeleteLocalRef(hostArray); + } + + mid = m_env->e->GetMethodID( + m_stat->m_clsBlockLocations, "GetBlockFileId", "(I)I"); + assert(mid != NULL); + + block->fileIndex = m_env->e->CallIntMethod(m_stat->m_blockLocations, mid, jId); + + return block; + } + + void FileStatAccessor::DisposeBlockInfo(HdfsBlockLocInfo* bi) + { + delete bi; + } + + long long FileStatAccessor::GetTotalFileLength() + { + jfieldID fidSize = m_env->e->GetFieldID( + m_stat->m_clsBlockLocations, "fileSize", "J"); + assert(fidSize != NULL); + + return m_env->e->GetLongField(m_stat->m_blockLocations, fidSize); + } + + long FileStatAccessor::GetNumberOfFiles() + { + jmethodID mid = m_env->e->GetMethodID( + m_stat->m_clsBlockLocations, "GetNumberOfFileNames", "()I"); + assert(mid != NULL); + + return m_env->e->CallIntMethod(m_stat->m_blockLocations, mid); + } + + char** FileStatAccessor::GetFileNameArray() + { + jmethodID mid = m_env->e->GetMethodID( + m_stat->m_clsBlockLocations, "GetFileNames", "()[Ljava/lang/String;"); + assert(mid != NULL); + + char** array = NULL; + + jobjectArray nameArray = (jobjectArray) m_env->e->CallObjectMethod(m_stat->m_blockLocations, mid); + if (nameArray != NULL) + { + jsize jArrayLength = m_env->e->GetArrayLength(nameArray); + array = new char* [jArrayLength]; + + const char *fileName; + for (int i=0; ie->GetObjectArrayElement(nameArray, i); + fileName = (const char*) m_env->e->GetStringUTFChars(jFile, NULL); + array[i] = _strdup(fileName); + m_env->e->ReleaseStringUTFChars(jFile, fileName); + m_env->e->DeleteLocalRef(jFile); + } + + m_env->e->DeleteLocalRef(nameArray); + } + + return array; + } + + void FileStatAccessor::DisposeFileNameArray(long length, char** array) + { + for (long i=0; ie = AttachToJvm(); + m_rdr = (ReaderInternal *) reader->p; + } + + ReaderAccessor::~ReaderAccessor() + { + delete m_env; + } + + void ReaderAccessor::Dispose() + { + if (m_rdr->m_reader != NULL) + { + m_env->e->DeleteGlobalRef(m_rdr->m_reader); + } + + delete m_rdr->m_holder; + delete m_rdr; + } + + char* ReaderAccessor::GetExceptionMessage() + { + return GetExceptionMessageLocal(m_env->e, m_rdr->m_clsReader, m_rdr->m_reader); + } + + long ReaderAccessor::ReadBlock(long long offset, char* buffer, long bufferLength) + { + jmethodID mid = m_env->e->GetMethodID( + m_rdr->m_clsReader, "ReadBlock", "(JI)LGSLHDFS/HdfsBridge$Instance$Reader$Block;"); + assert(mid != NULL); + + jlong jOffset = offset; + jint jLength = bufferLength; + jobject block = m_env->e->CallObjectMethod(m_rdr->m_reader, mid, jOffset, jLength); + + jfieldID fidret = m_env->e->GetFieldID( + m_rdr->m_clsReaderBlock, "ret", "I"); + assert(fidret != NULL); + + jint bytesRead = m_env->e->GetIntField(block, fidret); + + assert(bytesRead <= bufferLength); + + if (bytesRead > 0) + { + jfieldID fid = m_env->e->GetFieldID( + m_rdr->m_clsReaderBlock, "buffer", "[B"); + assert(fid != NULL); + + jbyteArray byteArray = (jbyteArray) m_env->e->GetObjectField(block, fid); + assert(byteArray != NULL); + + jint arrayLength = m_env->e->GetArrayLength(byteArray); + assert(arrayLength >= bytesRead); + + m_env->e->GetByteArrayRegion(byteArray, 0, bytesRead, (jbyte *) buffer); + + m_env->e->DeleteLocalRef(byteArray); + } + + m_env->e->DeleteLocalRef(block); + + return bytesRead; + } + + bool ReaderAccessor::Close() + { + jmethodID mid = m_env->e->GetMethodID( + m_rdr->m_clsReader, "Close", "()I"); + + assert(mid != NULL); + + jint ret = m_env->e->CallIntMethod(m_rdr->m_reader, mid); + + return (ret) ? true : false; + } + + + WriterAccessor::WriterAccessor(Writer* writer) + { + m_env = new Env; + m_env->e = AttachToJvm(); + m_wtr = (WriterInternal *) writer->p; + } + + WriterAccessor::~WriterAccessor() + { + delete m_env; + } + + void WriterAccessor::Dispose() + { + if (m_wtr->m_writer != NULL) + { + m_env->e->DeleteGlobalRef(m_wtr->m_writer); + } + + delete m_wtr->m_holder; + delete m_wtr; + } + + char* WriterAccessor::GetExceptionMessage() + { + return GetExceptionMessageLocal(m_env->e, m_wtr->m_clsWriter, m_wtr->m_writer); + } + + bool WriterAccessor::WriteBlock(char* buffer, long bufferLength, bool flushAfter) + { + jmethodID mid = m_env->e->GetMethodID( + m_wtr->m_clsWriter, "WriteBlock", "([BZ)I"); + assert(mid != NULL); + + jboolean jFlush = (jboolean) flushAfter; + jint jLength = bufferLength; + jbyteArray jBuffer = m_env->e->NewByteArray(jLength); + assert(jBuffer != NULL); + + m_env->e->SetByteArrayRegion(jBuffer, 0, jLength, (jbyte *) buffer); + + jint jRet = m_env->e->CallIntMethod(m_wtr->m_writer, mid, jBuffer, jFlush); + + m_env->e->DeleteLocalRef(jBuffer); + + return (jRet) ? true : false; + } + + bool WriterAccessor::Close() + { + jmethodID mid = m_env->e->GetMethodID( + m_wtr->m_clsWriter, "Close", "()I"); + + assert(mid != NULL); + + jint ret = m_env->e->CallIntMethod(m_wtr->m_writer, mid); + + return (ret) ? true : false; + } +}; + +#if 0 + +#include + +struct ReadBlock +{ + HdfsBridgeNative::Reader* r; + HdfsBridgeNative::Instance* i; + long long o; + const char* f; +}; + +unsigned __stdcall ThreadFunc(void* arg) +{ + ReadBlock* block = (ReadBlock *) arg; + + HdfsBridgeNative::ReaderAccessor r(block->r); + + int bytesRead = 0; + long long offset = 0; + char* buffer = new char[256*1024]; + do + { + bytesRead = r.ReadBlock(offset, buffer, 256*1024); + if (bytesRead > 0) + { + printf("Read from %s:%I64d:%d\n", block->f, offset, bytesRead); + offset += bytesRead; + } + if (bytesRead < -1) + { + printf("%s: %s\n", block->f, r.GetExceptionMessage()); + } + if (bytesRead == -1) + { + printf("EOF\n"); + } + } while (bytesRead > -1); + + return 0; +} + +int main(int argc, wchar_t** argv) +{ + bool ret = HdfsBridgeNative::Initialize(); + + if (!ret) + { + printf("Failed to initialize\n"); + return 0; + } + + HdfsBridgeNative::Instance* instance; + ret = HdfsBridgeNative::OpenInstance("svc-d1-17", 9000, &instance); + + if (!ret) + { + printf("failed open\n"); + return 0; + } + + HANDLE h[4]; + + HdfsBridgeNative::InstanceAccessor bridge(instance); + + ReadBlock* r0 = new ReadBlock; + r0->f = "/data/inputPart0.txt"; + + ret = bridge.OpenReader(r0->f, &r0->r); + + if (!ret) + { + printf("%s failed open %s\n", r0->f, bridge.GetExceptionMessage()); + bridge.Dispose(); + return 0; + } + + h[0] = (HANDLE) _beginthreadex( + NULL, + 0, + ThreadFunc, + r0, + 0, + NULL + ); + + h[1] = (HANDLE) _beginthreadex( + NULL, + 0, + ThreadFunc, + r0, + 0, + NULL + ); + + ReadBlock* r1 = new ReadBlock; + r1->f = "/data"; + + HdfsBridgeNative::FileStat* fileStat; + ret = bridge.OpenFileStat(r1->f, true, &fileStat); + + { + HdfsBridgeNative::FileStatAccessor fs(fileStat); + + long long ll = fs.GetFileLength(); + ll = fs.GetFileLastModified(); + ll = fs.GetFileBlockSize(); + long l = fs.GetFileReplication(); + bool b = fs.IsDir(); + + long nBlocks = fs.GetNumberOfBlocks(); + for (long i=0; if, &r1->r); + + if (!ret) + { + printf("%s failed open %s\n", r1->f, bridge.GetExceptionMessage()); + bridge.Dispose(); + return 0; + } + + h[2] = (HANDLE) _beginthreadex( + NULL, + 0, + ThreadFunc, + r1, + 0, + NULL + ); + + h[3] = (HANDLE) _beginthreadex( + NULL, + 0, + ThreadFunc, + r1, + 0, + NULL + ); + + WaitForMultipleObjects(4, h, TRUE, INFINITE); + + HdfsBridgeNative::ReaderAccessor ra0(r0->r); + ret = ra0.Close(); + ra0.Dispose(); + + HdfsBridgeNative::ReaderAccessor ra1(r1->r); + ret = ra1.Close(); + ra1.Dispose(); + + bridge.Dispose(); + +#if 0 + HdfsBridgeNative::Instance* instance; + ret = HdfsBridgeNative::OpenInstance("svc-d1-17", 9000, &instance); + + { + HdfsBridgeNative::InstanceAccessor bridge(instance); + + char* msg = bridge.GetExceptionMessage(); + free(msg); + + bool exists; + ret = bridge.IsFileExists("data/foo", &exists); + msg = bridge.GetExceptionMessage(); + free(msg); + + ret = bridge.IsFileExists("/data/inputPart0.txt", &exists); + msg = bridge.GetExceptionMessage(); + free(msg); + + HdfsBridgeNative::FileStat* fileStat; + + ret = bridge.OpenFileStat("data/foo", false, &fileStat); + msg = bridge.GetExceptionMessage(); + free(msg); + + ret = bridge.OpenFileStat("/data/inputPart0.txt", true, &fileStat); + msg = bridge.GetExceptionMessage(); + free(msg); + + { + HdfsBridgeNative::FileStatAccessor fs(fileStat); + + long long ll = fs.GetFileLength(); + ll = fs.GetFileLastModified(); + ll = fs.GetFileBlockSize(); + long l = fs.GetFileReplication(); + bool b = fs.IsDir(); + + long nBlocks = fs.GetNumberOfBlocks(); + for (long i=0; i 0) + { + printf("Read from %I64d:%d\n", offset, bytesRead); + fwrite(buffer, 1, bytesRead, f); + offset += bytesRead; + } + if (bytesRead < -1) + { + msg = r.GetExceptionMessage(); + printf("%s\n", msg); + free(msg); + } + if (bytesRead == -1) + { + printf("EOF\n"); + } + } while (bytesRead > -1); + + fclose(f); + + ret = r.Close(); + msg = r.GetExceptionMessage(); + free(msg); + + r.Dispose(); + } + + bridge.Dispose(); + } +#endif + + return 0; +} + +#endif \ No newline at end of file diff --git a/Hdfs/HdfsBridgeNative/HdfsBridgeNative.h b/Hdfs/HdfsBridgeNative/HdfsBridgeNative.h new file mode 100644 index 0000000..1aee33a --- /dev/null +++ b/Hdfs/HdfsBridgeNative/HdfsBridgeNative.h @@ -0,0 +1,183 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#include +#include +#include + +//--------------------------------------------------------------------------------------------------- + +namespace HdfsBridgeNative +{ + struct Instance + { + void* p; + }; + class InstanceInternal; + + struct FileStat + { + void* p; + }; + class FileStatInternal; + + struct Reader + { + void* p; + }; + class ReaderInternal; + + struct Writer + { + void* p; + }; + class WriterInternal; + + struct Env; + + bool Initialize(); + void DisposeString(char* str); + + class HdfsBlockLocInfo + { + public: + HdfsBlockLocInfo(); + ~HdfsBlockLocInfo(); + + long numberOfHosts; /* length of the hosts array */ + char** Hosts; /* hosts storing block replicas, freed by destructor */ + long long Size; /* the size of the block in bytes */ + long long Offset; /* start offset of file associated with this block */ + int fileIndex; /* which file in a directory this block is part of */ + }; + + class FileStatAccessor + { + public: + FileStatAccessor(FileStat* fileStat); + ~FileStatAccessor(); + + void Dispose(); + + char* GetExceptionMessage(); + char* GetBlockExceptionMessage(); + + long long GetFileLength(); + bool IsDir(); + long long GetFileLastModified(); + short GetFileReplication(); + long long GetFileBlockSize(); + + long long GetTotalFileLength(); + long GetNumberOfBlocks(); + HdfsBlockLocInfo* GetBlockInfo(long blockId); + void DisposeBlockInfo(HdfsBlockLocInfo* blockInfo); + long GetNumberOfFiles(); + char** GetFileNameArray(); + void DisposeFileNameArray(long length, char** array); + + private: + void* operator new( size_t ); + void* operator new[]( size_t ); + + Env* m_env; + FileStatInternal* m_stat; + }; + + class ReaderAccessor + { + public: + ReaderAccessor(Reader* reader); + ~ReaderAccessor(); + + void Dispose(); + + char* GetExceptionMessage(); + + long ReadBlock(long long offset, char* buffer, long bufferSize); + + bool Close(); + + private: + void* operator new( size_t ); + void* operator new[]( size_t ); + + Env* m_env; + ReaderInternal* m_rdr; + }; + + class WriterAccessor + { + public: + WriterAccessor(Writer* writer); + ~WriterAccessor(); + + void Dispose(); + + char* GetExceptionMessage(); + + bool WriteBlock(char* buffer, long bufferSize, bool flushAfter); + + bool Close(); + + private: + void* operator new( size_t ); + void* operator new[]( size_t ); + + Env* m_env; + WriterInternal* m_wtr; + }; + + class InstanceAccessor + { + public: + InstanceAccessor(Instance* instance); + ~InstanceAccessor(); + + void Dispose(); + + char* GetExceptionMessage(); + + bool IsFileExists(char* fileName, bool* pExists); + + bool DeleteFileOrDir(char* fileName, bool recursive, bool* pDeleted); + + bool RenameFileOrDir(char* dstFileName, char* srcFileName, bool* pRenamed); + + bool OpenFileStat(const char* fileName, bool getBlockArray, FileStat** pFileStat); + + bool OpenReader(const char* fileName, Reader** pReader); + + bool OpenWriter(const char* fileName, Writer** pWriter); + + private: + void* operator new( size_t ); + void* operator new[]( size_t ); + + Env* m_env; + InstanceInternal* m_inst; + }; + + bool OpenInstance(const char* headNode, long portNumber, Instance** pInstance); +}; +//--------------------------------------------------------------------------------------------------- + diff --git a/Hdfs/HdfsBridgeNative/HdfsBridgeNative.vcxproj b/Hdfs/HdfsBridgeNative/HdfsBridgeNative.vcxproj new file mode 100644 index 0000000..73622a1 --- /dev/null +++ b/Hdfs/HdfsBridgeNative/HdfsBridgeNative.vcxproj @@ -0,0 +1,157 @@ + + + + + Debug + Win32 + + + Debug + x64 + + + Release + Win32 + + + Release + x64 + + + + {95FBF9B7-9407-4554-A74A-3527839BD1B6} + Win32Proj + + + + StaticLibrary + true + v110 + + + StaticLibrary + true + v100 + Unicode + + + StaticLibrary + false + v110 + + + StaticLibrary + false + v100 + + + + + + + + + + + + + + + + + + + true + + + true + ..\..\bin\$(Configuration)\ + + + true + + + true + ..\..\bin\$(Configuration)\ + + + + WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions) + MultiThreadedDebugDLL + Level3 + ProgramDatabase + Disabled + $(JAVA_HOME)\include;$(JAVA_HOME)\include\win32 + + + MachineX86 + true + Windows + + + + + WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions) + MultiThreadedDebugDLL + Level4 + ProgramDatabase + Disabled + $(JAVA_HOME)\include;$(JAVA_HOME)\include\win32 + + + true + Windows + + + jvm.lib + + + $(JAVA_HOME)\lib + + + + + WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions) + MultiThreadedDLL + Level3 + ProgramDatabase + + + MachineX86 + true + Windows + true + true + + + + + WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions) + MultiThreadedDLL + Level4 + ProgramDatabase + $(JAVA_HOME)\include;$(JAVA_HOME)\include\win32 + + + true + Windows + true + true + + + jvm.lib + + + $(JAVA_HOME)\lib + + + + + + + + + + + + \ No newline at end of file diff --git a/Java/DryadAppMaster.java b/Java/DryadAppMaster.java new file mode 100644 index 0000000..758d3aa --- /dev/null +++ b/Java/DryadAppMaster.java @@ -0,0 +1,478 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +package com.microsoft.research; + +//import org.apache.hadoop.mapreduce.v2.jobhistory.JobHistoryUtils; +import java.io.File; +import java.io.IOException; +import java.lang.StringBuilder; +import java.net.InetSocketAddress; +import java.security.PrivilegedAction; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.net.NetUtils; +import org.apache.hadoop.security.SecurityInfo; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.yarn.api.AMRMProtocol; +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment; +import org.apache.hadoop.yarn.api.ApplicationConstants; +import org.apache.hadoop.yarn.api.ContainerManager; +import org.apache.hadoop.yarn.api.protocolrecords.AllocateRequest; +import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse; +import org.apache.hadoop.yarn.api.protocolrecords.FinishApplicationMasterRequest; +import org.apache.hadoop.yarn.api.protocolrecords.FinishApplicationMasterResponse; +import org.apache.hadoop.yarn.api.protocolrecords.GetContainerStatusRequest; +import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterRequest; +import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse; +import org.apache.hadoop.yarn.api.protocolrecords.StartContainerRequest; +import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; +import org.apache.hadoop.yarn.api.records.Container; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.api.records.ContainerLaunchContext; +import org.apache.hadoop.yarn.api.records.ContainerState; +import org.apache.hadoop.yarn.api.records.ContainerStatus; +import org.apache.hadoop.yarn.api.records.FinalApplicationStatus; +import org.apache.hadoop.yarn.api.records.NodeReport; +import org.apache.hadoop.yarn.api.records.Priority; +import org.apache.hadoop.yarn.api.records.Resource; +import org.apache.hadoop.yarn.api.records.ResourceRequest; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.exceptions.YarnRemoteException; +import org.apache.hadoop.yarn.ipc.YarnRPC; +import org.apache.hadoop.yarn.security.ContainerTokenIdentifier; +import org.apache.hadoop.yarn.security.client.ClientRMSecurityInfo; +import org.apache.hadoop.yarn.util.ConverterUtils; +import org.apache.hadoop.yarn.util.ProtoUtils; +import org.apache.hadoop.yarn.util.Records; + +public class DryadAppMaster +{ + private Log log; + private YarnConfiguration yarnConf; + private YarnRPC rpc; + private AMRMProtocol resourceManager; + private ApplicationAttemptId appAttemptID; + private String appMasterHostname; + public final String xcResources; + public final String vertexCmdLine; + public final String jniClassPath; + public final String dryadHome; + private ScheduledExecutorService heartbeatExec; + private ScheduledFuture heartbeatHandle; + private AtomicBoolean shuttingDown; + private AtomicBoolean scheduleProcesses; + private AtomicInteger responseId; + private AtomicInteger nextVertexId; + private Map runningContainers; + private List containersToReturn; + private List resourceRequests; + private int clusterNodeCount = -1; + private final int minMemory; + private final int maxMemory; + + private final int minNodes; + private final int maxNodes; + + private static int YTS_NA = 0; + private static int YTS_Scheduling = 1; + private static int YTS_Running = 2; + private static int YTS_Completed = 3; + private static int YTS_Failed = 4; + + private native void SendVertexState(int vertexId, int state, String nodeName); + + private class VertexInfo + { + public final int vertexId; + public final String nodeName; + + public VertexInfo(int vid, String node) { + vertexId = vid; + nodeName = node; + } + } + + static { + Log slog = LogFactory.getLog("DryadAppMaster"); + slog.info("About to load DryadYarnBridge library"); + System.loadLibrary("DryadYarnBridge"); + slog.info("Loaded DryadYarnBridge library"); + } + + public DryadAppMaster() throws YarnRemoteException, IOException + { + log = LogFactory.getLog("DryadAppMaster"); + log.info("In DryadAppMaster constructor"); + shuttingDown = new AtomicBoolean(); + scheduleProcesses = new AtomicBoolean(true); + responseId = new AtomicInteger(); + nextVertexId = new AtomicInteger(2); //first vertex id is 2 to map to Dryad Vertex Scheduler + runningContainers = new HashMap(); + + containersToReturn = Collections.synchronizedList(new ArrayList()); + resourceRequests = Collections.synchronizedList(new ArrayList()); + + Map envs = System.getenv(); + String containerIdString = envs.get(Environment.CONTAINER_ID.name()); + + if (containerIdString == null) { + // container id should always be set in the env by the framework + StringBuilder sb = new StringBuilder(4096); + for(Map.Entry entry : envs.entrySet()) + { + sb.append("\n\tKey: '"); + sb.append(entry.getKey()); + sb.append("'\tValue: '"); + sb.append(entry.getValue()); + sb.append("'"); + } + + log.error("Couldn't find container id in environment strings. Environment: " + sb); + throw new IllegalArgumentException("ContainerId not set in the environment"); + } + appMasterHostname = envs.get("COMPUTERNAME"); // WINDOWS ONLY + if (appMasterHostname == null) { + throw new IllegalArgumentException( + "COMPUTERNAME not set in the environment"); + } + xcResources = envs.get("XC_RESOURCEFILES"); + jniClassPath = envs.get("JNI_CLASSPATH"); + dryadHome = envs.get("DRYAD_HOME"); + + ContainerId containerId = ConverterUtils.toContainerId(containerIdString); + appAttemptID = containerId.getApplicationAttemptId(); + + minNodes = Integer.parseInt(envs.get("MINIMUM_COMPUTE_NODES")); + maxNodes = Integer.parseInt(envs.get("MAXIMUM_COMPUTE_NODES")); + + File vertexExecutable = new File(envs.get("DRYAD_HOME"), "DryadVertexService.exe"); + vertexCmdLine = vertexExecutable.getAbsolutePath(); + + yarnConf = new YarnConfiguration(); + String dest = yarnConf.get(YarnConfiguration.RM_SCHEDULER_ADDRESS,YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS); + log.warn("Configuration says to connect to ResourceManager at " + dest); + // Connect to the Scheduler of the ResourceManager. + InetSocketAddress rmAddress = NetUtils.createSocketAddr(dest); + log.info("Connecting to ResourceManager at " + rmAddress); + + rpc = YarnRPC.create(yarnConf); + resourceManager = (AMRMProtocol) rpc.getProxy(AMRMProtocol.class, rmAddress, (Configuration)yarnConf); + + heartbeatExec = Executors.newScheduledThreadPool(1); + + String historyUrl = "http://localhost/foo"; // NYI JobHistoryUtils.getHistoryUrl((Configuration)yarnConf, + // appAttemptID.getApplicationId()); + log.info("History url is " + historyUrl); + + RegisterApplicationMasterRequest appMasterRequest = + Records.newRecord(RegisterApplicationMasterRequest.class); + appMasterRequest.setApplicationAttemptId(appAttemptID); + appMasterRequest.setHost(appMasterHostname); + // NYI - for now, until we learn that these are necessary, use dummy values for URL and rpc port + appMasterRequest.setRpcPort(0); + appMasterRequest.setTrackingUrl(historyUrl); + log.info("Registering AppMaster"); + RegisterApplicationMasterResponse response = + resourceManager.registerApplicationMaster(appMasterRequest); + log.info("AppMaster registered"); + minMemory = response.getMinimumResourceCapability().getMemory(); + maxMemory = response.getMaximumResourceCapability().getMemory(); + + // setup the heartbeat to the RM + Runnable heartbeatObj = new Runnable() { + public void run() { heartbeat(); } + }; + + long hbInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, + YarnConfiguration.DEFAULT_RM_AM_EXPIRY_INTERVAL_MS); + // For now, just heartbeat every second, so we learn about failures + hbInterval = 1000; //(hbInterval * 3) / 4; + log.info("Sending heartbeats to the RM every " + hbInterval + " ms."); + + // send the first heartbeat immediately, so we learn how many nodes are in the cluster + heartbeatHandle = heartbeatExec.scheduleAtFixedRate(heartbeatObj, 0, hbInterval, TimeUnit.MILLISECONDS); + } + + private float getProgress() + { + return 0.01f; // NYI + } + + private void heartbeat() + { + // check to see if we should cancel the heartbeat + if (shuttingDown.get()) { + heartbeatHandle.cancel(true); + } + log.info("Sending heartbeat to the RM"); + AllocateResponse response = sendAllocateRequest(); + if (response != null) { + int oldNodeCount = clusterNodeCount; + clusterNodeCount = response.getNumClusterNodes(); + if (clusterNodeCount != oldNodeCount) { + log.info("There are now " + clusterNodeCount + " available nodes on the cluster."); + } + processResponse(response); + } + + } + + private void launchContainer(Container container, ContainerManager cm) + { + ContainerLaunchContext ctx = + Records.newRecord(ContainerLaunchContext.class); + + VertexInfo vi = new VertexInfo(nextVertexId.getAndIncrement(), + container.getNodeId().getHost()); + + // set the environment variable to enable vertex debugging if desired + // also set the CCP_DRYADPROCID and XC_JOBMANAGER variables so the + // vertex knows how to find the GM and knows what its id is + Map vertexEnv = new HashMap(); + //vertexEnv.put("HPCQUERY_DEBUGVERTEXHOST", "HPCQUERY_DEBUGVERTEXHOST"); + //vertexEnv.put("CCP_SCHEDULERTYPE", "LOCAL"); + //vertexEnv.put("HPCQUERY_DEBUGVERTEXHOST", "DEBUG"); + vertexEnv.put("XCJOBMANAGER", appMasterHostname); + vertexEnv.put("CCP_JOBID", appAttemptID.getApplicationId().getId() + ""); + vertexEnv.put("CCP_TASKID", container.getId().getId() + ""); + vertexEnv.put("XC_RESOURCEFILES", xcResources); + + vertexEnv.put("JNI_CLASSPATH", jniClassPath); + vertexEnv.put("DRYAD_HOME", dryadHome); + ctx.setEnvironment(vertexEnv); + + String commandLine = vertexCmdLine + + " 1>stdout-fromcm.txt" + + " 2>stderr-fromcm.txt"; + log.info("Launching a container with command line '" + + vertexCmdLine + "'" + " for vertex " + vi.vertexId + + " on host " + vi.nodeName); + + List commands = new ArrayList(); + commands.add(commandLine); + ctx.setCommands(commands); + + runningContainers.put(container.getId(), vi); + + //SendVertexState(command.vertexId, DPS_Starting); // no need to send this when starting task + + // Send the start request to the ContainerManager + StartContainerRequest startReq = Records.newRecord(StartContainerRequest.class); + startReq.setContainerLaunchContext(ctx); + startReq.setContainer(container); + try { + cm.startContainer(startReq); + } catch (YarnRemoteException|IOException e) { + log.info("Error launching the container: " + e.getMessage()); + } + try { + GetContainerStatusRequest conStatusReq = Records.newRecord(GetContainerStatusRequest.class); + conStatusReq.setContainerId(container.getId()); + ContainerStatus status = cm.getContainerStatus(conStatusReq).getStatus(); + log.info("Container " + status.getContainerId() + " is in the " + status.getState() + " state"); + if (status.getState() == ContainerState.RUNNING) { + log.debug("Calling SendVertexState()"); + SendVertexState(vi.vertexId, YTS_Running, vi.nodeName); + log.debug("Returned from SendVertexState()"); + } else { + log.warn("May not send running state"); + } + } catch (YarnRemoteException|IOException e) { + log.info("Error getting container state: " + e.getMessage()); + } + } + + private void processResponse(AllocateResponse response) + { + // is this the first allocation? + if (scheduleProcesses.compareAndSet(true, false)) { + int numProcessesToStart = Math.max(response.getNumClusterNodes() - 1, maxNodes); //don't schedule a process where the graph manager is running + scheduleProcess(numProcessesToStart); + } + + boolean shouldReboot = response.getReboot(); + List newContainers = response.getAllocatedContainers(); + List finishedContainers = response.getCompletedContainersStatuses(); + List updatedNodes = response.getUpdatedNodes(); + int returnedResponseId = response.getResponseId(); // TODO - how should this be tracked? + log.info(String.format("Response id %d reboot %b containing %d new containers, %d finished containers, and %d updated nodes", + returnedResponseId, shouldReboot, newContainers.size(), + finishedContainers.size(), updatedNodes.size())); + + for (ContainerStatus containerStatus : finishedContainers) { + ContainerId cid = containerStatus.getContainerId(); + log.info("Got container status for containerID= " + + cid + ", state=" + containerStatus.getState() + + ", exitStatus=" + containerStatus.getExitStatus() + + ", diagnostics=" + containerStatus.getDiagnostics()); + + // Need to notify graph manager of current state + VertexInfo vi = runningContainers.remove(cid); + if (vi != null) { + int containerState = 0; + if (containerStatus.getState() == ContainerState.COMPLETE) { + if (containerStatus.getExitStatus() == 0) { + containerState = YTS_Completed; + } else { + containerState = YTS_Failed; + } + SendVertexState(vi.vertexId, containerState, vi.nodeName); + } else { + log.error("Container finished without a COMPLETE status. containerID=" + cid); + } + } + } + + startContainers(newContainers); + } + + public void scheduleProcess(int vertexId, String name, String commandLine) + { + log.info(String.format("scheduleProcess called (external) for vertex %1$d name: '%2$s' commandLine: '%3$s'", + vertexId, name, commandLine)); + + } + + public void scheduleProcess(int numProcesses) + { + log.info("Scheduling " + numProcesses + " processes."); + + ResourceRequest resourceRequest = Records.newRecord(ResourceRequest.class); + + resourceRequest.setHostName("*"); + + Resource capability = Records.newRecord(Resource.class); + capability.setMemory(maxMemory); + resourceRequest.setCapability(capability); + Priority priority = Records.newRecord(Priority.class); + priority.setPriority(1); + resourceRequest.setPriority(priority); + + resourceRequest.setNumContainers(numProcesses); + synchronized(resourceRequests) { + resourceRequests.add(resourceRequest); + } + } + + private AllocateResponse sendAllocateRequest() + { + AllocateRequest request = Records.newRecord(AllocateRequest.class); + int idToSend = responseId.getAndIncrement(); + request.setResponseId(idToSend); + request.setProgress(getProgress()); + request.setApplicationAttemptId(appAttemptID); + int numReleases = 0; + + List localContainersToReturn = new ArrayList(); + synchronized(containersToReturn) { + if (containersToReturn.size() > 0) { + numReleases = containersToReturn.size(); + localContainersToReturn.addAll(containersToReturn); + request.setReleaseList(localContainersToReturn); + containersToReturn.clear(); + } + } + + synchronized (resourceRequests) { + if (resourceRequests.size() > 0) { + request.setAskList(resourceRequests); + } + log.info("Sending request to RM requesting " + resourceRequests.size() + + " nodes and releasing " + numReleases + " nodes."); + + AllocateResponse response = null; + try { + response = resourceManager.allocate(request); + resourceRequests.clear(); + log.info("Received reponse from RM - " + response.getNumClusterNodes() + + " nodes available in cluster"); + return response; + } catch (YarnRemoteException|IOException e) { + log.error("Error communicating with RM: " + e.getMessage() , e); + // TODO - retry communication + return null; + } + } + } + + public void shutdown(boolean immediateShutdown) + { + shuttingDown.set(true); + heartbeatHandle.cancel(immediateShutdown); // if we are shutting down, we can just interrupt the running thread, if necessary + log.info("Shutdown heartbeats to RM"); + + // send the shutdown message to the RM + FinishApplicationMasterRequest request = Records.newRecord(FinishApplicationMasterRequest.class); + request.setAppAttemptId(appAttemptID); + request.setFinishApplicationStatus(FinalApplicationStatus.SUCCEEDED); // NYI - determine success + try { + //response is currently an empty class + FinishApplicationMasterResponse response = resourceManager. finishApplicationMaster(request); + } catch (YarnRemoteException|IOException e) { + log.error("Error communicating with RM: " + e.getMessage() , e); + } + } + + private void startContainers(List newContainers) + { + // DCF TODO: Cache the connections to the cm + for (final Container container : newContainers) { + // Connect to ContainerManager on the allocated container + String cmIpPortStr = container.getNodeId().getHost() + ":" + + container.getNodeId().getPort(); + final InetSocketAddress cmAddress = NetUtils.createSocketAddr(cmIpPortStr); + log.debug("The allocated container contains a resource memory capactity of " + + container.getResource().getMemory()); + log.debug("The allocated container contains a container ID of " + container.getId()); + + // UGI example from DistributedShell AM + UserGroupInformation ugi = + UserGroupInformation.createRemoteUser(container.getId().toString()); + Token token = + ProtoUtils.convertFromProtoFormat(container.getContainerToken(), + cmAddress); + ugi.addToken(token); + ContainerManager cm = ugi.doAs(new PrivilegedAction() { + @Override + public ContainerManager run() { + return ((ContainerManager) rpc.getProxy(ContainerManager.class, + cmAddress, yarnConf)); + } + }); + launchContainer(container, cm); + } + + + } +} diff --git a/Java/DryadLinqYarnApp.java b/Java/DryadLinqYarnApp.java new file mode 100644 index 0000000..2df8371 --- /dev/null +++ b/Java/DryadLinqYarnApp.java @@ -0,0 +1,284 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +package com.microsoft.research; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.lang.Integer; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.parsers.ParserConfigurationException; +import javax.xml.xpath.XPath; +import javax.xml.xpath.XPathConstants; +import javax.xml.xpath.XPathExpression; +import javax.xml.xpath.XPathExpressionException; +import javax.xml.xpath.XPathFactory; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.LocalFileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.net.NetUtils; +import org.apache.hadoop.security.SecurityInfo; +import org.apache.hadoop.yarn.api.ApplicationConstants; +import org.apache.hadoop.yarn.api.ClientRMProtocol; +import org.apache.hadoop.yarn.api.protocolrecords.GetApplicationReportRequest; +import org.apache.hadoop.yarn.api.protocolrecords.GetApplicationReportResponse; +import org.apache.hadoop.yarn.api.protocolrecords.GetClusterMetricsRequest; +import org.apache.hadoop.yarn.api.protocolrecords.GetClusterMetricsResponse; +import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationRequest; +import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse; +import org.apache.hadoop.yarn.api.protocolrecords.SubmitApplicationRequest; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.ApplicationReport; +import org.apache.hadoop.yarn.api.records.ApplicationSubmissionContext; +import org.apache.hadoop.yarn.api.records.ContainerLaunchContext; +import org.apache.hadoop.yarn.api.records.LocalResource; +import org.apache.hadoop.yarn.api.records.LocalResourceType; +import org.apache.hadoop.yarn.api.records.LocalResourceVisibility; +import org.apache.hadoop.yarn.api.records.Resource; +import org.apache.hadoop.yarn.api.records.YarnApplicationState; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.exceptions.YarnRemoteException; +import org.apache.hadoop.yarn.ipc.YarnRPC; +import org.apache.hadoop.yarn.security.client.ClientRMSecurityInfo; +import org.apache.hadoop.yarn.util.ConverterUtils; +import org.apache.hadoop.yarn.util.Records; +import org.w3c.dom.Document; +import org.w3c.dom.NodeList; +import org.xml.sax.SAXException; + +public class DryadLinqYarnApp +{ + + + public static void main( String[] args ) throws YarnRemoteException, InterruptedException, + IOException, URISyntaxException, + ParserConfigurationException, SAXException, + XPathExpressionException + { + Log log = LogFactory.getLog("DryadLinqYarnClient"); + if (args.length != 1) + { + log.error("Incorrect number of arguments."); + System.exit(1); + } + // the queryplan xml file is in args[0] + log.info("Reading query plan from file " + args[0]); + DocumentBuilder builder = DocumentBuilderFactory.newInstance().newDocumentBuilder(); + Document queryPlan = builder.parse(new File(args[0])); + + XPath xpath = XPathFactory.newInstance().newXPath(); + + XPathExpression nameExpr = xpath.compile("/Query/ClusterName"); + //TODO - the cluster name might not be useful + String clusterName = (String) nameExpr.evaluate(queryPlan, XPathConstants.STRING); + + XPathExpression resExpr = xpath.compile("/Query/Resources/Resource"); + NodeList resourceList = (NodeList) resExpr.evaluate(queryPlan, XPathConstants.NODESET); + + String[] localResourcePaths = new String[resourceList.getLength() + 1]; + for (int i = 0; i < resourceList.getLength(); i++) + { + localResourcePaths[i] = resourceList.item(i).getTextContent(); + //System.out.println(localResourcePaths[i]); + } + XPathExpression appNameExpr = xpath.compile("/Query/QueryName"); + String queryName = (String) appNameExpr.evaluate(queryPlan, XPathConstants.STRING); + + XPathExpression minNodesExpr = xpath.compile("/Query/MinimumComputeNodes"); + int minComputeNodes = Integer.parseInt((String) minNodesExpr.evaluate(queryPlan, XPathConstants.STRING)); + + XPathExpression maxNodesExpr = xpath.compile("/Query/MaximumComputeNodes"); + int maxComputeNodes = Integer.parseInt((String) maxNodesExpr.evaluate(queryPlan, XPathConstants.STRING)); + + + File queryPlanFile = new File(args[0]); + String queryPlanLeafName = queryPlanFile.getName(); + + File[] srcDirs = new File[2]; + srcDirs[0] = queryPlanFile.getParentFile(); + srcDirs[1] = new File(System.getProperty("user.dir")); + + // add the query plan to the resources + localResourcePaths[localResourcePaths.length - 1] = queryPlanLeafName; + + ClientRMProtocol applicationsManager; + YarnConfiguration yarnConf = new YarnConfiguration(); + + String dest = yarnConf.get(YarnConfiguration.RM_ADDRESS, YarnConfiguration.DEFAULT_RM_ADDRESS); + log.info("Connecting to dest " + dest); + InetSocketAddress rmAddress = NetUtils.createSocketAddr(dest); + YarnRPC rpc = YarnRPC.create(yarnConf); + applicationsManager = ((ClientRMProtocol) rpc.getProxy( + ClientRMProtocol.class, rmAddress, (Configuration)yarnConf)); + + if (maxComputeNodes == -1) { + // find the max number of nodes in the cluster + GetClusterMetricsResponse metricsResponse = applicationsManager.getClusterMetrics( + Records.newRecord(GetClusterMetricsRequest.class)); + maxComputeNodes = metricsResponse.getClusterMetrics().getNumNodeManagers(); + log.info("Set maxComputeNodes to " + maxComputeNodes); + } + + GetNewApplicationRequest request = + Records.newRecord(GetNewApplicationRequest.class); + GetNewApplicationResponse response = + applicationsManager.getNewApplication(request); + + ApplicationId appId = response.getApplicationId(); + log.info("Got new ApplicationId=" + appId); + log.info("Min Resource Capability: " + response.getMinimumResourceCapability().getMemory()); + log.info("Max Resource Capability: " + response.getMaximumResourceCapability().getMemory()); + + Map localResources = new HashMap(); + + // copy the files to hdfs under a job directory + // and add them to local resources + // TODO: Use content based hashing to avoid copying files + FileSystem fs = FileSystem.get(yarnConf); + Path homeDir = fs.getHomeDirectory(); + + Path resourceHdfsDir = new Path(homeDir, "dlbin/" + appId); + StringBuilder resourceString = new StringBuilder(); + resourceString.append(resourceHdfsDir); + + for(int i = 0; i < localResourcePaths.length; i++) + { + boolean sourceFound = false; + File resourceFile = new File(localResourcePaths[i]); + resourceString.append(','); + resourceString.append(localResourcePaths[i]); + if (!resourceFile.exists()) + { + for(int j = 0; j < srcDirs.length && !sourceFound; j++) + { + resourceFile = new File(srcDirs[j], localResourcePaths[i]); + if (resourceFile.exists()) + { + sourceFound = true; + } + } + } + else + { + sourceFound = true; + } + if (!sourceFound) + { + throw new FileNotFoundException("Unable to find local resource: " + localResourcePaths[i]); + } + Path srcPath = new Path(resourceFile.toURI()); + String leafName = new File(localResourcePaths[i]).getName(); + Path remotePath = new Path(resourceHdfsDir, leafName); + log.info("Copying file '" + leafName + "' to '" + remotePath + "'"); + fs.copyFromLocalFile(srcPath, remotePath); + + + FileStatus remoteStatus = fs.getFileStatus(remotePath); + LocalResource amResource = Records.newRecord(LocalResource.class); + amResource.setType(LocalResourceType.FILE); + amResource.setVisibility(LocalResourceVisibility.APPLICATION); + amResource.setResource(ConverterUtils.getYarnUrlFromPath(remotePath)); + amResource.setTimestamp(remoteStatus.getModificationTime()); + log.info("Set file modification time to " + + remoteStatus.getModificationTime()); + amResource.setSize(remoteStatus.getLen()); + log.info("Set file length to " + remoteStatus.getLen()); + localResources.put(remotePath.getName(), amResource); + } + + int amMemory = response.getMaximumResourceCapability().getMemory(); + // request the min amount of memory, which should schedule the am on its own node + //int amMemory = response.getMaximumResourceCapability().getMemory(); + log.info("Set amMemory=" + amMemory); + log.info("Creating the ApplicationSubmissionContext"); + // Create a new ApplicationSubmissionContext + ApplicationSubmissionContext appContext = Records.newRecord(ApplicationSubmissionContext.class); + log.info("Setting the ApplicationId"); + // set the ApplicationId + appContext.setApplicationId(response.getApplicationId()); + // set the application name + appContext.setApplicationName(queryName); + // set the queue to the default queue + appContext.setQueue("default"); + log.info("Getting a ContainerLaunchContext"); + // Create a new container launch context for the AM's container + ContainerLaunchContext amContainer = Records.newRecord(ContainerLaunchContext.class); + log.info("Got a ContainerLaunchContext"); + + // Set the local resources into the launch context + amContainer.setLocalResources(localResources); + + // get a copy of the local environment variables + Map envs = System.getenv(); + + // The environment for the am + Map env = new HashMap(); + env.put("CCP_JOBID", appId.getId() + ""); + env.put("CCP_DRYADPROCID", "1"); + env.put("XC_RESOURCEFILES", resourceString.toString()); + env.put("MINIMUM_COMPUTE_NODES", minComputeNodes + ""); + env.put("MAXIMUM_COMPUTE_NODES", maxComputeNodes + ""); + env.put("DRYAD_HOME", envs.get("DRYAD_HOME")); + env.put("JNI_CLASSPATH", envs.get("JNI_CLASSPATH")); + //log.info("DRYAD_HOME env variable is '" + envs.get("DRYAD_HOME") + "'"); + amContainer.setEnvironment(env); + + File jmExecutable = new File(envs.get("DRYAD_HOME"), "LinqToDryadJM_managed.exe"); + String jmCmdLine = jmExecutable.getAbsolutePath(); + + // Construct the command to be executed on the launched container + String command = jmCmdLine + " " + + queryPlanLeafName + " " + + //" --break " + + " 1>stdout.txt" + + " 2>stderr.txt"; + + List commands = new ArrayList(); + commands.add(command); + amContainer.setCommands(commands); + + Resource capability = Records.newRecord(Resource.class); + capability.setMemory(amMemory); + appContext.setResource(capability); + appContext.setAMContainerSpec(amContainer); + + // Create the request to send to the ApplicationsManager + SubmitApplicationRequest appRequest = + Records.newRecord(SubmitApplicationRequest.class); + appRequest.setApplicationSubmissionContext(appContext); + applicationsManager.submitApplication(appRequest); + + System.out.println(appId); + } + } diff --git a/Java/HdfsBridge.java b/Java/HdfsBridge.java new file mode 100644 index 0000000..72cd2c5 --- /dev/null +++ b/Java/HdfsBridge.java @@ -0,0 +1,604 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + + +package GSLHDFS; + +import java.io.*; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.BlockLocation; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hdfs.DistributedFileSystem; + +//-------------------------------------------------------------------------------- + +public class HdfsBridge +{ + //-------------------------------------------------------------------------------- + + public static void main(String[] args) throws IOException + { + Instance i = new Instance(); + int ret = i.Connect("svc-d1-17",9000); + if (ret != SUCCESS) + { + System.out.println("Failed to connect"); + return; + } + + String fileName = "/data/inputPart0.txt"; + + FileStatus fs = i.OpenFileStatus(fileName, false); + System.out.println(fs.getPath().toUri().getPath()); + + int rc = i.IsFileExist(fileName); + System.out.println(rc); + + if (rc == 1) + { + Instance.Reader r = i.OpenReader(fileName); + + if (r != null) + { + int nRead = 0; + long offset = 0; + do + { + Instance.Reader.Block b = r.ReadBlock(offset, 64 * 1024); + nRead = b.ret; + if (nRead != -1) + { + System.out.println("Read " + nRead + " bytes at " + offset); + offset += nRead; + } + } while (nRead >= 0); + + ret = r.Close(); + if (ret != SUCCESS) + { + System.out.println("Failed to close"); + } + } + } + + ret = i.Disconnect(); + if (ret != SUCCESS) + { + System.out.println("Failed to disconnect"); + } + //String fileToRead = "/data/tpch/customer_1G_128MB.txt"; + //String content = HdfsBridge.ReadBlock(fileToRead, 0); + //System.out.println(rc); + } + + //-------------------------------------------------------------------------------- + + public static int SUCCESS = 1; + public static int FAILURE = 0; + + //-------------------------------------------------------------------------------- + + + //-------------------------------------------------------------------------------- + + public static class Instance + { + private DistributedFileSystem dfs = null; + public String exceptionMessage = null; + + //-------------------------------------------------------------------------------- + + public static class Reader + { + public static class Block + { + public int ret; + public byte[] buffer; + } + + private FSDataInputStream dis = null; + public String exceptionMessage = null; + + public void Open(DistributedFileSystem dfs, String fileName) throws IOException + { + Path path = new Path(fileName); + dis = dfs.open(path); + } + + public Block ReadBlock(long blockOffset, int bytesRequested) + { + Block block = new Block(); + block.buffer = null; + + if (dis == null) + { + exceptionMessage = "ReadBlock called on closed reader"; + block.ret = -2; + return block; + } + + block.buffer = new byte[bytesRequested]; + + int numBytesRead = -2; + try + { + numBytesRead = dis.read(blockOffset, block.buffer, 0, bytesRequested); + } + catch (IOException e1) + { + exceptionMessage = e1.getMessage(); + block.buffer = null; + block.ret = -2; + return block; + } + + block.ret = numBytesRead; + return block; + } + + public int Close() + { + int ret = SUCCESS; + + if (dis != null) + { + try { + dis.close(); + } + catch (IOException e1) + { + exceptionMessage = e1.getMessage(); + ret = FAILURE; + } + + dis = null; + } + + return ret; + } + } + + public static class Writer + { + private FSDataOutputStream dos = null; + public String exceptionMessage = null; + + public void Open(DistributedFileSystem dfs, String fileName) throws IOException + { + Path path = new Path(fileName); + dos = dfs.create(path); + } + + public int WriteBlock(byte[] buffer, boolean flushAfter) + { + if (dos == null) + { + exceptionMessage = "WriteBlock called on closed writer"; + return FAILURE; + } + + try + { + dos.write(buffer); + if (flushAfter) + { + dos.flush(); + } + } + catch (IOException e1) + { + exceptionMessage = e1.getMessage(); + return FAILURE; + } + + return SUCCESS; + } + + public int Close() + { + int ret = SUCCESS; + + if (dos != null) + { + try + { + dos.close(); + } + catch (IOException e1) + { + exceptionMessage = e1.getMessage(); + ret = FAILURE; + } + + dos = null; + } + + return ret; + } + } + + public static class BlockLocations + { + private BlockLocation[] bls = null; + private int[] fileIndex = null; + private String[] fileName = null; + public String exceptionMessage = null; + public long fileSize = -1; + + BlockLocations( + BlockLocation[] b, + int[] fIndex, + String[] fName, + long fSize) + { + bls = b; + fileIndex = fIndex; + fileName = fName; + fileSize = fSize; + } + + public int GetNumberOfFileNames() + { + return fileName.length; + } + + public String[] GetFileNames() + { + return fileName; + } + + public int GetNumberOfBlocks() + { + return bls.length; + } + + public long GetBlockOffset(int blockId) + { + return bls[blockId].getOffset(); + } + + public long GetBlockLength(int blockId) + { + return bls[blockId].getLength(); + } + + public String[] GetBlockHosts(int blockId) + { + BlockLocation bl = bls[blockId]; + String[] hosts = null; + + exceptionMessage = null; + try + { + hosts = bl.getHosts(); + } + catch (IOException e1) + { + exceptionMessage = e1.getMessage(); + return null; + } + return hosts; + } + + public String[] GetBlockNames(int blockId) + { + BlockLocation bl = bls[blockId]; + String[] names = null; + + exceptionMessage = null; + try + { + names = bl.getNames(); + } + catch (IOException e1) + { + exceptionMessage = e1.getMessage(); + return null; + } + return names; + } + + public int GetBlockFileId(int blockId) + { + return fileIndex[blockId]; + } + } + + public int Connect(String inputNameNode, long inputPortNumber) + { + Configuration config = new Configuration(); + config.set("fs.defaultFS", "hdfs://" + + inputNameNode + ":"+ inputPortNumber +""); + + exceptionMessage = null; + + try + { + dfs = (DistributedFileSystem)FileSystem.get(config); + } + catch (IOException e1) + { + exceptionMessage = e1.getMessage(); + return FAILURE; + } + + return SUCCESS; + } + + //-------------------------------------------------------------------------------- + + public int Disconnect() + { + int ret = SUCCESS; + + if (dfs != null) + { + exceptionMessage = null; + + try + { + dfs.close(); + } + catch (IOException e1) + { + exceptionMessage = e1.getMessage(); + ret = FAILURE; + } + + dfs = null; + } + + return ret; + } + + //-------------------------------------------------------------------------------- + + public int IsFileExist(String fileName) + { + if (dfs == null) + { + exceptionMessage = "IsFileExist called on disconnected instance"; + return -1; + } + + exceptionMessage = null; + + try + { + Path path = new Path(fileName); + return (dfs.exists(path)) ? 1 : 0; + } + catch (IOException e1) + { + exceptionMessage = e1.getMessage(); + return -1; + } + } + + public int DeleteFile(String fileName, boolean recursive) + { + if (dfs == null) + { + exceptionMessage = "DeleteFile called on disconnected instance"; + return -1; + } + + exceptionMessage = null; + + try + { + Path path = new Path(fileName); + return (dfs.delete(path, recursive)) ? 1 : 0; + } + catch (IOException e1) + { + exceptionMessage = e1.getMessage(); + return -1; + } + } + + public int RenameFile(String dstFileName, String srcFileName) + { + if (dfs == null) + { + exceptionMessage = "RenameFile called on disconnected instance"; + return -1; + } + + exceptionMessage = null; + + try + { + Path dstPath = new Path(dstFileName); + Path srcPath = new Path(srcFileName); + return (dfs.rename(srcPath, dstPath)) ? 1 : 0; + } + catch (IOException e1) + { + exceptionMessage = e1.getMessage(); + return -1; + } + } + + public Reader OpenReader(String fileName) + { + if (dfs == null) + { + System.out.println("OpenReader called on disconnected instance\n"); + return null; + } + + Reader r = new Reader(); + + exceptionMessage = null; + + try + { + r.Open(dfs, fileName); + } + catch (IOException e1) + { + exceptionMessage = e1.getMessage(); + return null; + } + + return r; + } + + public Writer OpenWriter(String fileName) + { + if (dfs == null) + { + System.out.println("OpenWriter called on disconnected instance\n"); + return null; + } + + Writer w = new Writer(); + + exceptionMessage = null; + + try + { + w.Open(dfs, fileName); + } + catch (IOException e1) + { + exceptionMessage = e1.getMessage(); + return null; + } + + return w; + } + + public FileStatus OpenFileStatus(String fileOrDirectoryName, boolean getLocations) + { + if (dfs == null) + { + exceptionMessage = "OpenFileStatus called on disconnected instance"; + return null; + } + + exceptionMessage = null; + + try + { + Path path = new Path(fileOrDirectoryName); + + return dfs.getFileStatus(path); + } + catch (IOException e1) + { + exceptionMessage = e1.getMessage(); + return null; + } + } + + public BlockLocations OpenBlockLocations(FileStatus fileStatus, boolean getBlocks) + { + exceptionMessage = null; + try + { + FileStatus[] expanded; + if (fileStatus.isDirectory()) + { + expanded = dfs.listStatus(fileStatus.getPath()); + for (int i=0; i + /// The Resource attribute is used to specify the computation cost of a function. + /// IsStateful asserts that the function is stateful; IsExpensive asserts that + /// the function is expensive to compute. The information is useful in generating + /// better execution plan. For example, expensive associative aggregation + /// functions can use multiple aggregation layers. + /// + [AttributeUsage(AttributeTargets.Method, AllowMultiple = false)] + internal sealed class ResourceAttribute : Attribute + { + private bool m_isStateful; + private bool m_isExpensive; + + public ResourceAttribute() + { + this.m_isStateful = true; + this.m_isExpensive = false; + } + + public bool IsStateful + { + get { return this.m_isStateful; } + set { this.m_isStateful = value; } + } + + public bool IsExpensive + { + get { return this.m_isExpensive; } + set { this.m_isExpensive = value; } + } + } + + [AttributeUsage(AttributeTargets.Method, AllowMultiple = false)] + public sealed class DecomposableAttribute : Attribute + { + Type m_decompositionType; + + public DecomposableAttribute(Type decompositionType) + { + m_decompositionType = decompositionType; + } + + public Type DecompositionType + { + get { + return m_decompositionType; + } + } + } + + /// + /// Indicates that a method can be used as an associative aggregation method. + /// The aggregation can either be via recursive calls to the tagged method, or + /// via top-level calls to the tagged method, followed by recursive calls to + /// a RecursiveAccumulate method. + /// + /// + /// If a recursive accumulator method is necessary, create type that implements + /// IAssociative and provide that to the ctor of this type. + /// + [AttributeUsage(AttributeTargets.Method, AllowMultiple = false)] + public sealed class AssociativeAttribute : Attribute + { + private Type m_associativeType; + + /// + /// Creates an instance of AssociativeAttribute + /// + public AssociativeAttribute() + { + } + + /// + /// Creates an instance of AssociativeAttribute, with an associated type that provides + /// a recursive-accumulator method. + /// + /// + /// During aggregation, the recursiveAccumulator will be used to aggregate items arising + /// from the main aggregation. + /// + /// A type that implements IAssociative{T,T} where T + /// is the output type of methods that are decorated with this attribute. + public AssociativeAttribute(Type associativeType) + { + this.m_associativeType = associativeType; + } + + /// + /// Type that implements IAssociative{T,T} where T is the output type of methods + /// that are decorated with this attribute. + /// + public Type AssociativeType + { + get { return this.m_associativeType; } + } + } + + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct, AllowMultiple = false, Inherited=false)] + public sealed class CustomHpcSerializerAttribute : Attribute + { + public CustomHpcSerializerAttribute(Type serializerType) + { + SerializerType = serializerType; + + // We need to make sure serializerType implements IHpcSerializer + // However we will defer that check until DryadCodeGen.FindCustomSerializerType(), because + // 1) we don't have access to here but it's available at code gen time, and + // 2) because an exception coming from the attribute ctor leads to an obscure failure. + + } + + public Type SerializerType { private set; get; } + } + + internal static class AttributeSystem + { + private static Dictionary attribMap = new Dictionary(); + + internal static void Add(LambdaExpression func, Attribute attrib) + { + Attribute[] attribs; + if (attribMap.TryGetValue(func, out attribs)) + { + Attribute[] oldAttribs = attribs; + attribs = new Attribute[oldAttribs.Length+1]; + Array.Copy(oldAttribs, attribs, oldAttribs.Length); + attribs[oldAttribs.Length] = attrib; + attribMap.Remove(func); + } + else + { + attribs = new Attribute[] { attrib }; + } + attribMap[func] = attribs; + } + + private static Attribute[] Get(LambdaExpression func, Type attribType) + { + Attribute[] attribs; + attribMap.TryGetValue(func, out attribs); + if (attribs != null) + { + ArrayList alist = new ArrayList(); + foreach (var x in attribs) + { + if (x.GetType() == attribType) + { + alist.Add(x); + } + } + attribs = (Attribute[])alist.ToArray(attribType); + } + return attribs; + } + + internal static Attribute[] GetAttribs(LambdaExpression func, Type attribType) + { + Attribute[] attribs1 = AttributeSystem.Get(func, attribType); + Attribute[] attribs2 = null; + if (func.Body is MethodCallExpression) + { + MethodCallExpression expr = (MethodCallExpression)func.Body; + attribs2 = Attribute.GetCustomAttributes(expr.Method, attribType); + } + else if (func.Body is NewExpression && ((NewExpression)func.Body).Constructor != null) + { + NewExpression expr = (NewExpression)func.Body; + attribs2 = Attribute.GetCustomAttributes(expr.Constructor, attribType); + } + else if (func.Body is BinaryExpression) + { + BinaryExpression expr = (BinaryExpression)func.Body; + if (expr.Method != null) + { + attribs2 = Attribute.GetCustomAttributes(expr.Method, attribType); + } + } + else if (func.Body is InvocationExpression) + { + InvocationExpression expr = (InvocationExpression)func.Body; + if (expr.Expression is LambdaExpression) + { + attribs2 = GetAttribs((LambdaExpression)expr.Expression, attribType); + } + } + if (attribs1 == null) + { + return attribs2; + } + if (attribs2 == null) + { + return attribs1; + } + ArrayList alist = new ArrayList(); + foreach (var x in attribs1) + { + alist.Add(x); + } + foreach (var x in attribs2) + { + alist.Add(x); + } + Attribute[] attribs = (Attribute[])alist.ToArray(attribType); + return attribs; + } + + internal static Attribute GetAttrib(Expression expr, Type attribType) + { + Attribute[] attribs = null; + if (expr is MethodCallExpression) + { + attribs = Attribute.GetCustomAttributes(((MethodCallExpression)expr).Method, attribType); + } + else if (expr is NewExpression && ((NewExpression) expr).Constructor != null) + { + attribs = Attribute.GetCustomAttributes(((NewExpression)expr).Constructor, attribType); + } + else if (expr is LambdaExpression) + { + attribs = GetAttribs((LambdaExpression)expr, attribType); + } + + if (attribs == null || attribs.Length == 0) return null; + return attribs[0]; + } + + internal static DecomposableAttribute GetDecomposableAttrib(Expression expr) + { + return (DecomposableAttribute)GetAttrib(expr, typeof(DecomposableAttribute)); + } + + internal static AssociativeAttribute GetAssociativeAttrib(Expression expr) + { + return (AssociativeAttribute)GetAttrib(expr, typeof(AssociativeAttribute)); + } + + internal static ResourceAttribute GetResourceAttrib(LambdaExpression func) + { + return (ResourceAttribute)GetAttrib(func, typeof(ResourceAttribute)); + } + + internal static FieldMappingAttribute[] GetFieldMappingAttribs(LambdaExpression func) + { + Attribute[] a = GetAttribs(func, typeof(FieldMappingAttribute)); + if (a == null || a.Length == 0) return null; + return (FieldMappingAttribute[])a; + } + + internal static bool DoAutoTypeInference(HpcLinqContext context, Type type) + { + if (!StaticConfig.AllowAutoTypeInference) return false; + object[] a = type.GetCustomAttributes(typeof(AutoTypeInferenceAttribute), true); + return (a.Length != 0); + } + + internal static DistinctAttribute GetDistinctAttrib(LambdaExpression func) + { + return (DistinctAttribute)GetAttrib(func, typeof(DistinctAttribute)); + } + + internal static bool FieldCanBeNull(FieldInfo finfo) + { + if (finfo == null || finfo.FieldType.IsValueType) return false; + + object[] attribs = finfo.GetCustomAttributes(typeof(NullableAttribute), true); + if (attribs.Length == 0) + { + return StaticConfig.AllowNullFields; + } + return ((NullableAttribute)attribs[0]).CanBeNull; + } + + internal static bool RecordCanBeNull(HpcLinqContext context, Type type) + { + if (type == null || type.IsValueType) return false; + + object[] attribs = type.GetCustomAttributes(typeof(NullableAttribute), true); + if (attribs.Length == 0) + { + return StaticConfig.AllowNullRecords; + } + return ((NullableAttribute)attribs[0]).CanBeNull; + } + } +} diff --git a/LinqToDryad/BitVector.cs b/LinqToDryad/BitVector.cs new file mode 100644 index 0000000..39059d6 --- /dev/null +++ b/LinqToDryad/BitVector.cs @@ -0,0 +1,114 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq.Internal +{ + public struct BitVector + { + private byte[] m_array; + + public BitVector(int length) + { + this.m_array = new byte[(length + 7) / 8]; + } + + public BitVector(bool[] values) + { + this.m_array = new byte[(values.Length + 7) / 8]; + for (int i = 0; i < values.Length; i++) + { + if (values[i]) + { + m_array[i / 8] |= (byte)(1 << (i % 8)); + } + } + } + + private BitVector(byte[] values) + { + this.m_array = values; + } + + public bool this[int index] + { + get { return this.Get(index); } + } + + public bool Get(int index) + { + int idx = index / 8; + return ((idx < this.m_array.Length) && + (this.m_array[idx] & (1 << (index % 8))) != 0); + } + + public void Set(int index) + { + m_array[index / 8] |= (byte)(1 << (index % 8)); + } + + public void SetAll(bool value) + { + byte fillValue = 0; + if (value) fillValue = 0xff; + for (int i = 0; i < this.m_array.Length; i++) + { + this.m_array[i] = fillValue; + } + } + + private void WriteInner(HpcBinaryWriter writer) + { + int len; + for (len = this.m_array.Length - 1; len >= 0; len--) + { + if (this.m_array[len] != 0) break; + } + len++; + writer.WriteCompact(len); + for (int i = 0; i < len; i++) + { + writer.Write(this.m_array[i]); + } + } + + public static BitVector Read(HpcBinaryReader reader) + { + Int32 len = reader.ReadCompactInt32(); + byte[] values = new byte[len]; + for (int i = 0; i < len; i++) + { + values[i] = reader.ReadUByte(); + } + return new BitVector(values); + } + + public static void Write(HpcBinaryWriter writer, BitVector bv) + { + bv.WriteInner(writer); + } + } + +} diff --git a/LinqToDryad/CodeGenHelper.cs b/LinqToDryad/CodeGenHelper.cs new file mode 100644 index 0000000..596ba56 --- /dev/null +++ b/LinqToDryad/CodeGenHelper.cs @@ -0,0 +1,152 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.IO; +using System.Reflection; +using System.Reflection.Emit; +using System.Linq; +using System.Linq.Expressions; +using System.Diagnostics; +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq.Internal +{ + //these are involved in generated code for member lookup + //eg for direct access to fields of IndexedValue. and probably other situations too. + public delegate S GetObjFieldDelegate(T obj); + public delegate void SetObjFieldDelegate(T obj, S value); + + public delegate S GetStructFieldDelegate(out T obj); + public delegate void SetStructFieldDelegate(out T obj, S value); + + //this class is internal-public for Get ObjFieldDelegate etc. + public static class CodeGenHelper + { + public static GetObjFieldDelegate GetObjField(string fname) + { + Type typeT = typeof(T); + if (typeT.IsValueType) + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, + SR.Internal_CannotBeUsedForValueType); + } + FieldInfo finfo = typeT.GetField(fname, BindingFlags.Instance|BindingFlags.Public|BindingFlags.NonPublic); + if (finfo == null) + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, + String.Format(SR.TypeDoesNotContainRequestedField, typeT.Name, fname)); + } + DynamicMethod dm = new DynamicMethod("GetObjField", + typeof(S), + new Type[] { typeT }, + typeT); + ILGenerator ilgen = dm.GetILGenerator(); + ilgen.Emit(OpCodes.Ldarg_0); + ilgen.Emit(OpCodes.Ldfld, finfo); + ilgen.Emit(OpCodes.Ret); + + return (GetObjFieldDelegate)dm.CreateDelegate(typeof(GetObjFieldDelegate)); + } + + public static SetObjFieldDelegate SetObjField(string fname) + { + Type typeT = typeof(T); + if (typeT.IsValueType) + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, + SR.Internal_CannotBeUsedForValueType); + } + FieldInfo finfo = typeT.GetField(fname, BindingFlags.Instance|BindingFlags.Public|BindingFlags.NonPublic); + if (finfo == null) + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, + String.Format(SR.TypeDoesNotContainRequestedField, typeT.Name, fname)); + } + DynamicMethod dm = new DynamicMethod("SetObjField", + typeof(void), + new Type[] { typeT, typeof(S) }, + typeT); + ILGenerator ilgen = dm.GetILGenerator(); + ilgen.Emit(OpCodes.Ldarg_0); + ilgen.Emit(OpCodes.Ldarg_1); + ilgen.Emit(OpCodes.Stfld, finfo); + ilgen.Emit(OpCodes.Ret); + + return (SetObjFieldDelegate)dm.CreateDelegate(typeof(SetObjFieldDelegate)); + } + + public static GetStructFieldDelegate GetStructField(string fname) + { + Type typeT = typeof(T); + if (!typeT.IsValueType) + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, + SR.CannotBeUsedForReferenceType); + } + FieldInfo finfo = typeT.GetField(fname, BindingFlags.Instance|BindingFlags.Public|BindingFlags.NonPublic); + if (finfo == null) + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, + String.Format(SR.TypeDoesNotContainRequestedField, typeT.Name, fname)); + } + DynamicMethod dm = new DynamicMethod("GetStructField", + typeof(S), + new Type[] { typeT.MakeByRefType() }, + typeT); + ILGenerator ilgen = dm.GetILGenerator(); + ilgen.Emit(OpCodes.Ldarg_0); + ilgen.Emit(OpCodes.Ldfld, finfo); + ilgen.Emit(OpCodes.Ret); + + return (GetStructFieldDelegate)dm.CreateDelegate(typeof(GetStructFieldDelegate)); + } + + public static SetStructFieldDelegate SetStructField(string fname) + { + Type typeT = typeof(T); + if (!typeT.IsValueType) + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, + SR.CannotBeUsedForReferenceType); + } + FieldInfo finfo = typeT.GetField(fname, BindingFlags.Instance|BindingFlags.Public|BindingFlags.NonPublic); + if (finfo == null) + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, + String.Format(SR.TypeDoesNotContainRequestedField, typeT.Name, fname)); + } + DynamicMethod dm = new DynamicMethod("SetStructField", + typeof(void), + new Type[] { typeT.MakeByRefType(), typeof(S) }, + typeT); + ILGenerator ilgen = dm.GetILGenerator(); + ilgen.Emit(OpCodes.Ldarg_0); + ilgen.Emit(OpCodes.Ldarg_1); + ilgen.Emit(OpCodes.Stfld, finfo); + ilgen.Emit(OpCodes.Ret); + + return (SetStructFieldDelegate)dm.CreateDelegate(typeof(SetStructFieldDelegate)); + } + } +} diff --git a/LinqToDryad/Constants.cs b/LinqToDryad/Constants.cs new file mode 100644 index 0000000..b788902 --- /dev/null +++ b/LinqToDryad/Constants.cs @@ -0,0 +1,163 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//------------------------------------------------------------------------------ +// +// Constants used by managed code in Dryad +// +//------------------------------------------------------------------------------ + +namespace Microsoft.Hpc.Dryad +{ + using System; + + internal class Constants + { + // + // Constants for all WCF nettcp bindings + // + public const int MaxReceivedMessageSize = 16 * 1024 * 1024; + public const int MaxBufferPoolSize = 16 * 1024 * 1024; + public const int MaxConnections = 1024; + public const int ListenBacklog = 256; + public static readonly TimeSpan SendTimeout = new TimeSpan(0, 2, 0); + public static readonly TimeSpan ReceiveTimeout = new TimeSpan(0, 10, 0); + public static readonly TimeSpan VertexSendTimeout = new TimeSpan(0, 1, 0); + + // For Seal and Delete Node, use 6 minutes for the WCF timeout because of the 5 minute SQL timeout for the DB call + // Otherwise, default to 2 minutes + // TODO: Post-SP3, re-examine longest running operations, as they slow down service failure and failover time + public static readonly TimeSpan DscOperationTimeout = new TimeSpan(0, 2, 0); + public static readonly TimeSpan DscExtendedOperationTimeout = new TimeSpan(0, 6, 0); + + public static readonly String DryadConnectionString = String.Empty; + + public const string CommonRegistryPath = @"SOFTWARE\Microsoft\HPC"; + public const string HpcSchedulerNameString = "ClusterName"; + public const string HpcInstallPath = "BinDir"; + public const string DscServerName = "DscServiceNodeName"; + public const string DscConnectionFormat = @"net.tcp://{0}:{1}/HpcDsc/Service/DscService"; + public const string DscServiceDefaultScheme = "hpcdsc"; + public const UInt32 DscServiceDefaultPort = 6498; + + public const string ServiceLocationString = @"ServiceLocation"; + + public const String jobManager = "XCJOBMANAGER"; + + public const string vertexAddrFormat = "net.tcp://{0}:8050/{1}/"; // net.tcp://:8050// + public const string vertexCallbackAddrFormat = "net.tcp://{0}:8051/{1}/"; // net.tcp://:8051// + public const string vertexCallbackServiceName = "DryadVertexCallback"; + public const string vertexServiceName = "DryadVertexService"; + public const string vertexFileServiceName = "DryadVertexFileService"; + public const int vertexFileChunkSize = 1024 * 16; + + public const string vertexCountEnvVar = "HPC_VERTEXCOUNT"; + public const string vertexEnvVarFormat = "HPC_VERTEX{0}"; + public const string vertexSvcInstanceEnvVar = "HPC_VERTEXSVCINST"; + public const string vertexSvcLocalAddrEnvVar = "CCP_DRYADVERTEXLOCALADDRESS"; + + public const string schedulerTypeEnvVar = "CCP_SCHEDULERTYPE"; + public const string schedulerTypeLocal = "LOCAL"; + public const string schedulerTypeCluster = "CLUSTER"; + public const string schedulerTypeAzure = "AZURE"; + public const string debugAzure = "DEBUG_AZURE"; + + // Recognized values are: OFF, CRITICAL, ERROR, WARN, INFO, VERBOSE + public const string traceLevelEnvVar = "CCP_DRYADTRACELEVEL"; + public const string traceOff = "OFF"; + public const string traceCritical = "CRITICAL"; + public const string traceError = "ERROR"; + public const string traceWarning = "WARN"; + public const string traceInfo = "INFO"; + public const string traceVerbose = "VERBOSE"; + + public const int traceOffNum = 0; + public const int traceCriticalNum = 1; + public const int traceErrorNum = 3; + public const int traceWarningNum = 7; + public const int traceInfoNum = 15; + public const int traceVerboseNum = 31; + + public const string VertexSecurityEnvVar = "HPC_VERTEX_SECURITY"; + + // SchedulerHelper environment variables + public const string clusterNameEnvVar = "CCP_CLUSTER_NAME"; + public const string jobIdEnvVar = "CCP_JOBID"; + public const string taskIdEnvVar = "CCP_TASKID"; + public const string nodesEnvVar = "CCP_NODES"; + public const string jobNameEnvVar = "CCP_JOBNAME"; + public const string requiredNodesEnvVar = "CCP_REQUIREDNODES"; + public const string localProcessComputeNodesEnvVar = "CCP_LOCALPROCESSCOMPUTENODES"; + + // DrError.h values used in managed code + // need to keep this section in sync with drerror.h changes... + public const uint DrError_VertexReceivedTermination = 0x830A0003; + public const uint DrError_VertexCompleted = 0x830A0016; + public const uint DrError_VertexError = 0x830A0017; + public const uint DrError_VertexInitialization = 0x830A0019; + public const uint DrError_ProcessingInterrupted = 0x830A001A; + public const uint DrError_VertexHostLostCommunication = 0x830A0FFF; + + // DSC Share Names + public const string DscTempShare = "HpcTemp"; + public const string DscDataShare = "HpcData"; + public const string RuntimeShareConfig = "HPC_RUNTIMESHARE"; + + // Cluster name + public const string ClusterNameConfig = "CCP_CLUSTER_NAME"; + + // NodeAdmin constants + // Retain time set to one day + // todo: this should be configurable + public static readonly TimeSpan RetainTime = new TimeSpan(1, 0, 0, 0); + public static readonly TimeSpan FileTimeStampMarginForGC = new TimeSpan(0, 0, 5, 0); + public const string runningJobEnvVar = "CCP_RUNNING_JOBS"; + public const string replicaPathFormat = @"\\{0}\HpcData\{1}.data"; + public const string nodeAdminMutexName = "A19A8AC1-4129-46e2-BB81-ED7EE3265B05"; + public const string nodeAdminUsage = "Syntax:\n\t" + + "HpcDscNodeAdmin [/r] [/g] [/wd] [/e] [/v] [/u]\n\n" + + "Parameters:\n\t" + + "/? \t- Display this help message.\n\t" + + "/g \t- Delete files not managed by DSC from the HpcData share.\n\t" + + "/wd\t- Delete old job working directories from the HpcTemp share.\n\t" + + "/r \t- Replicate DSC files onto this node.\n\t" + + "/e \t- Print full error traces.\n\t" + + "/u \t- Resets HpcReplication account password.\n\t" + + "/v \t- Print verbose activity traces.\n"; + + // HpcReplication user account + internal const string HpcReplicationUserName = "HpcReplication"; + + // Client retry period is 1 second for first retry, increasing up to 12 seconds for a total of 30 seconds + // These timeouts are intended to ride through transient network failures + internal const int StartRetryPeriod = 1000; + internal const int MaxRetryPeriod = 12000; + internal const int TotalRetryPeriod = 30000; + internal const int ClientRetryCount = 4; + + // Runtime retry period is 10 seconds for first retry, increasing up to 60 seconds for a total of 6 minutes + // Runtime timeouts intended to ride through a failover and more severe network disruptions with the goal + // of keeping running jobs alive + internal const int RuntimeStartRetryPeriod = 10000; + internal const int RuntimeMaxRetryPeriod = 60000; + internal const int RuntimeTotalRetryPeriod = 360000; + internal const int RuntimeClientRetryCount = 7; + } +} diff --git a/LinqToDryad/DataPath.cs b/LinqToDryad/DataPath.cs new file mode 100644 index 0000000..6d7c7ef --- /dev/null +++ b/LinqToDryad/DataPath.cs @@ -0,0 +1,239 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Diagnostics; +using System.Globalization; + +namespace Microsoft.Research.DryadLinq +{ + internal static class DataPath + { + const string TEMPORARY_STREAM_NAME_PREFIX = "/__DryadLinq_Temp/"; + internal const string DSC_URI_PREFIX = "hpcdsc://"; + internal const string HDFS_URI_PREFIX = "hpchdfs://"; + + internal const string PrefixSeparator = "://"; + internal const string QuerySeparator = "?"; + internal const char ArgumentSeparator = '&'; + internal const char KeyValueSeparator = '='; + + internal const string DSC_STORAGE_SET_TYPE = "Dsc"; // used to be enum StorageSetType {Dsc} + internal const string HDFS_STORAGE_SET_TYPE = "Hdfs"; // used to be enum StorageSetType {Hdfs} + + public static bool IsDsc(string uri) + { + return uri.StartsWith(DSC_URI_PREFIX, StringComparison.OrdinalIgnoreCase); + } + + public static bool IsHdfs(string dataPath) + { + return dataPath.StartsWith(HDFS_URI_PREFIX); + } + + internal static string GetDataPath(string tableUri) + { + int idx = tableUri.IndexOf(QuerySeparator, StringComparison.Ordinal); + string dataPath = tableUri; + if (idx > 0) + { + dataPath = tableUri.Substring(0, idx); + } + return dataPath; + } + + internal static string GetPrefix(string dataPath) + { + int idx = dataPath.IndexOf(PrefixSeparator, StringComparison.Ordinal); + if (idx <= 0) return ""; + return dataPath.Substring(0, idx+3); + } + + internal static string GetPath(string dataPath) + { + int idx = dataPath.IndexOf(PrefixSeparator, StringComparison.Ordinal); + if (idx <= 0) return dataPath; + return dataPath.Substring(idx+3); + } + + internal static Dictionary GetArguments(string dscFileSetUri) + { + Dictionary args = new Dictionary(); + int idx = dscFileSetUri.IndexOf(QuerySeparator, StringComparison.Ordinal); + while (idx >= 0) + { + idx++; + int idx1 = dscFileSetUri.IndexOf(KeyValueSeparator, idx); + if (idx1 < 0) + { + throw new ArgumentException(String.Format(SR.IllFormedUri, dscFileSetUri), "dscFileSetUri"); + } + string key = dscFileSetUri.Substring(idx, idx1 - idx).ToLower(CultureInfo.InvariantCulture); + idx1++; + idx = dscFileSetUri.IndexOf(ArgumentSeparator, idx1); + string value; + if (idx < 0) + { + value = dscFileSetUri.Substring(idx1); + } + else + { + value = dscFileSetUri.Substring(idx1, idx - idx1); + } + args.Add(key, value); + } + return args; + } + + internal static string PathCombine(string dataPath1, string dataPath2) + { + string prefix = GetPrefix(dataPath1); + if (prefix == "") + { + return Path.Combine(dataPath1, dataPath2); + } + if (prefix == dataPath1) + { + return dataPath1 + dataPath2; + } + DataProvider dp = DataProvider.GetDataProvider(prefix); + dataPath1 = dataPath1.TrimEnd(dp.PathSeparator); + dataPath2 = dataPath2.TrimStart(dp.PathSeparator); + return dataPath1 + dp.PathSeparator + dataPath2; + } + + /// + /// Split a path into a directory name and a filename. + /// + /// Path to split. + /// Directory part of pathname. May be empty if there are no slashes. + /// File part of pathname. + internal static void PathSplit(string dataPath, out string dir, out string file) + { + string prefix = GetPrefix(dataPath); + int slash; + if (prefix == "") + { + slash = dataPath.LastIndexOf(Path.DirectorySeparatorChar); + } + else + { + DataProvider dp = DataProvider.GetDataProvider(prefix); + slash = dataPath.LastIndexOf(dp.PathSeparator); + if (slash < 0) slash = dataPath.LastIndexOf('/'); + } + + if (slash >= 0) + { + file = dataPath.Substring(slash + 1); + dir = dataPath.Substring(0, slash); + } + else + { + file = dataPath; + dir = ""; + } + } + + /// + /// Extract the directory part of a path. + /// + /// Path to split. + /// Directory name (may be empty). + internal static string GetDir(string dataPath) + { + string dir, file; + PathSplit(dataPath, out dir, out file); + return dir; + } + + /// + /// Extract just the filename from a path. + /// + /// Path to split. + /// Just the filename part. + public static string GetFile(string dataPath) + { + string dir, file; + PathSplit(dataPath, out dir, out file); + return file; + } + + internal static string MakeDscStreamUri(DscService dsc, string streamName) + { + string serviceNodeName = dsc.HostName; + return MakeDscStreamUri(serviceNodeName, streamName); + } + + internal static string MakeDscStreamUri(string serviceNodeName, string streamName) + { + Uri uri = new Uri(new Uri("hpcdsc://" + serviceNodeName + ":6498/"), streamName); // use the Uri class to do combining/escaping. + return uri.AbsoluteUri; + } + + internal static string MakeHdfsStreamUri(string serviceNodeName, string streamName) + { + return "hpchdfs://" + serviceNodeName + ":9000/" + streamName; // TODO: Parse config files. + } + + //@@TODO[P2]: some overlap in how unique DSC stream names are made. Cleanup. + // 1. string tableUri = DryadLinqUtil.MakeUniqueName(); + // string fullTableUri = DataPath.GetFullUri(tableUri); + // + // 2. string tableName = DataPath.MakeUniqueDscStreamUri(); + // + // Note: both base their root path on DryadLinq.DryadOutputDir. + + + internal static string MakeUniqueTemporaryDscFileSetName() + { + return TEMPORARY_STREAM_NAME_PREFIX + HpcLinqUtil.MakeUniqueName(); + } + + internal static string MakeUniqueTemporaryDscFileSetUri(HpcLinqContext context) + { + string uri = DataPath.MakeDscStreamUri(context.DscService.HostName, MakeUniqueTemporaryDscFileSetName()); + return uri; + } + + internal static string MakeUniqueTemporaryHdfsFileSetUri(HpcLinqContext context) + { + string uri = DataPath.MakeHdfsStreamUri(context.Configuration.HdfsNameNode, MakeUniqueTemporaryDscFileSetName()); + return uri; + } + + internal static string GetFilesetNameFromUri(string uriString) + { + Uri uri = new Uri(uriString); + return uri.AbsolutePath.TrimStart('/'); + } + } +} diff --git a/LinqToDryad/DataProvider.cs b/LinqToDryad/DataProvider.cs new file mode 100644 index 0000000..e5a8852 --- /dev/null +++ b/LinqToDryad/DataProvider.cs @@ -0,0 +1,234 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using System.Xml; +using System.Data.Linq; +using System.Data.Linq.Mapping; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Diagnostics; + +using Microsoft.Research.DryadLinq.Internal; +using System.IO.Compression; + +namespace Microsoft.Research.DryadLinq +{ + //DataProvider is an abstraction for different data backends. + //Currently this class isn't strictly necessary (DSC is only data-store) but other data sources could come back. + internal class DataProvider + { + private static Dictionary s_providers; + + static DataProvider() + { + s_providers = new Dictionary(); + s_providers.Add(DataPath.DSC_URI_PREFIX, new DataProvider()); + } + + // dead code. + ///// + ///// Register a data provider so that it can be used. + ///// + ///// The prefix name of the data provider + ///// the data provider to be registered + //internal static void Register(string prefix, DataProvider dp) + //{ + // if (s_providers.ContainsKey(prefix)) + // { + // throw new HpcLinqException(HpcLinqErrorCode.Internal_PrefixAlreadyUsedForOtherProvider, + // String.Format(SR.PrefixAlreadyUsedForOtherProvider, prefix)); + // } + // s_providers.Add(prefix, dp); + //} + + /// + /// Get the data provider associated with a prefix. + /// + /// The data provider prefix + /// The data provider + internal static DataProvider GetDataProvider(string prefix) + { + if (!s_providers.ContainsKey(prefix)) + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, + String.Format(SR.UnknownProvier, prefix)); + } + return s_providers[prefix]; + } + + /// + /// The prefix of this data provider. + /// + internal string Prefix + { + get { return DataPath.DSC_URI_PREFIX; } + } + + /// + /// The path separator of this data provider. + /// + internal char PathSeparator + { + get { return '/'; } + } + + /// + /// Get the DSC file set specified by a URI. + /// + /// The record type of the table. + /// The URI of a DscFileSet. + /// A query object representing the dsc file set data. + + + internal static DryadLinqQuery GetPartitionedTable(HpcLinqContext context, string dscFileSetUri) + { + Dictionary args = DataPath.GetArguments(dscFileSetUri); + + DataProvider dataProvider = new DataProvider(); + DryadLinqProvider queryProvider = new DryadLinqProvider(context); + return new DryadLinqQuery(null, queryProvider, dataProvider, dscFileSetUri); + } + + // ingresses data, and also sets the temporary-length lease. + internal static DryadLinqQuery IngressTemporaryDataDirectlyToDsc(HpcLinqContext context, IEnumerable source, string dscFileSetName, DryadLinqMetaData metaData, DscCompressionScheme outputScheme) + { + DryadLinqQuery result = IngressDataDirectlyToDsc(context, source, dscFileSetName, metaData, outputScheme); + + // try to set a temporary lease on the resulting fileset + try + { + DscFileSet fs = context.DscService.GetFileSet(dscFileSetName); + fs.SetLeaseEndTime(DateTime.Now.Add(StaticConfig.LeaseDurationForTempFiles)); + } + catch (DscException) + { + // suppress + } + + return result; + } + + //* streams plain enumerable data directly to DSC + internal static DryadLinqQuery IngressDataDirectlyToDsc(HpcLinqContext context, + IEnumerable source, + string dscFileSetName, + DryadLinqMetaData metaData, + DscCompressionScheme outputScheme) + { + try + { + string dscFileSetUri = DataPath.MakeDscStreamUri(context.DscService.HostName, dscFileSetName); + if (source.Take(1).Count() == 0) + { + //there is no data.. we must create a FileSet with an empty file + //(the factory/stream approach opens files lazily and thus never opens a file if there is no data) + + + if (context.DscService.FileSetExists(dscFileSetName)) + { + context.DscService.DeleteFileSet(dscFileSetName); + } + DscFileSet fileSet = context.DscService.CreateFileSet(dscFileSetName, outputScheme); + DscFile file = fileSet.AddNewFile(0); + string writePath = file.WritePath; + + + if (outputScheme == DscCompressionScheme.Gzip) + { + //even zero-byte file must go through the gzip-compressor (for headers etc). + using (Stream s = new FileStream(writePath, FileMode.Create)) + { + var gzipStream = new GZipStream(s, CompressionMode.Compress); + gzipStream.Close(); + } + } + else + { + StreamWriter sw = new StreamWriter(writePath, false); + sw.Close(); + } + fileSet.Seal(); + + } + else + { + HpcLinqFactory factory = (HpcLinqFactory)HpcLinqCodeGen.GetFactory(context, typeof(T)); + + // new DscBlockStream(uri,Create,Write,compress) provides a DSC stream with one partition. + NativeBlockStream nativeStream = new DscBlockStream(dscFileSetUri, FileMode.Create, FileAccess.Write, outputScheme); + HpcRecordWriter writer = factory.MakeWriter(nativeStream); + try + { + if (context.Configuration.AllowConcurrentUserDelegatesInSingleProcess) + { + foreach (T item in source) + { + writer.WriteRecordAsync(item); + } + } + else + { + foreach (T item in source) + { + writer.WriteRecordSync(item); + } + } + } + finally + { + writer.Close(); // closes the NativeBlockStream, which seals the dsc stream. + } + } + + if (metaData != null) + { + DscFileSet fileSet = context.DscService.GetFileSet(dscFileSetName); + fileSet.SetMetadata(DryadLinqMetaData.RECORD_TYPE_NAME, Encoding.UTF8.GetBytes(metaData.ElemType.AssemblyQualifiedName)); + } + + return DataProvider.GetPartitionedTable(context, dscFileSetUri); + } + catch + { + // if we had a problem creating the empty fileset, try to delete it to avoid cruft being left in DSC. + try + { + context.DscService.DeleteFileSet(dscFileSetName); + } + catch + { + // suppress error during delete + } + + throw; // rethrow the original exception. + } + } + } +} diff --git a/LinqToDryad/DataSetInfo.cs b/LinqToDryad/DataSetInfo.cs new file mode 100644 index 0000000..7b8000e --- /dev/null +++ b/LinqToDryad/DataSetInfo.cs @@ -0,0 +1,800 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.IO; +using System.Runtime.Serialization; +using System.Runtime.Serialization.Formatters.Binary; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Diagnostics; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + // The information we know about the dataset at each stage of the + // computation. For each operator, we try to compute this from the + // DataSetInfo of its input datasets and the semantics of the operator. + [Serializable] + internal class DataSetInfo + { + internal static PartitionInfo OnePartition = new RandomPartition(1); + internal static OrderByInfo NoOrderBy = new OrderByInfo(); + internal static DistinctInfo NoDistinct = new DistinctInfo(); + + internal PartitionInfo partitionInfo; + internal OrderByInfo orderByInfo; + internal DistinctInfo distinctInfo; + + internal DataSetInfo() + { + this.partitionInfo = OnePartition; + this.orderByInfo = NoOrderBy; + this.distinctInfo = NoDistinct; + } + + internal DataSetInfo(PartitionInfo pinfo, OrderByInfo oinfo, DistinctInfo dinfo) + { + this.partitionInfo = pinfo; + this.orderByInfo = oinfo; + this.distinctInfo = dinfo; + } + + internal DataSetInfo(DataSetInfo info) + { + this.partitionInfo = info.partitionInfo; + this.orderByInfo = info.orderByInfo; + this.distinctInfo = info.distinctInfo; + } + + // Return true iff the entire dataset is ordered. + internal bool IsOrderedBy(LambdaExpression keySel, object comparer) + { + return (this.partitionInfo.ParType == PartitionType.Range && + this.partitionInfo.IsPartitionedBy(keySel, comparer) && + this.orderByInfo.IsOrderedBy(keySel, comparer) && + this.orderByInfo.IsSameMonotoncity(this.partitionInfo)); + } + + internal static DataSetInfo Read(Stream fstream) + { + BinaryFormatter bfm = new BinaryFormatter(); + return (DataSetInfo)bfm.Deserialize(fstream); + } + + internal static void Write(DataSetInfo dsInfo, Stream fstream) + { + BinaryFormatter bfm = new BinaryFormatter(); + bfm.Serialize(fstream, dsInfo); + } + } + + internal enum PartitionType + { + Random = 0x0000, + Hash = 0x0001, + Range = 0x0002, + HashOrRange = 0x0003 + } + + internal abstract class PartitionInfo + { + private PartitionType m_partitionType; + + protected PartitionInfo(PartitionType parType) + { + this.m_partitionType = parType; + } + + internal PartitionType ParType + { + get { return this.m_partitionType; } + } + + internal virtual bool IsDescending + { + get { + throw new InvalidOperationException(); + } + } + + internal virtual bool HasKeys + { + get { + throw new InvalidOperationException(); + } + } + + internal virtual PartitionInfo Concat(PartitionInfo p) + { + return new RandomPartition(this.Count + p.Count); + } + + internal abstract int Count { get; set; } + internal abstract bool IsPartitionedBy(LambdaExpression keySel); + internal abstract bool IsPartitionedBy(LambdaExpression keySel, object comparer); + internal abstract bool IsPartitionedBy(LambdaExpression keySel, object comparer, bool isDescending); + internal abstract bool IsSamePartition(PartitionInfo p); + internal abstract DryadQueryNode CreatePartitionNode(LambdaExpression keySelector, DryadQueryNode child); + internal abstract PartitionInfo Create(LambdaExpression keySel); + internal abstract PartitionInfo Rewrite(LambdaExpression keySel, ParameterExpression param); + + internal static PartitionInfo CreateHash(LambdaExpression keySel, int count, object comparer, Type keyType) + { + Type hashType = typeof(HashPartition<>).MakeGenericType(keyType); + object[] args = new object[] { keySel, count, comparer }; + return (PartitionInfo)Activator.CreateInstance(hashType, BindingFlags.NonPublic | BindingFlags.Instance, null ,args, null); + } + + internal static PartitionInfo CreateRange(LambdaExpression keySel, object keys, object comparer, bool? isDescending, Int32 parCnt, Type keyType) + { + Type parType = typeof(RangePartition<>).MakeGenericType(keyType); + object[] args = new object[] { keySel, keys, comparer, isDescending, parCnt }; + try + { + return (PartitionInfo)Activator.CreateInstance(parType, BindingFlags.NonPublic | BindingFlags.Instance, null, args, null); + } + catch (TargetInvocationException tie) + { + // The ctor for RangePartition<> can throw.. we trap and rethrow the useful exception here. + if (tie.InnerException != null) + throw tie.InnerException; + else + throw; + } + } + + internal virtual Pair GetOperator() + { + throw new InvalidOperationException(); + } + } + + internal class RandomPartition : PartitionInfo + { + private int m_count; + + internal RandomPartition(int count) + : base(PartitionType.Random) + { + this.m_count = count; + } + + internal override int Count + { + get { return this.m_count; } + set { this.m_count = value; } + } + + internal override bool IsPartitionedBy(LambdaExpression keySel) + { + return false; + } + + internal override bool IsPartitionedBy(LambdaExpression keySel, object comparer) + { + return false; + } + + internal override bool IsPartitionedBy(LambdaExpression keySel, object comparer, bool isDescending) + { + return false; + } + + internal override bool IsSamePartition(PartitionInfo p) + { + return false; + } + + internal override DryadQueryNode CreatePartitionNode(LambdaExpression keySel, DryadQueryNode child) + { + throw new DryadLinqException(HpcLinqErrorCode.CannotCreatePartitionNodeRandom, + SR.CannotCreatePartitionNodeRandom); + } + + internal override PartitionInfo Create(LambdaExpression keySel) + { + return this; + } + + internal override PartitionInfo Rewrite(LambdaExpression resultSel, ParameterExpression param) + { + return this; + } + } + + internal class RangePartition : PartitionInfo + { + private int m_count; + private LambdaExpression m_keySelector; + private TKey[] m_partitionKeys; + private IComparer m_comparer; + private bool m_isDescending; + + internal RangePartition(LambdaExpression keySelector, TKey[] partitionKeys, IComparer comparer) + : this(keySelector, partitionKeys, comparer, null, 1) + { + } + + internal RangePartition(LambdaExpression keySelector, + TKey[] partitionKeys, + IComparer comparer, + bool? isDescending, + Int32 parCnt) + : base(PartitionType.Range) + { + this.m_count = (partitionKeys == null) ? parCnt : (partitionKeys.Length + 1); + this.m_keySelector = keySelector; + this.m_partitionKeys = partitionKeys; + this.m_comparer = TypeSystem.GetComparer(comparer); + if (isDescending == null) + { + if (partitionKeys == null) + { + throw new DryadLinqException(HpcLinqErrorCode.PartitionKeysNotProvided, + SR.PartitionKeysNotProvided); + } + + bool? detectedIsDescending; + + if (!HpcLinqUtil.ComputeIsDescending(partitionKeys, m_comparer, out detectedIsDescending)) + { + throw new DryadLinqException(HpcLinqErrorCode.PartitionKeysAreNotConsistentlyOrdered, + SR.PartitionKeysAreNotConsistentlyOrdered); + } + + this.m_isDescending = detectedIsDescending ?? false; + } + else + { + this.m_isDescending = isDescending.GetValueOrDefault(); + if (partitionKeys != null && + !HpcLinqUtil.IsOrdered(partitionKeys, this.m_comparer, this.m_isDescending)) + { + throw new DryadLinqException(HpcLinqErrorCode.IsDescendingIsInconsistent, + SR.IsDescendingIsInconsistent); + } + } + } + + internal RangePartition(LambdaExpression keySelector, + IComparer comparer, + bool isDescending, + Int32 parCnt) + : base(PartitionType.Range) + { + this.m_count = parCnt; + this.m_keySelector = keySelector; + this.m_partitionKeys = null; + this.m_comparer = TypeSystem.GetComparer(comparer); + } + + internal TKey[] Keys + { + get { return this.m_partitionKeys; } + } + + internal Expression KeysExpression + { + get { + return Expression.Constant(this.m_partitionKeys); + } + } + + internal Expression KeySelector + { + get { return this.m_keySelector; } + } + + internal IComparer Comparer + { + get { return this.m_comparer; } + } + + internal override bool IsDescending + { + get { return this.m_isDescending; } + } + + internal override bool HasKeys + { + get { return this.m_partitionKeys != null; } + } + + internal override int Count + { + get { return this.m_count; } + set { this.m_count = value; } + } + + internal override bool IsPartitionedBy(LambdaExpression keySel) + { + // Match the key selector functions: + if (this.m_keySelector == null) + { + return (keySel == null); + } + if (keySel == null) return false; + return ExpressionMatcher.Match(this.m_keySelector, keySel); + } + + internal override bool IsPartitionedBy(LambdaExpression keySel, object comp) + { + // Match the key selector functions: + if (!this.IsPartitionedBy(keySel)) + { + return false; + } + + // Check the comparers: + IComparer comp1 = TypeSystem.GetComparer(comp); + if (comp1 == null) return false; + return this.m_comparer.Equals(comp1); + } + + internal override bool IsPartitionedBy(LambdaExpression keySel, object comp, bool isDescending) + { + // Match the key selector functions: + if (!this.IsPartitionedBy(keySel)) + { + return false; + } + + // Check the comparers: + IComparer comp1 = TypeSystem.GetComparer(comp); + if (comp1 == null) return false; + if (this.m_isDescending != isDescending) + { + comp1 = new MinusComparer(comp1); + } + return this.m_comparer.Equals(comp1); + } + + internal override bool IsSamePartition(PartitionInfo p) + { + RangePartition p1 = p as RangePartition; + if (p1 == null) return false; + + // Check the keys: + if (this.Keys == null || + p1.Keys == null || + this.Keys.Length != p1.Keys.Length) + { + return false; + } + + IComparer comp1 = TypeSystem.GetComparer(p1.m_comparer); + if (comp1 == null) return false; + if (this.IsDescending != p1.IsDescending) + { + comp1 = new MinusComparer(comp1); + } + for (int i = 0; i < this.Keys.Length; i++) + { + if (this.m_comparer.Compare(this.Keys[i], p1.Keys[i]) != 0) + { + return false; + } + } + + // Check the comparers: + return this.m_comparer.Equals(p1.m_comparer); + } + + internal override DryadQueryNode CreatePartitionNode(LambdaExpression keySel, DryadQueryNode child) + { + Expression keysExpr = Expression.Constant(this.m_partitionKeys); + Expression comparerExpr = Expression.Constant(this.m_comparer, typeof(IComparer)); + Expression isDescendingExpr = Expression.Constant(this.m_isDescending); + return new DryadRangePartitionNode(keySel, null, keysExpr, comparerExpr, isDescendingExpr, null, child.QueryExpression, child); + } + + internal override PartitionInfo Create(LambdaExpression keySel) + { + Type keyType = keySel.Body.Type; + return PartitionInfo.CreateRange(keySel, this.Keys, this.m_comparer, this.m_isDescending, this.Count, keyType); + } + + internal override PartitionInfo Rewrite(LambdaExpression resultSel, ParameterExpression param) + { + ParameterExpression a = this.m_keySelector.Parameters[0]; + Substitution pSubst = Substitution.Empty.Cons(a, param); + LambdaExpression newKeySel = HpcLinqExpression.Rewrite(this.m_keySelector, resultSel, pSubst); + if (newKeySel == null) + { + return new RandomPartition(this.m_count); + } + return this.Create(newKeySel); + } + + internal override Pair GetOperator() + { + Type sourceType = this.m_keySelector.Parameters[0].Type; + MethodInfo operation = TypeSystem.FindStaticMethod( + typeof(Microsoft.Research.DryadLinq.HpcLinqQueryable), "RangePartition", + new Type[] { typeof(IQueryable<>).MakeGenericType(sourceType), + m_keySelector.GetType(), + m_partitionKeys.GetType(), + m_comparer.GetType(), + typeof(bool) }, + new Type[] { sourceType, typeof(TKey) }); + Expression[] arguments = new Expression[] { + this.m_keySelector, + Expression.Constant(this.m_partitionKeys), + Expression.Constant(this.m_comparer, typeof(IComparer)), + Expression.Constant(this.m_isDescending) }; + + return new Pair(operation, arguments); + } + } + + internal class HashPartition : PartitionInfo + { + private int m_count; + private LambdaExpression m_keySelector; + private IEqualityComparer m_comparer; + + internal HashPartition(LambdaExpression keySelector, int count) + : this(keySelector, count, null) + { + } + + internal HashPartition(LambdaExpression keySelector, int count, IEqualityComparer eqComparer) + : base(PartitionType.Hash) + { + this.m_count = count; + this.m_keySelector = keySelector; + this.m_comparer = (eqComparer == null) ? EqualityComparer.Default : eqComparer; + } + + internal Expression KeySelector + { + get { return this.m_keySelector; } + } + + internal IEqualityComparer EqualityComparer + { + get { return this.m_comparer; } + } + + internal override int Count + { + get { return this.m_count; } + set { this.m_count = value; } + } + + internal override bool IsPartitionedBy(LambdaExpression keySel) + { + // Match the key selector functions: + if (this.m_keySelector == null) + { + return (keySel == null); + } + if (keySel == null) return false; + return ExpressionMatcher.Match(this.m_keySelector, keySel); + } + + internal override bool IsPartitionedBy(LambdaExpression keySel, object comp) + { + // Match the key selector functions: + if (!this.IsPartitionedBy(keySel)) + { + return false; + } + + // Check the comparers: + IEqualityComparer comp1 = TypeSystem.GetEqualityComparer(comp); + if (comp1 == null) return false; + return this.m_comparer.Equals(comp1); + } + + internal override bool IsPartitionedBy(LambdaExpression keySel, object comparer, bool isDescending) + { + return this.IsPartitionedBy(keySel, comparer); + } + + internal override bool IsSamePartition(PartitionInfo p) + { + HashPartition p1 = p as HashPartition; + if (p1 == null || this.Count != p1.Count) + { + return false; + } + // Check the comparers: + return this.m_comparer.Equals(p1.m_comparer); + } + + internal override DryadQueryNode CreatePartitionNode(LambdaExpression keySel, DryadQueryNode child) + { + Expression comparerExpr = Expression.Constant(this.m_comparer, typeof(IEqualityComparer)); + return new DryadHashPartitionNode(keySel, comparerExpr, this.Count, child.QueryExpression, child); + } + + internal override PartitionInfo Create(LambdaExpression keySel) + { + Type keyType = keySel.Body.Type; + return PartitionInfo.CreateHash(keySel, this.Count, this.m_comparer, keyType); + } + + internal override PartitionInfo Rewrite(LambdaExpression resultSel, ParameterExpression param) + { + ParameterExpression a = this.m_keySelector.Parameters[0]; + Substitution pSubst = Substitution.Empty.Cons(a, param); + LambdaExpression newKeySel = HpcLinqExpression.Rewrite(this.m_keySelector, resultSel, pSubst); + if (newKeySel == null) + { + return new RandomPartition(this.m_count); + } + return this.Create(newKeySel); + } + + internal override Pair GetOperator() + { + Type sourceType = this.m_keySelector.Parameters[0].Type; + MethodInfo operation = TypeSystem.FindStaticMethod( + typeof(Microsoft.Research.DryadLinq.HpcLinqQueryable), "HashPartition", + new Type[] { typeof(IQueryable<>).MakeGenericType(sourceType), + m_keySelector.GetType(), + m_comparer.GetType(), + typeof(int) }, + new Type[] { sourceType, typeof(TKey) }); + + Expression[] arguments = new Expression[] { + m_keySelector, + Expression.Constant(this.m_comparer, typeof(IEqualityComparer)), + Expression.Constant(this.Count) }; + + return new Pair(operation, arguments); + } + } + + internal class OrderByInfo + { + internal virtual bool IsOrdered + { + get { return false; } + } + + internal virtual LambdaExpression KeySelector + { + get { return null; } + } + + internal virtual Expression Comparer + { + get { return null; } + } + + internal virtual bool IsDescending + { + get { return false; } + } + + internal virtual bool IsOrderedBy(LambdaExpression keySel) + { + return false; + } + + internal virtual bool IsOrderedBy(LambdaExpression keySel, object comparer) + { + return false; + } + + internal virtual bool IsOrderedBy(LambdaExpression keySel, object comparer, bool isDescending) + { + return false; + } + + internal virtual bool IsSameMonotoncity(PartitionInfo pinfo) + { + return false; + } + + internal static OrderByInfo Create(Expression keySel, object comparer, bool isDescending, Type keyType) + { + Type infoType = typeof(OrderByInfo<>).MakeGenericType(keyType); + object[] args = new object[] { keySel, comparer, isDescending }; + return (OrderByInfo)Activator.CreateInstance(infoType, BindingFlags.NonPublic | BindingFlags.Instance, null, args, null); + } + + internal virtual OrderByInfo Create(LambdaExpression keySel) + { + return DataSetInfo.NoOrderBy; + } + + internal virtual OrderByInfo Rewrite(LambdaExpression resultSel, ParameterExpression param) + { + return DataSetInfo.NoOrderBy; + } + } + + internal class OrderByInfo : OrderByInfo + { + private LambdaExpression m_keySelector; + private IComparer m_comparer; + private bool m_isDescending; + + internal OrderByInfo(LambdaExpression keySelector, IComparer comparer, bool isDescending) + { + this.m_keySelector = keySelector; + this.m_comparer = TypeSystem.GetComparer(comparer); + this.m_isDescending = isDescending; + } + + internal override LambdaExpression KeySelector + { + get { return this.m_keySelector; } + } + + internal override Expression Comparer + { + get { + return Expression.Constant(this.m_comparer, typeof(IComparer)); + } + } + + internal override bool IsDescending + { + get { return this.m_isDescending; } + } + + internal override bool IsOrdered + { + get { return true; } + } + + internal override bool IsOrderedBy(LambdaExpression keySel) + { + if (this.m_keySelector == null) + { + return (keySel == null); + } + if (keySel == null) return false; + return ExpressionMatcher.Match(this.m_keySelector, keySel); + } + + internal override bool IsOrderedBy(LambdaExpression keySel, object comp) + { + // Match the key selector functions: + if (!this.IsOrderedBy(keySel)) + { + return false; + } + + // Check the comparers: + IComparer comp1 = TypeSystem.GetComparer(comp); + if (comp1 == null) return false; + return this.m_comparer.Equals(comp1); + } + + internal override bool IsOrderedBy(LambdaExpression keySel, object comp, bool isDescending) + { + // Match the key selector functions: + if (!this.IsOrderedBy(keySel)) + { + return false; + } + + // Check the comparers: + IComparer comp1 = TypeSystem.GetComparer(comp); + if (comp1 == null) return false; + if (this.IsDescending != isDescending) + { + comp1 = new MinusComparer(comp1); + } + return this.m_comparer.Equals(comp1); + } + + internal override bool IsSameMonotoncity(PartitionInfo pinfo) + { + RangePartition pinfo1 = pinfo as RangePartition; + if (pinfo1 == null) return false; + + IComparer comp1 = pinfo1.Comparer; + if (this.m_isDescending != pinfo1.IsDescending) + { + comp1 = new MinusComparer(comp1); + } + return this.m_comparer.Equals(comp1); + } + + internal override OrderByInfo Create(LambdaExpression keySel) + { + Type keyType = keySel.Body.Type; + return OrderByInfo.Create(keySel, this.m_comparer, this.m_isDescending, keyType); + } + + internal override OrderByInfo Rewrite(LambdaExpression resultSel, ParameterExpression param) + { + ParameterExpression a = this.m_keySelector.Parameters[0]; + Substitution pSubst = Substitution.Empty.Cons(a, param); + LambdaExpression newKeySel = HpcLinqExpression.Rewrite(this.m_keySelector, resultSel, pSubst); + if (newKeySel == null) + { + return DataSetInfo.NoOrderBy; + } + return this.Create(newKeySel); + } + } + + internal class DistinctInfo + { + internal virtual bool IsDistinct() + { + return false; + } + + internal virtual bool IsDistinct(object comp) + { + return false; + } + + internal virtual bool IsSameDistinct(DistinctInfo dist) + { + return false; + } + + internal static DistinctInfo Create(object comparer, Type type) + { + Type infoType = typeof(DistinctInfo<>).MakeGenericType(type); + object[] args = new object[] { comparer }; + return (DistinctInfo)Activator.CreateInstance(infoType, BindingFlags.NonPublic | BindingFlags.Instance, null ,args, null); + } + } + + internal class DistinctInfo : DistinctInfo + { + private IEqualityComparer m_comparer; + + internal Expression Comparer + { + get { return Expression.Constant(this.m_comparer, typeof(IEqualityComparer)); } + } + + internal DistinctInfo(IEqualityComparer comparer) + { + this.m_comparer = (comparer == null) ? EqualityComparer.Default : comparer; + } + + internal override bool IsDistinct() + { + return true; + } + + internal override bool IsDistinct(object comp) + { + IEqualityComparer comp1 = TypeSystem.GetEqualityComparer(comp); + if (comp1 == null) return false; + return this.m_comparer.Equals(comp1); + } + + internal override bool IsSameDistinct(DistinctInfo dist) + { + DistinctInfo info = dist as DistinctInfo; + if (info == null) return false; + else return IsDistinct(info.Comparer); + } + } +} diff --git a/LinqToDryad/DryadBinaryReader.cs b/LinqToDryad/DryadBinaryReader.cs new file mode 100644 index 0000000..7a90a9f --- /dev/null +++ b/LinqToDryad/DryadBinaryReader.cs @@ -0,0 +1,742 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.Text; +using System.IO; +using System.Runtime.InteropServices; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.Runtime.Serialization.Formatters.Binary; +using System.Runtime.Serialization; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + /// + /// HpcBinaryReader is the main interface for user provided custom serializers + /// or DL-internal autoserialization codepaths to read primitive types from a partition file. + /// + public unsafe sealed class HpcBinaryReader + { + private NativeBlockStream m_nativeStream; + private Encoding m_encoding; + private Decoder m_decoder; + + private DataBlockInfo m_curDataBlockInfo; + private byte* m_curDataBlock; // The current read buffer. This is requested from the native stream as it depletes, + // and individual ReadXXX methods deserialize primitives out of this buffer + private Int32 m_curBlockSize; // Size of the current read buffer. + private Int32 m_curBlockPos; // Current position on the read buffer. This is updated by individual ReadXXX methods as they pull bytes for primitives they are reading + + private bool m_isClosed; + + internal HpcBinaryReader(NativeBlockStream stream) + : this(stream, Encoding.UTF8) + { + } + + internal HpcBinaryReader(NativeBlockStream stream, Encoding encoding) + { + this.m_nativeStream = stream; + this.m_encoding = encoding; + this.m_decoder = encoding.GetDecoder(); + this.m_curDataBlockInfo.dataBlock = null; + this.m_curDataBlockInfo.blockSize = -1; + this.m_curDataBlockInfo.itemHandle = IntPtr.Zero; + this.m_curDataBlock = this.m_curDataBlockInfo.dataBlock; + this.m_curBlockSize = this.m_curDataBlockInfo.blockSize; + this.m_curBlockPos = -1; + this.m_isClosed = false; + } + + internal HpcBinaryReader(IntPtr vertexInfo, UInt32 portNum) + : this(new HpcLinqChannel(vertexInfo, portNum, true), Encoding.UTF8) + { + } + + internal HpcBinaryReader(IntPtr vertexInfo, UInt32 portNum, Encoding encoding) + : this(new HpcLinqChannel(vertexInfo, portNum, true), encoding) + { + } + + ~HpcBinaryReader() + { + this.Close(); + } + + //////////////////////////////////////////////////////////////////////////////// + // + // Internal methods + // + + internal Int64 GetTotalLength() + { + return this.m_nativeStream.GetTotalLength(); + } + + internal long Length + { + get { return this.m_curBlockSize; } + } + + internal void Close() + { + if (!this.m_isClosed) + { + this.m_isClosed = true; + this.m_nativeStream.ReleaseDataBlock(this.m_curDataBlockInfo.itemHandle); + this.m_nativeStream.Close(); + } + + GC.SuppressFinalize(this); + } + + /// + /// Private helper to request a new data block from the native stream. + /// - it releases the current data block back to the native stream code + /// (which owns the lifecycle of read buffers), + /// - then requests a new buffer from the native stream + /// - and updates the internal read buffer pointer (m_curDataBlock), size (m_curDataSize) + /// and position, all of which are needed by subsequent Read*() calls coming from the user + /// + private unsafe void GetNextDataBlock() + { + this.m_nativeStream.ReleaseDataBlock(this.m_curDataBlockInfo.itemHandle); + this.m_curDataBlockInfo.itemHandle = IntPtr.Zero; + this.m_curDataBlockInfo = this.m_nativeStream.ReadDataBlock(); + this.m_curDataBlock = this.m_curDataBlockInfo.dataBlock; + this.m_curBlockSize = this.m_curDataBlockInfo.blockSize; + this.m_curBlockPos = 0; + } + + internal string GetChannelURI() + { + return this.m_nativeStream.GetURI(); + } + + //////////////////////////////////////////////////////////////////////////////// + // + // Public methods + // + + /// + /// Helper used by HpcRecordReader and generated vertex code to check whether the + /// reader reached the end of stream. + /// Returns true if the reader is at the end of the stream, and false is there is + /// more data to read. + /// This check may cause a new data block to be fetched if the call happens while + /// we're at the end of the current read buffer. + /// + internal bool EndOfStream() + { + if (this.m_curBlockPos < this.m_curBlockSize) + { + return false; + } + + if (this.m_curBlockSize == 0) + { + return true; + } + + this.GetNextDataBlock(); + return (this.m_curBlockSize <= 0); + } + + /// + /// Read a byte from the current reader and advances the current position of the + /// reader by one byte. + /// + /// The next byte read from the current reader. + public byte ReadUByte() + { + if (this.m_curBlockPos == this.m_curBlockSize) + { + this.GetNextDataBlock(); + if (this.m_curBlockSize <= 0) + { + throw new DryadLinqException(HpcLinqErrorCode.EndOfStreamEncountered, + String.Format(SR.EndOfStreamEncountered, + GetChannelURI())); + } + } + return this.m_curDataBlock[this.m_curBlockPos++]; + } + + /// + /// Read a signed byte from the current reader and advances the current + /// position of the reader by one byte. + /// + /// The next signed byte read from the current reader. + public sbyte ReadSByte() + { + if (this.m_curBlockPos == this.m_curBlockSize) + { + this.GetNextDataBlock(); + if (this.m_curBlockSize <= 0) + { + throw new DryadLinqException(HpcLinqErrorCode.EndOfStreamEncountered, + String.Format(SR.EndOfStreamEncountered, + GetChannelURI())); + } + } + return (sbyte)this.m_curDataBlock[this.m_curBlockPos++]; + } + + /// + /// Read a boolean value from the current reader and advances the current + /// position of the reader by one byte. + /// + /// true iff the byte is nonzero. + public unsafe bool ReadBool() + { + if (this.m_curBlockPos == this.m_curBlockSize) + { + this.GetNextDataBlock(); + if (this.m_curBlockSize <= 0) + { + throw new DryadLinqException(HpcLinqErrorCode.EndOfStreamEncountered, + String.Format(SR.EndOfStreamEncountered, + GetChannelURI())); + } + } + byte res = this.m_curDataBlock[this.m_curBlockPos++]; + return (res == 0) ? false : true; + } + + /// + /// Read a character from the current reader and advances the current position of the reader + /// according to the encoding and the character. + /// + /// A character read from the current reader. + public char ReadChar() + { + char ch; + char *pCh = &ch; + + while (true) + { + // request a new buffer from the native stream if we're at the end of the current one + if (this.m_curBlockPos == this.m_curBlockSize) + { + this.GetNextDataBlock(); + if (this.m_curBlockSize <= 0) + { + throw new DryadLinqException(HpcLinqErrorCode.EndOfStreamEncountered, + String.Format(SR.EndOfStreamEncountered, + GetChannelURI())); + } + } + + // decode a character and update current position + int numChars = this.m_decoder.GetChars(this.m_curDataBlock + this.m_curBlockPos, 1, pCh, 1, false); + this.m_curBlockPos++; + if (numChars == 1) return ch; + } + } + + public short ReadInt16() + { + ushort low, high; + if (this.m_curBlockSize < this.m_curBlockPos + 2) + { + low = this.ReadUByte(); + high = this.ReadUByte(); + } + else + { + low = this.m_curDataBlock[this.m_curBlockPos++]; + high = this.m_curDataBlock[this.m_curBlockPos++]; + } + return (short)(low | (high << 8)); + } + + public ushort ReadUInt16() + { + ushort low, high; + if (this.m_curBlockSize < this.m_curBlockPos + 2) + { + low = this.ReadUByte(); + high = this.ReadUByte(); + } + else + { + low = this.m_curDataBlock[this.m_curBlockPos++]; + high = this.m_curDataBlock[this.m_curBlockPos++]; + } + return (ushort)(low | (high << 8)); + } + + public int ReadInt32() + { + int b1, b2, b3, b4; + if (this.m_curBlockSize < this.m_curBlockPos + 4) + { + b1 = this.ReadUByte(); + b2 = this.ReadUByte() << 8; + b3 = this.ReadUByte() << 16; + b4 = this.ReadUByte() << 24; + } + else + { + b1 = this.m_curDataBlock[this.m_curBlockPos++]; + b2 = this.m_curDataBlock[this.m_curBlockPos++] << 8; + b3 = this.m_curDataBlock[this.m_curBlockPos++] << 16; + b4 = this.m_curDataBlock[this.m_curBlockPos++] << 24; + } + return (int)(b1 | b2 | b3 | b4); + } + + public unsafe int ReadCompactInt32() + { + int b1, b2, b3, b4; + b1 = this.ReadUByte(); + if (b1 < 0x80) + { + return b1; + } + else + { + b1 = (b1 & 0x7F) << 24; + if (this.m_curBlockSize < this.m_curBlockPos + 3) + { + b2 = this.ReadUByte() << 16; + b3 = this.ReadUByte() << 8; + b4 = this.ReadUByte(); + } + else + { + b2 = this.m_curDataBlock[this.m_curBlockPos++] << 16; + b3 = this.m_curDataBlock[this.m_curBlockPos++] << 8; + b4 = this.m_curDataBlock[this.m_curBlockPos++]; + } + return (int)(b1 | b2 | b3 | b4); + } + } + + public uint ReadUInt32() + { + int b1, b2, b3, b4; + if (this.m_curBlockSize < this.m_curBlockPos + 4) + { + b1 = this.ReadUByte(); + b2 = this.ReadUByte() << 8; + b3 = this.ReadUByte() << 16; + b4 = this.ReadUByte() << 24; + } + else + { + b1 = this.m_curDataBlock[this.m_curBlockPos++]; + b2 = this.m_curDataBlock[this.m_curBlockPos++] << 8; + b3 = this.m_curDataBlock[this.m_curBlockPos++] << 16; + b4 = this.m_curDataBlock[this.m_curBlockPos++] << 24; + } + return (uint)(b1 | b2 | b3 | b4); + } + + public long ReadInt64() + { + uint lo, hi; + if (this.m_curBlockSize < this.m_curBlockPos + 8) + { + lo = this.ReadUInt32(); + hi = this.ReadUInt32(); + } + else + { + lo = (uint)(this.m_curDataBlock[this.m_curBlockPos++] | + this.m_curDataBlock[this.m_curBlockPos++] << 8 | + this.m_curDataBlock[this.m_curBlockPos++] << 16 | + this.m_curDataBlock[this.m_curBlockPos++] << 24); + hi = (uint)(this.m_curDataBlock[this.m_curBlockPos++] | + this.m_curDataBlock[this.m_curBlockPos++] << 8 | + this.m_curDataBlock[this.m_curBlockPos++] << 16 | + this.m_curDataBlock[this.m_curBlockPos++] << 24); + } + return (long)(((ulong)hi) << 32 | lo); + } + + public ulong ReadUInt64() + { + uint lo, hi; + if (this.m_curBlockSize < this.m_curBlockPos + 8) + { + lo = this.ReadUInt32(); + hi = this.ReadUInt32(); + } + else + { + lo = (uint)(this.m_curDataBlock[this.m_curBlockPos++] | + this.m_curDataBlock[this.m_curBlockPos++] << 8 | + this.m_curDataBlock[this.m_curBlockPos++] << 16 | + this.m_curDataBlock[this.m_curBlockPos++] << 24); + hi = (uint)(this.m_curDataBlock[this.m_curBlockPos++] | + this.m_curDataBlock[this.m_curBlockPos++] << 8 | + this.m_curDataBlock[this.m_curBlockPos++] << 16 | + this.m_curDataBlock[this.m_curBlockPos++] << 24); + } + return ((ulong)hi) << 32 | lo; + } + + public float ReadSingle() + { + int tmp = this.ReadInt32(); + return *((float*)&tmp); + } + + public decimal ReadDecimal() + { + decimal val; + this.ReadRawBytes((byte*)&val, sizeof(decimal)); + return val; + } + + public double ReadDouble() + { + ulong tmp = this.ReadUInt64(); + return *((double*)&tmp); + } + + private const Int64 TicksMask = 0x3FFFFFFFFFFFFFFF; + private const Int32 KindShift = 62; + + public DateTime ReadDateTime() + { + UInt64 value = this.ReadUInt64(); + return new DateTime((Int64)(value & TicksMask), (DateTimeKind)(value >> KindShift)); + } + + public SqlDateTime ReadSqlDateTime() + { + int dayTicks = this.ReadInt32(); + int timeTicks = this.ReadInt32(); + return new SqlDateTime(dayTicks, timeTicks); + } + + public Guid ReadGuid() + { + Guid guid; + ReadRawBytes((byte*)&guid, sizeof(Guid)); + return guid; + } + + + /// + /// Reads chars into starting at . + /// + /// The pre-allocated char array to read data into. + /// The starting offset at which to begin reading chars into . + /// The maximum number of chars to read. Must be smaller than or equal to ( - ). + /// The number of chars that was actually read. + public unsafe int ReadChars(char[] destBuffer, int offset, int charCount) + { + if (destBuffer == null) + { + throw new ArgumentNullException("destBuffer"); + } + if (offset < 0) + { + throw new ArgumentOutOfRangeException("offset"); + } + if (charCount < 0) + { + throw new ArgumentOutOfRangeException("charCount"); + } + if (destBuffer.Length < (offset + charCount)) + { + throw new ArgumentOutOfRangeException("destBuffer", + String.Format(SR.ArrayLengthVsCountAndOffset, + "destBuffer", offset + charCount, + "offset", "charCount")); + } + + Int32 numChars = charCount; + Int32 numMaxBytes = m_encoding.GetMaxByteCount(charCount); // note numMaxBytes may not always equal the actual bytes consumed in the conversion + int numCharsDecoded = 0; + + while (numChars > 0) + { + // check if there's enough data in the read buffer to finish the conversion + // if so do it in a single step, adjust read buffer location and return + int numAvailBytes = this.m_curBlockSize - this.m_curBlockPos; + if (numAvailBytes >= numMaxBytes) + { + int bytesUsed; + int charsConverted; + bool completed; + + fixed (char *pChars = destBuffer) + { + this.m_decoder.Convert(this.m_curDataBlock + this.m_curBlockPos, + numMaxBytes, + pChars + offset + numCharsDecoded, + numChars, + false, + out bytesUsed, + out charsConverted, + out completed); + } + + this.m_curBlockPos += bytesUsed; + numCharsDecoded = charCount; + return numCharsDecoded; + } + + // The remaining bytes in the read buffer don't *seem to be* enough to satisfy + // the request, but attempt to convert all the remaining bytes, and adjust + // current pos etc. before requesting a new buffer + if (numAvailBytes != 0) + { + int bytesUsed; + int charsConverted; + bool completed; + + fixed (char *pChars = destBuffer) + { + this.m_decoder.Convert(this.m_curDataBlock + this.m_curBlockPos, + numAvailBytes, + pChars + offset + numCharsDecoded, + numChars, + false, + out bytesUsed, + out charsConverted, + out completed); + } + numChars -= charsConverted; // update the number of remaining chars to convert + numMaxBytes -= bytesUsed; // adjust the max bytes estimate + numAvailBytes -= bytesUsed; + numCharsDecoded += charsConverted; + + // Even though it seemed like the remaining # of bytes wouldn't be enough, + // there's still a chance we decoded all the chars we needed + // (this can happen if the decoding used less bytes / char than the max estimate) + // So we need to check for this and return if it's indeed the case + if (numChars == 0) + { + this.m_curBlockPos += bytesUsed; + break; + } + } + + Debug.Assert(numAvailBytes == 0); // if we've reached this line there must be 0 bytes remaining, if not there's a mismatch in the math above + + // if we are here it means we've depleted all the bytes + this.GetNextDataBlock(); + if (this.m_curBlockSize <= 0) + { + // this means we're at the end of the file. simply break so that we + // return the number of chars read so far + break; + } + } + + return numCharsDecoded; + } + + public string ReadString() + { + // First read the length of the string and the number of bytes needed + Int32 numChars = this.ReadCompactInt32(); + Int32 numBytes = this.ReadCompactInt32(); + + // allocate the string + string str = new String('a', numChars); + int numCharsDecoded = 0; + + while (numChars > 0) + { + int numAvailBytes = this.m_curBlockSize - this.m_curBlockPos; + + // Check whether current read buffer has enough data to fill the string + // buffer to the end. If so, invoke decoder to copy from bytes to the + // destination chars, update m_curBlockPos and return + if (numAvailBytes >= numBytes) + { + fixed (char *pChars = str) + { + this.m_decoder.GetChars(this.m_curDataBlock + this.m_curBlockPos, + numBytes, + pChars + numCharsDecoded, + numChars, + false); + } + this.m_curBlockPos += numBytes; + break; + } + + + // If there wasn't enough data in the read buffer convert the remaining bytes to chars + // and request a new data block from the stream. + if (numAvailBytes != 0) + { + Int32 num = 0; + fixed (char *pChars = str) + { + num = this.m_decoder.GetChars(this.m_curDataBlock + this.m_curBlockPos, + numAvailBytes, + pChars + numCharsDecoded, + numChars, + false); + } + numChars -= num; + numBytes -= numAvailBytes; + numCharsDecoded += num; + } + + this.GetNextDataBlock(); + if (this.m_curBlockSize <= 0) + { + throw new DryadLinqException(HpcLinqErrorCode.EndOfStreamEncountered, + String.Format(SR.EndOfStreamEncountered, + GetChannelURI())); + } + } + return str; + } + + /// + /// Reads bytes into starting at . + /// + /// The pre-allocated byte array to read data into. + /// The starting offset at which to begin reading bytes into . + /// The maximum number of bytes to read. Must be smaller than or equal to ( - ). + /// The number of bytes that was actually read. + public int ReadBytes(byte[] destBuffer, int offset, int byteCount) + { + if (destBuffer == null) + { + throw new ArgumentNullException("destBuffer"); + } + if (offset < 0) + { + throw new ArgumentOutOfRangeException("offset"); + } + if (byteCount < 0) + { + throw new ArgumentOutOfRangeException("byteCount"); + } + if (destBuffer.Length < (offset + byteCount)) + { + throw new ArgumentOutOfRangeException("destBuffer", + String.Format(SR.ArrayLengthVsCountAndOffset, + "destBuffer", offset + byteCount, + "offset", "byteCount")); + } + + int numBytes = byteCount; + int numBytesRead = 0; + fixed (byte *pBytes = &destBuffer[offset]) + { + while (numBytes > 0) + { + int numAvailBytes = this.m_curBlockSize - this.m_curBlockPos; + + // Check if there are enough bytes in the read buffer to satisfy the + // caller's request. If so, do the copy, update m_curBlockPos and return + if (numAvailBytes >= numBytes) + { + HpcLinqUtil.memcpy(this.m_curDataBlock + this.m_curBlockPos, + pBytes + numBytesRead, + numBytes); + this.m_curBlockPos += numBytes; + numBytesRead = byteCount; + break; + } + + // The remaining data in the read buffer isn't enough to fill the user's request... + // Copy the all the remaining bytes to the destination buffer, and request a + // new read buffer from the stream. + // Note that we don't need to update m_curBlockPos here because the + // GetNextDataBlock call will reset it. + HpcLinqUtil.memcpy(this.m_curDataBlock + this.m_curBlockPos, + pBytes + numBytesRead, + numAvailBytes); + + // update numBytes/numBytesRead + numBytes -= numAvailBytes; + numBytesRead += numAvailBytes; + + this.GetNextDataBlock(); + if (this.m_curBlockSize <= 0) + { + // if the file stream returned an empty buffer it means we are at the + // end of the file. Just return the total number of bytes read, and the + // caller will decide how to handle it. + break; + } + + // continue with the loop to keep filling the remaining parts of the + // destination buffer + } + } + return numBytesRead; + } + + /// + /// public helper to read into a byte*, mainly used to read preallocated fixed size, + /// non-integer types (Array, Guid, decimal etc) + /// + public void ReadRawBytes(byte* pBytes, int numBytes) + { + int numBytesRead = 0; + while (numBytes > 0) + { + int numAvailBytes = this.m_curBlockSize - this.m_curBlockPos; + + // if m_curDataBlock has enough bytes to fill the remainder of the user's request, + // simply copy and exit. + if (numAvailBytes >= numBytes) + { + HpcLinqUtil.memcpy(this.m_curDataBlock + this.m_curBlockPos, + pBytes + numBytesRead, + numBytes); + this.m_curBlockPos += numBytes; + break; + } + + // if m_curDataBlock has less data than required, copy all the remaining bytes + // to user's buffer, update BytesRead counter, and request a new data block from + // the native stream now that m_curDataBlock is depleted + // Note that we don't need to update m_curBlockPos after memcpy() becase the + // subsequent GetNextDataBlock() call will reset it + HpcLinqUtil.memcpy(this.m_curDataBlock + this.m_curBlockPos, + pBytes + numBytesRead, + numAvailBytes); + this.GetNextDataBlock(); + if (this.m_curBlockSize <= 0) + { + throw new DryadLinqException(HpcLinqErrorCode.EndOfStreamEncountered, + String.Format(SR.EndOfStreamEncountered, + GetChannelURI())); + } + numBytes -= numAvailBytes; + numBytesRead += numAvailBytes; + + // here we go on to do another loop, as we can only reach when the user's request + // isn't fulfilled. + } + } + } +} diff --git a/LinqToDryad/DryadBinaryWriter.cs b/LinqToDryad/DryadBinaryWriter.cs new file mode 100644 index 0000000..3e3565d --- /dev/null +++ b/LinqToDryad/DryadBinaryWriter.cs @@ -0,0 +1,646 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.Text; +using System.IO; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.Runtime.Serialization.Formatters.Binary; +using System.Runtime.Serialization; +using Microsoft.Research.DryadLinq.Internal; +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq +{ + /// + /// HpcBinaryWriter is the main interface for user provided custom serializers + /// or DL-internal autoserialization codepaths to write primitive types from a partition file. + /// + public unsafe sealed class HpcBinaryWriter + { + private const Int32 DefaultBlockSize = 256 * 1024; + + private NativeBlockStream m_nativeStream; + private Encoding m_encoding; + private Int32 m_nextBlockSize; + private Int32 m_bufferSizeHint; + + private DataBlockInfo m_curDataBlockInfo; + private byte* m_curDataBlock; // The current write buffer. This is allocated from the native stream, + // individual WriteXXX methods serialize primitives into this buffer, + // and it gets written out when it's full + private Int32 m_curBlockSize; // Size of the current write buffer. + private Int32 m_curRecordStart; + private Int32 m_curRecordEnd; + private Int32 m_charMaxByteCount; + private bool m_isClosed; + private Int64 m_numBytesWritten; + private bool m_calcFP; + private BinaryFormatter m_bfm; + + internal HpcBinaryWriter(NativeBlockStream stream) + : this(stream, Encoding.UTF8) + { + } + + internal HpcBinaryWriter(NativeBlockStream stream, Encoding encoding) + : this(stream, encoding, DefaultBlockSize) + { + } + + internal HpcBinaryWriter(NativeBlockStream stream, Encoding encoding, Int32 buffSize) + { + this.m_nativeStream = stream; + this.m_encoding = encoding; + this.m_nextBlockSize = Math.Max(DefaultBlockSize, buffSize / 2); + this.m_bufferSizeHint = buffSize; + this.m_curDataBlockInfo.dataBlock = null; + this.m_curDataBlockInfo.blockSize = 0; + this.m_curDataBlockInfo.itemHandle = IntPtr.Zero; + this.m_curDataBlock = this.m_curDataBlockInfo.dataBlock; + this.m_curBlockSize = this.m_curDataBlockInfo.blockSize; + this.m_curRecordStart = 0; + this.m_curRecordEnd = 0; + this.m_charMaxByteCount = this.m_encoding.GetMaxByteCount(1); + this.m_isClosed = false; + this.m_numBytesWritten = 0; + this.m_calcFP = false; + this.m_bfm = new BinaryFormatter(); + } + + internal HpcBinaryWriter(IntPtr vertexInfo, UInt32 portNum, Int32 buffSize) + : this(new HpcLinqChannel(vertexInfo, portNum, false), Encoding.UTF8, buffSize) + { + } + + internal HpcBinaryWriter(IntPtr vertexInfo, UInt32 portNum, Encoding encoding, Int32 buffSize) + : this(new HpcLinqChannel(vertexInfo, portNum, false), encoding, buffSize) + { + } + + ~HpcBinaryWriter() + { + this.Close(); + } + + //////////////////////////////////////////////////////////////////////////////// + // + // Internal methods + // + internal Int32 BufferSizeHint + { + get { return this.m_bufferSizeHint; } + } + + internal void CompleteWriteRecord() + { + this.m_curRecordStart = this.m_curRecordEnd; + } + + internal bool CalcFP + { + get { return this.m_calcFP; } + set + { + this.m_nativeStream.SetCalcFP(); + this.m_calcFP = value; + } + } + + internal string GetChannelURI() + { + return this.m_nativeStream.GetURI(); + } + + internal Int64 GetTotalLength() + { + return this.m_nativeStream.GetTotalLength(); + } + + internal UInt64 GetFingerPrint() + { + if (!this.m_calcFP) + { + throw new DryadLinqException(HpcLinqErrorCode.FingerprintDisabled, + SR.FingerprintDisabled); + } + return this.m_nativeStream.GetFingerPrint(); + } + + /// + /// Writes out the current data buffer (equivalent of FlushDataBlock), and calls + /// Flush on the native stream to ensure all the data makes its way to the disk + /// + internal void Flush() + { + if (this.m_curRecordEnd > 0) + { + this.m_nativeStream.WriteDataBlock(this.m_curDataBlockInfo.itemHandle, this.m_curRecordEnd); + this.m_numBytesWritten += this.m_curRecordEnd; + this.m_nativeStream.ReleaseDataBlock(this.m_curDataBlockInfo.itemHandle); + this.m_curDataBlockInfo.itemHandle = IntPtr.Zero; + this.m_curDataBlockInfo = this.m_nativeStream.AllocateDataBlock(this.m_curBlockSize); + this.m_curDataBlock = this.m_curDataBlockInfo.dataBlock; + this.m_curBlockSize = this.m_curDataBlockInfo.blockSize; + this.m_curRecordStart = 0; + this.m_curRecordEnd = 0; + } + this.m_nativeStream.Flush(); + } + + /// + /// Internal entry point to flush and close the writer. This is called by the record writer + /// + internal void Close() + { + if (!this.m_isClosed) + { + this.m_isClosed = true; + this.Flush(); + if (this.m_curBlockSize > 0) + { + this.m_nativeStream.ReleaseDataBlock(this.m_curDataBlockInfo.itemHandle); + } + this.m_nativeStream.Close(); + } + + GC.SuppressFinalize(this); + } + + + /// + /// Private helper to write the current block out to the native stream. + /// - it writes out the current data buffer up to the point it was filled + /// - it releases the current data block back to the native stream code (which owns the lifecycle of read buffers), + /// - then allocated a new buffer from the native stream + /// - and updates the internal read buffer pointer and position members + /// + private void FlushDataBlock() + { + DataBlockInfo newDataBlockInfo; + if (this.m_curRecordStart <= 16) + { + // The current block is too small for a single record, augment it + if (this.m_curBlockSize == this.m_nextBlockSize) + { + throw new DryadLinqException(HpcLinqErrorCode.RecordSizeMax2GB, SR.RecordSizeMax2GB); + } + newDataBlockInfo = this.m_nativeStream.AllocateDataBlock(this.m_nextBlockSize); + this.m_nextBlockSize = this.m_nextBlockSize * 2; + if (this.m_nextBlockSize < 0) + { + this.m_nextBlockSize = 0x7FFFFFF8; + } + HpcLinqUtil.memcpy(this.m_curDataBlock, newDataBlockInfo.dataBlock, this.m_curRecordEnd); + } + else + { + // Write all the complete records in the block, put the partial record in the new block + newDataBlockInfo = this.m_nativeStream.AllocateDataBlock(this.m_curBlockSize); + HpcLinqUtil.memcpy(this.m_curDataBlock + this.m_curRecordStart, + newDataBlockInfo.dataBlock, + this.m_curRecordEnd - this.m_curRecordStart); + this.m_nativeStream.WriteDataBlock(this.m_curDataBlockInfo.itemHandle, this.m_curRecordStart); + this.m_numBytesWritten += this.m_curRecordStart; + this.m_curRecordEnd -= this.m_curRecordStart; + this.m_curRecordStart = 0; + } + this.m_nativeStream.ReleaseDataBlock(this.m_curDataBlockInfo.itemHandle); + this.m_curDataBlockInfo.itemHandle = IntPtr.Zero; + this.m_curDataBlockInfo = newDataBlockInfo; + this.m_curDataBlock = newDataBlockInfo.dataBlock; + this.m_curBlockSize = newDataBlockInfo.blockSize; + } + + internal Int64 Length + { + get + { + return this.m_numBytesWritten + this.m_curRecordEnd; + } + } + + //////////////////////////////////////////////////////////////////////////////// + // + // Public methods + // + + public void Write(byte b) + { + if (this.m_curRecordEnd == this.m_curBlockSize) + { + this.FlushDataBlock(); + } + this.m_curDataBlock[this.m_curRecordEnd++] = b; + } + + public void Write(sbyte b) + { + if (this.m_curRecordEnd == this.m_curBlockSize) + { + this.FlushDataBlock(); + } + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)b; + } + + public void Write(bool b) + { + if (this.m_curRecordEnd == this.m_curBlockSize) + { + this.FlushDataBlock(); + } + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(b ? 1 : 0); + } + + public void Write(char ch) + { + if (this.m_curBlockSize - this.m_curRecordEnd < this.m_charMaxByteCount) + { + this.FlushDataBlock(); + } + + int numBytes = this.m_encoding.GetBytes(&ch, 1, this.m_curDataBlock + this.m_curRecordEnd, this.m_charMaxByteCount); + this.m_curRecordEnd += numBytes; + } + + public void Write(short val) + { + if (this.m_curBlockSize - this.m_curRecordEnd < 2) + { + this.FlushDataBlock(); + } + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)val; + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 8); + } + + public void Write(ushort val) + { + if (this.m_curBlockSize - this.m_curRecordEnd < 2) + { + this.FlushDataBlock(); + } + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)val; + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 8); + } + + public void Write(int val) + { + if (this.m_curBlockSize - this.m_curRecordEnd < 4) + { + this.FlushDataBlock(); + } + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)val; + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 8); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 16); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 24); + } + + public void WriteCompact(int val) + { + if (this.m_curBlockSize - this.m_curRecordEnd < 4) + { + this.FlushDataBlock(); + } + if (val < 0x80) + { + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)val; + } + else + { + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 24 | 0x80); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 16); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 8); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)val; + } + } + + internal static int CompactSize(int val) + { + return (val < 0x80) ? 1 : 4; + } + + private void Write(int val, int loc) + { + this.m_curDataBlock[loc++] = (byte)val; + this.m_curDataBlock[loc++] = (byte)(val >> 8); + this.m_curDataBlock[loc++] = (byte)(val >> 16); + this.m_curDataBlock[loc++] = (byte)(val >> 24); + } + + private void WriteCompact(int val, int compactSize, int loc) + { + if (compactSize == 1) + { + this.m_curDataBlock[loc++] = (byte)val; + } + else + { + this.m_curDataBlock[loc++] = (byte)(val >> 24 | 0x80); + this.m_curDataBlock[loc++] = (byte)(val >> 16); + this.m_curDataBlock[loc++] = (byte)(val >> 8); + this.m_curDataBlock[loc++] = (byte)val; + } + } + + public void Write(uint val) + { + if (this.m_curBlockSize - this.m_curRecordEnd < 4) + { + this.FlushDataBlock(); + } + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)val; + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 8); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 16); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 24); + } + + public void Write(long val) + { + if (this.m_curBlockSize - this.m_curRecordEnd < 8) + { + this.FlushDataBlock(); + } + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)val; + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 8); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 16); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 24); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 32); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 40); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 48); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 56); + } + + public void Write(ulong val) + { + if (this.m_curBlockSize - this.m_curRecordEnd < 8) + { + this.FlushDataBlock(); + } + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)val; + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 8); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 16); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 24); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 32); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 40); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 48); + this.m_curDataBlock[this.m_curRecordEnd++] = (byte)(val >> 56); + } + + public void Write(decimal val) + { + this.WriteRawBytes((byte*)&val, sizeof(decimal)); + } + + public void Write(float val) + { + uint tmpVal = *(uint*)&val; + this.Write(tmpVal); + } + + public void Write(double val) + { + ulong tmpVal = *(ulong*)&val; + this.Write(tmpVal); + } + + private const Int32 KindShift = 62; + + public void Write(DateTime val) + { + UInt64 tempVal = (UInt64)val.Ticks | (((UInt64)val.Kind) << KindShift); + this.Write(tempVal); + } + + public void Write(SqlDateTime val) + { + this.Write(val.DayTicks); + this.Write(val.TimeTicks); + } + + public void Write(Guid guid) + { + WriteRawBytes((byte*)&guid, sizeof(Guid)); + } + + public void Write(string val) + { + Int32 len = val.Length; + Int32 maxByteCount = this.m_encoding.GetMaxByteCount(len); + Int32 compactSize = CompactSize(maxByteCount); + + while (this.m_curBlockSize - this.m_curRecordEnd < (maxByteCount + 8)) + { + this.FlushDataBlock(); + } + this.WriteCompact(len); + int buffLoc = this.m_curRecordEnd; + this.m_curRecordEnd += compactSize; + int numBytes; + fixed (char* pVal = val) + { + numBytes = this.m_encoding.GetBytes(pVal, + len, + this.m_curDataBlock + this.m_curRecordEnd, + this.m_curBlockSize - this.m_curRecordEnd); + } + this.m_curRecordEnd += numBytes; + this.WriteCompact(numBytes, compactSize, buffLoc); + } + + public void WriteChars(char[] charBuffer, int offset, int charCount) + { + if (charBuffer == null) + { + throw new ArgumentNullException("charBuffer"); + } + if (offset < 0) + { + throw new ArgumentOutOfRangeException("offset"); + } + if (charCount < 0) + { + throw new ArgumentOutOfRangeException("charCount"); + } + if (charBuffer.Length < (offset + charCount)) + { + throw new ArgumentOutOfRangeException("charBuffer", + String.Format(SR.ArrayLengthVsCountAndOffset, + "charBuffer", offset + charCount, + "offset", "charCount")); + } + + Int32 maxByteCount = this.m_encoding.GetMaxByteCount(charCount); + + // if current block doesn't have enough space flush it and allocate a new one + while (this.m_curBlockSize - this.m_curRecordEnd < maxByteCount ) + { + this.FlushDataBlock(); + } + + int buffLoc = this.m_curRecordEnd; + int numBytes; + fixed (char* pVal = charBuffer ) + { + numBytes = this.m_encoding.GetBytes(pVal + offset, + charCount, + this.m_curDataBlock + this.m_curRecordEnd, + this.m_curBlockSize - this.m_curRecordEnd); + } + + this.m_curRecordEnd += numBytes; + } + + public void WriteBytes(byte[] byteBuffer, int offset, int byteCount) + { + if (byteBuffer == null) + { + throw new ArgumentNullException("byteBuffer"); + } + if (offset < 0) + { + throw new ArgumentOutOfRangeException("offset"); + } + if (byteCount < 0) + { + throw new ArgumentOutOfRangeException("byteCount"); + } + if (byteBuffer.Length < (offset + byteCount)) + { + throw new ArgumentOutOfRangeException("byteBuffer", + String.Format(SR.ArrayLengthVsCountAndOffset, + "byteBuffer", offset + byteCount, + "offset", "byteCount")); + } + while (this.m_curBlockSize - this.m_curRecordEnd < byteCount) + { + this.FlushDataBlock(); + } + fixed (byte* pBytes = byteBuffer) + { + HpcLinqUtil.memcpy(pBytes + offset, this.m_curDataBlock + this.m_curRecordEnd, byteCount); + } + this.m_curRecordEnd += byteCount; + } + + /// + /// Public helper to write from a caller provided byte* to the output stream. + /// This is mainly used to read preallocated fixed size, non-integer types (Guid, decimal etc). + /// + public void WriteRawBytes(byte* pBytes, Int32 numBytes) + { + while (this.m_curBlockSize - this.m_curRecordEnd < numBytes) + { + this.FlushDataBlock(); + } + HpcLinqUtil.memcpy(pBytes, this.m_curDataBlock + this.m_curRecordEnd, numBytes); + this.m_curRecordEnd += numBytes; + } + } +} + +namespace Microsoft.Research.DryadLinq.Internal +{ + // internal adapter class to make a HpcBinaryWriter work as a Stream + // this is needed to reuse Stream-based serialization code. + internal class HpcBinaryWriterToStreamAdapter : Stream + { + private HpcBinaryWriter m_dbw; + + internal HpcBinaryWriterToStreamAdapter(HpcBinaryWriter dbw) + { + m_dbw = dbw; + } + + public override bool CanRead + { + get { return false; } + } + + public override bool CanSeek + { + get { return false; } + } + + public override bool CanWrite + { + get { return true; } + } + + public override void Flush() + { + m_dbw.Flush(); + } + + public override long Length + { + get { return m_dbw.Length; } + } + + public override long Position + { + get { return m_dbw.Length; } + set { throw new DryadLinqException(HpcLinqErrorCode.SettingPositionNotSupported, + SR.SettingPositionNotSupported); } + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new DryadLinqException(HpcLinqErrorCode.ReadNotAllowed, SR.ReadNotAllowed); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new DryadLinqException(HpcLinqErrorCode.SeekNotSupported, SR.SeekNotSupported); + } + + public override void SetLength(long value) + { + throw new DryadLinqException(HpcLinqErrorCode.SetLengthNotSupported, + SR.SetLengthNotSupported); + } + + public override void Write(byte[] buffer, int offset, int count) + { + m_dbw.WriteBytes(buffer, offset, count); + } + + public override void WriteByte(byte value) + { + m_dbw.Write(value); + } + + public override void Close() + { + try + { + m_dbw.Close(); + } + finally + { + base.Dispose(true); + } + } + } +} diff --git a/LinqToDryad/DryadCodeGen.cs b/LinqToDryad/DryadCodeGen.cs new file mode 100644 index 0000000..4a962cb --- /dev/null +++ b/LinqToDryad/DryadCodeGen.cs @@ -0,0 +1,2376 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Text; +using System.IO; +using System.Reflection; +using System.Threading; +using System.Data.SqlTypes; +using System.Linq; +using System.Linq.Expressions; +using System.CodeDom; +using System.CodeDom.Compiler; +using Microsoft.CSharp; +using System.Diagnostics; +using Microsoft.Research.DryadLinq; + +//DryadCodeGen is public-internal so that DryadCodeGen.GetFactory is available. +namespace Microsoft.Research.DryadLinq.Internal +{ + using CodeStmtPair = Pair; + + // This class generates and compiles the managed code executed by HpcLinq. + // It creates a managed library (DLL) that gets used by HpcLinq vertices. + // + // For each type of HpcLinq vertex node, we need to call + // AddDryadCodeForType(node.OutputType); + // AddVertexMethod(node); + // This should generate all the code described in the note. + public class HpcLinqCodeGen + { + private const BindingFlags FieldFlags = BindingFlags.Instance|BindingFlags.Public|BindingFlags.NonPublic; + private const string TargetNamespace = "Microsoft.Research.DryadLinq"; + private const string ExtensionClassName = "HpcLinq__Extension"; + private const string VertexClassName = "HpcLinq__Vertex"; + private const string TargetDllName = "Microsoft.Research.DryadLinq_.dll"; + private const string VertexSourceFile = "Microsoft.Research.DryadLinq_.cs"; + private const string DummyExtensionSourceFile = "HpcLinqDummyExtension.cs"; + private const string DryadVertexParamName = "dvertexparam"; + private const string DebugHelperMethodName = "CheckVertexDebugRequest"; + private const string CopyResourcesMethodName = "CopyResources"; + internal const string DryadEnvName = "denv"; + + private static int s_uniqueId = -1; + private static int s_DryadLinqDllVersion = -1; + private static Dictionary s_TypeToInternalName; + private static Dictionary s_BuiltinTypeToReaderName; + private static Dictionary s_BuiltinTypeToSerializerName; + private static Dictionary s_TypeToFactory; + private static object s_codeGenLock = new Object(); + + internal readonly static CodeExpression ZeroExpr = new CodePrimitiveExpression(0); + internal readonly static CodeExpression OneExpr = new CodePrimitiveExpression(1); + internal readonly static CodeExpression NullExpr = new CodePrimitiveExpression(null); + internal readonly static CodeExpression DLVTypeExpr = new CodeTypeReferenceExpression("HpcLinqVertex"); + + static HpcLinqCodeGen() + { + // Initialize the mapping from type to its internal name + s_TypeToInternalName = new Dictionary(20); + s_TypeToInternalName.Add(typeof(byte), "Byte"); + s_TypeToInternalName.Add(typeof(sbyte), "SByte"); + s_TypeToInternalName.Add(typeof(bool), "Bool"); + s_TypeToInternalName.Add(typeof(char), "Char"); + s_TypeToInternalName.Add(typeof(short), "Short"); + s_TypeToInternalName.Add(typeof(ushort), "UShort"); + s_TypeToInternalName.Add(typeof(int), "Int32"); + s_TypeToInternalName.Add(typeof(uint), "UInt32"); + s_TypeToInternalName.Add(typeof(long), "Int64"); + s_TypeToInternalName.Add(typeof(ulong), "UInt64"); + s_TypeToInternalName.Add(typeof(float), "Float"); + s_TypeToInternalName.Add(typeof(decimal), "Decimal"); + s_TypeToInternalName.Add(typeof(double), "Double"); + s_TypeToInternalName.Add(typeof(DateTime), "DateTime"); + s_TypeToInternalName.Add(typeof(string), "String"); + s_TypeToInternalName.Add(typeof(LineRecord), "LineRecord"); + s_TypeToInternalName.Add(typeof(SqlDateTime), "SqlDateTime"); + s_TypeToInternalName.Add(typeof(Guid), "Guid"); + + // Initialize the mapping from builtin type to its read method name + s_BuiltinTypeToReaderName = new Dictionary(20); + s_BuiltinTypeToReaderName.Add(typeof(bool), "ReadBool"); + s_BuiltinTypeToReaderName.Add(typeof(char), "ReadChar"); + s_BuiltinTypeToReaderName.Add(typeof(sbyte),"ReadSByte"); + s_BuiltinTypeToReaderName.Add(typeof(byte), "ReadUByte"); + s_BuiltinTypeToReaderName.Add(typeof(short), "ReadInt16"); + s_BuiltinTypeToReaderName.Add(typeof(ushort), "ReadUInt16"); + s_BuiltinTypeToReaderName.Add(typeof(int), "ReadInt32"); + s_BuiltinTypeToReaderName.Add(typeof(uint), "ReadUInt32"); + s_BuiltinTypeToReaderName.Add(typeof(long), "ReadInt64"); + s_BuiltinTypeToReaderName.Add(typeof(ulong), "ReadUInt64"); + s_BuiltinTypeToReaderName.Add(typeof(float), "ReadSingle"); + s_BuiltinTypeToReaderName.Add(typeof(double), "ReadDouble"); + s_BuiltinTypeToReaderName.Add(typeof(decimal), "ReadDecimal"); + s_BuiltinTypeToReaderName.Add(typeof(DateTime), "ReadDateTime"); + s_BuiltinTypeToReaderName.Add(typeof(string), "ReadString"); + s_BuiltinTypeToReaderName.Add(typeof(SqlDateTime), "ReadSqlDateTime"); + s_BuiltinTypeToReaderName.Add(typeof(Guid), "ReadGuid"); + + // Initialize the mapping from builtin type to its serializer class name + s_BuiltinTypeToSerializerName = new Dictionary(20); + s_BuiltinTypeToSerializerName.Add(typeof(byte), "ByteHpcSerializer"); + s_BuiltinTypeToSerializerName.Add(typeof(sbyte), "SByteHpcSerializer"); + s_BuiltinTypeToSerializerName.Add(typeof(bool), "BoolHpcSerializer"); + s_BuiltinTypeToSerializerName.Add(typeof(char), "CharHpcSerializer"); + s_BuiltinTypeToSerializerName.Add(typeof(short), "Int16HpcSerializer"); + s_BuiltinTypeToSerializerName.Add(typeof(ushort), "UInt16HpcSerializer"); + s_BuiltinTypeToSerializerName.Add(typeof(int), "Int32HpcSerializer"); + s_BuiltinTypeToSerializerName.Add(typeof(uint), "UInt32HpcSerializer"); + s_BuiltinTypeToSerializerName.Add(typeof(long), "Int64HpcSerializer"); + s_BuiltinTypeToSerializerName.Add(typeof(ulong), "UInt64HpcSerializer"); + s_BuiltinTypeToSerializerName.Add(typeof(float), "SingleHpcSerializer"); + s_BuiltinTypeToSerializerName.Add(typeof(double), "DoubleHpcSerializer"); + s_BuiltinTypeToSerializerName.Add(typeof(decimal), "DecimalHpcSerializer"); + s_BuiltinTypeToSerializerName.Add(typeof(DateTime), "DateTimeHpcSerializer"); + s_BuiltinTypeToSerializerName.Add(typeof(string), "StringHpcSerializer"); + s_BuiltinTypeToSerializerName.Add(typeof(Guid), "GuidHpcSerializer"); + s_BuiltinTypeToSerializerName.Add(typeof(SqlDateTime), "SqlDateTimeHpcSerializer"); + + // Initialize the mapping from type to its factory + s_TypeToFactory = new Dictionary(20); + s_TypeToFactory.Add(typeof(byte), new HpcLinqFactoryByte()); + s_TypeToFactory.Add(typeof(sbyte), new HpcLinqFactorySByte()); + s_TypeToFactory.Add(typeof(bool), new HpcLinqFactoryBool()); + s_TypeToFactory.Add(typeof(char), new HpcLinqFactoryChar()); + s_TypeToFactory.Add(typeof(short), new HpcLinqFactoryShort()); + s_TypeToFactory.Add(typeof(ushort), new HpcLinqFactoryUShort()); + s_TypeToFactory.Add(typeof(int), new HpcLinqFactoryInt32()); + s_TypeToFactory.Add(typeof(uint), new HpcLinqFactoryUInt32()); + s_TypeToFactory.Add(typeof(long), new HpcLinqFactoryInt64()); + s_TypeToFactory.Add(typeof(ulong), new HpcLinqFactoryUInt64()); + s_TypeToFactory.Add(typeof(float), new HpcLinqFactoryFloat()); + s_TypeToFactory.Add(typeof(decimal), new HpcLinqFactoryDecimal()); + s_TypeToFactory.Add(typeof(double), new HpcLinqFactoryDouble()); + s_TypeToFactory.Add(typeof(DateTime), new HpcLinqFactoryDateTime()); + s_TypeToFactory.Add(typeof(string), new HpcLinqFactoryString()); + s_TypeToFactory.Add(typeof(LineRecord), new HpcLinqFactoryLineRecord()); + s_TypeToFactory.Add(typeof(SqlDateTime), new HpcLinqFactorySqlDateTime()); + s_TypeToFactory.Add(typeof(Guid), new HpcLinqFactoryGuid()); + } + + private string m_generatedVertexDllPath; // only set if vertex code and assembly were both created successfully + private Assembly m_loadedVertexAssembly; // only set if the caller requests a load (the only case is GetFactory, which is used by enumeration) + + private CodeCompileUnit m_dryadLinqUnit; + private CodeNamespace m_dryadCodeSpace; + private CodeTypeDeclaration m_dryadExtensionClass; + private CodeTypeConstructor m_extensionStaticCtor; + private CodeTypeDeclaration m_dryadVertexClass; + private HashSet m_dryadDataTypes; + private HashSet m_serializationDatatypes; + private Dictionary m_fieldToStaticName; + private HashSet m_staticFieldDefined; + private Dictionary m_typeToSerializerName; + private Dictionary m_anonymousTypeToName; + private Dictionary m_nameToAlias; + private HpcLinqContext m_context; + private VertexCodeGen m_vertexCodeGen; + + internal HpcLinqCodeGen(HpcLinqContext context, VertexCodeGen vertexCodeGen) + { + this.m_context = context; + this.m_vertexCodeGen = vertexCodeGen; + this.m_loadedVertexAssembly = null; + this.m_dryadLinqUnit = new CodeCompileUnit(); + + // Create a namespace + this.m_dryadCodeSpace = new CodeNamespace(TargetNamespace); + this.m_dryadCodeSpace.Imports.Add(new CodeNamespaceImport("System")); + this.m_dryadCodeSpace.Imports.Add(new CodeNamespaceImport("System.Collections")); + this.m_dryadCodeSpace.Imports.Add(new CodeNamespaceImport("System.Collections.Generic")); + this.m_dryadCodeSpace.Imports.Add(new CodeNamespaceImport("System.Text")); + this.m_dryadCodeSpace.Imports.Add(new CodeNamespaceImport("System.Linq")); + this.m_dryadCodeSpace.Imports.Add(new CodeNamespaceImport("System.Linq.Expressions")); + this.m_dryadCodeSpace.Imports.Add(new CodeNamespaceImport("System.Diagnostics")); + this.m_dryadCodeSpace.Imports.Add(new CodeNamespaceImport("System.Runtime.Serialization")); + this.m_dryadCodeSpace.Imports.Add(new CodeNamespaceImport("System.Data.SqlTypes")); + this.m_dryadCodeSpace.Imports.Add(new CodeNamespaceImport("System.Data.Linq")); + this.m_dryadCodeSpace.Imports.Add(new CodeNamespaceImport("System.Data.Linq.Mapping")); + this.m_dryadCodeSpace.Imports.Add(new CodeNamespaceImport("Microsoft.Research.DryadLinq")); + this.m_dryadCodeSpace.Imports.Add(new CodeNamespaceImport("Microsoft.Research.DryadLinq.Internal")); + + this.m_dryadLinqUnit.Namespaces.Add(this.m_dryadCodeSpace); + + // Create the class for all the Dryad extension methods + this.m_dryadExtensionClass = new CodeTypeDeclaration(ExtensionClassName); + this.m_dryadExtensionClass.IsClass = true; + this.m_dryadExtensionClass.IsPartial = true; + this.m_dryadExtensionClass.TypeAttributes = TypeAttributes.Public; + this.m_dryadCodeSpace.Types.Add(this.m_dryadExtensionClass); + + // Create the static constructor for the vertex extension class + this.m_extensionStaticCtor = new CodeTypeConstructor(); + this.m_dryadExtensionClass.Members.Add(this.m_extensionStaticCtor); + + // Create the class for all the Dryad vertex methods + this.m_dryadVertexClass = new CodeTypeDeclaration(VertexClassName); + this.m_dryadVertexClass.IsClass = true; + this.m_dryadVertexClass.TypeAttributes = TypeAttributes.Public | TypeAttributes.Sealed; + this.m_dryadCodeSpace.Types.Add(this.m_dryadVertexClass); + this.AddCopyResourcesMethod(); + + // The set of input/output channel datatypes + this.m_dryadDataTypes = new HashSet(); + this.m_dryadDataTypes.Add(typeof(byte)); + this.m_dryadDataTypes.Add(typeof(sbyte)); + this.m_dryadDataTypes.Add(typeof(bool)); + this.m_dryadDataTypes.Add(typeof(char)); + this.m_dryadDataTypes.Add(typeof(short)); + this.m_dryadDataTypes.Add(typeof(ushort)); + this.m_dryadDataTypes.Add(typeof(int)); + this.m_dryadDataTypes.Add(typeof(uint)); + this.m_dryadDataTypes.Add(typeof(long)); + this.m_dryadDataTypes.Add(typeof(ulong)); + this.m_dryadDataTypes.Add(typeof(float)); + this.m_dryadDataTypes.Add(typeof(decimal)); + this.m_dryadDataTypes.Add(typeof(double)); + this.m_dryadDataTypes.Add(typeof(DateTime)); + this.m_dryadDataTypes.Add(typeof(string)); + this.m_dryadDataTypes.Add(typeof(LineRecord)); + this.m_dryadDataTypes.Add(typeof(SqlDateTime)); + this.m_dryadDataTypes.Add(typeof(Guid)); + + // The set of datatypes we have added serialization methods + this.m_serializationDatatypes = new HashSet(); + + this.m_fieldToStaticName = new Dictionary(); + this.m_staticFieldDefined = new HashSet(); + this.m_typeToSerializerName = new Dictionary(); + this.m_anonymousTypeToName = new Dictionary(); + this.m_nameToAlias = new Dictionary(); + } + + private static string GetBuiltinReaderName(Type type) + { + string readerName = null; + s_BuiltinTypeToReaderName.TryGetValue(type, out readerName); + return readerName; + } + + internal VertexCodeGen VertexCodeGen + { + get { return this.m_vertexCodeGen; } + set { this.m_vertexCodeGen = value; } + } + + internal static string VertexClassFullName + { + get { return TargetNamespace + "." + VertexClassName; } + } + + internal static string MakeUniqueName(string name) + { + return name + "__" + Interlocked.Increment(ref s_uniqueId); + } + + internal static string MakeName(Type type) + { + if (!s_TypeToInternalName.ContainsKey(type)) + { + string name = MakeUniqueName("Type"); + s_TypeToInternalName.Add(type, name); + return name; + } + return s_TypeToInternalName[type]; + } + + internal static string DryadReaderClassName(Type type) + { + return "HpcRecordReader" + MakeName(type); + } + + internal static string DryadWriterClassName(Type type) + { + return "HpcRecordWriter" + MakeName(type); + } + + internal static string HpcLinqFactoryClassName(Type type) + { + return "HpcLinqFactory" + MakeName(type); + } + + internal static string AnonymousClassName(Type type) + { + return "Anonymous" + MakeName(type); + } + + internal static string GetBuiltInHpcSerializer(Type type) + { + if (s_BuiltinTypeToSerializerName.ContainsKey(type)) + { + return s_BuiltinTypeToSerializerName[type]; + } + return null; + } + + internal Dictionary AnonymousTypeToName + { + get { return this.m_anonymousTypeToName; } + } + + // Converts long type names into an alias in order to make vertex code more readable + private string MakeTypeNameAlias(string fullProcessedTypeName) + { + // no change necessary if the full type name is short enough + if (fullProcessedTypeName.Length <= 60) + { + return fullProcessedTypeName; + } + + int aliasLen = fullProcessedTypeName.Length; + while (aliasLen > 2 && + fullProcessedTypeName[aliasLen - 1] == ']' && + fullProcessedTypeName[aliasLen - 2] == '[') + { + aliasLen -= 2; + } + string typeNamePrefix = fullProcessedTypeName; + if (aliasLen < fullProcessedTypeName.Length) + { + typeNamePrefix = fullProcessedTypeName.Substring(0, aliasLen); + } + string typeNameAlias; + if (this.m_nameToAlias.ContainsKey(typeNamePrefix)) + { + typeNameAlias = this.m_nameToAlias[typeNamePrefix]; + } + else + { + typeNameAlias = MakeUniqueName("Alias"); + this.m_nameToAlias[typeNamePrefix] = typeNameAlias; + } + this.m_dryadCodeSpace.Imports.Add(new CodeNamespaceImport(typeNameAlias + " = " + typeNamePrefix)); + + string newTypeName = typeNameAlias; + if (aliasLen < fullProcessedTypeName.Length) + { + newTypeName += fullProcessedTypeName.Substring(aliasLen); + } + return typeNameAlias; + } + + internal string GetStaticFactoryName(Type type) + { + string fieldName = "Factory" + MakeName(type); + if (!this.m_staticFieldDefined.Contains(fieldName)) + { + this.m_staticFieldDefined.Add(fieldName); + string factoryName = HpcLinqFactoryClassName(type); + CodeMemberField factoryField = new CodeMemberField(factoryName, fieldName); + factoryField.Attributes = MemberAttributes.Assembly | MemberAttributes.Static; + factoryField.InitExpression = new CodeObjectCreateExpression(factoryName); + this.m_dryadExtensionClass.Members.Add(factoryField); + } + return ExtensionClassName + "." + fieldName; + } + + internal string GetStaticSerializerName(Type type) + { + string fieldName = "Serializer" + MakeName(type); + if (!this.m_staticFieldDefined.Contains(fieldName)) + { + this.m_staticFieldDefined.Add(fieldName); + string serializerName = this.AddSerializerClass(type); + CodeMemberField serializerField = new CodeMemberField(serializerName, fieldName); + serializerField.Attributes = MemberAttributes.Assembly | MemberAttributes.Static; + serializerField.InitExpression = new CodeObjectCreateExpression(serializerName); + this.m_dryadExtensionClass.Members.Add(serializerField); + } + return ExtensionClassName + "." + fieldName; + } + + private string GetterFieldName(FieldInfo finfo) + { + this.AddFieldAccessDelegates(finfo); + return "get_" + this.m_fieldToStaticName[finfo]; + } + + private string SetterFieldName(FieldInfo finfo) + { + this.AddFieldAccessDelegates(finfo); + return "set_" + this.m_fieldToStaticName[finfo]; + } + + private void AddFieldAccessDelegates(FieldInfo finfo) + { + if (!this.m_fieldToStaticName.ContainsKey(finfo)) + { + string fieldName = HpcLinqUtil.MakeValidId(TypeSystem.FieldName(finfo.Name)); + this.m_fieldToStaticName[finfo] = MakeUniqueName(fieldName); + CodeTypeReferenceExpression typeExpr = new CodeTypeReferenceExpression("CodeGenHelper"); + CodeTypeReference cRef = new CodeTypeReference(finfo.DeclaringType); + CodeTypeReference fRef = new CodeTypeReference(finfo.FieldType); + string getterName, setterName; + Type getterType, setterType; + + if (finfo.DeclaringType.IsValueType) + { + getterName = "GetStructField"; + setterName = "SetStructField"; + getterType = typeof(GetStructFieldDelegate<,>); + setterType = typeof(SetStructFieldDelegate<,>); + } + else + { + getterName = "GetObjField"; + setterName = "SetObjField"; + getterType = typeof(GetObjFieldDelegate<,>); + setterType = typeof(SetObjFieldDelegate<,>); + } + getterType = getterType.MakeGenericType(finfo.DeclaringType, finfo.FieldType); + setterType = setterType.MakeGenericType(finfo.DeclaringType, finfo.FieldType); + + CodeMemberField getField = new CodeMemberField(getterType, this.GetterFieldName(finfo)); + getField.Attributes = MemberAttributes.Assembly | MemberAttributes.Static; + getField.InitExpression = new CodeMethodInvokeExpression( + new CodeMethodReferenceExpression(typeExpr, getterName, cRef, fRef), + new CodePrimitiveExpression(finfo.Name)); + this.m_dryadExtensionClass.Members.Add(getField); + + CodeMemberField setField = new CodeMemberField(setterType, this.SetterFieldName(finfo)); + setField.Attributes = MemberAttributes.Assembly | MemberAttributes.Static; + setField.InitExpression = new CodeMethodInvokeExpression( + new CodeMethodReferenceExpression(typeExpr, setterName, cRef, fRef), + new CodePrimitiveExpression(finfo.Name)); + this.m_dryadExtensionClass.Members.Add(setField); + } + } + + // Copy user resources to the vertex working directory + private void AddCopyResourcesMethod() + { + CodeMemberMethod copyResourcesMethod = new CodeMemberMethod(); + copyResourcesMethod.Name = CopyResourcesMethodName; + copyResourcesMethod.Attributes = MemberAttributes.Public | MemberAttributes.Static; + + IEnumerable resourcesToExclude = this.m_context.Configuration.ResourcesToRemove; + foreach (string res in this.m_context.Configuration.ResourcesToAdd) + { + if (!resourcesToExclude.Contains(res)) + { + string fname = Path.GetFileName(res); + string stmt = @"System.IO.File.Copy(@""" + Path.Combine("..", fname) + "\", @\"" + fname + "\")"; + CodeExpression stmtExpr = new CodeSnippetExpression(stmt); + copyResourcesMethod.Statements.Add(new CodeExpressionStatement(stmtExpr)); + } + } + this.m_dryadVertexClass.Members.Add(copyResourcesMethod); + } + + internal string AddDecompositionInitializer(Type decomposerType, Expression stateExpr) + { + string decomposerTypeName = TypeSystem.TypeName(decomposerType); + string decomposerFieldName = MakeUniqueName("decomposer"); + CodeMemberField decomposerField = new CodeMemberField(decomposerTypeName, decomposerFieldName); + decomposerField.Attributes = MemberAttributes.Assembly | MemberAttributes.Static; + this.m_dryadExtensionClass.Members.Add(decomposerField); + + CodeStatement initStmt1 = new CodeAssignStatement( + new CodeVariableReferenceExpression(decomposerFieldName), + new CodeObjectCreateExpression(decomposerTypeName)); + this.m_extensionStaticCtor.Statements.Add(initStmt1); + + MethodInfo initInfo = decomposerType.GetMethod("Initialize"); + ParameterExpression decomposer = Expression.Parameter(decomposerType, decomposerFieldName); + Expression initExpr = Expression.Call(decomposer, initInfo, stateExpr); + CodeStatement initStmt2 = new CodeExpressionStatement(this.MakeExpression(initExpr)); + this.m_extensionStaticCtor.Statements.Add(initStmt2); + + return ExtensionClassName + "." + decomposerFieldName; + } + + internal void AddDryadCodeForType(Type type) + { + if (!this.m_dryadDataTypes.Contains(type)) + { + this.m_dryadDataTypes.Add(type); + this.AddAnonymousClass(type); + this.AddSerializerClass(type); + this.AddReaderClass(type); + this.AddWriterClass(type); + this.AddFactoryClass(type); + } + } + + // Add a new HpcRecordReader subclass for a type + internal void AddReaderClass(Type type) + { + Type baseClass = typeof(HpcRecordBinaryReader<>).MakeGenericType(type); + string baseClassName = TypeSystem.TypeName(baseClass, this.AnonymousTypeToName); + string className = DryadReaderClassName(type); + CodeTypeDeclaration readerClass = new CodeTypeDeclaration(className + " : " + baseClassName); + this.m_dryadCodeSpace.Types.Add(readerClass); + readerClass.IsClass = true; + readerClass.TypeAttributes = TypeAttributes.Public | TypeAttributes.Sealed; + + // Add constructors: + string conString = " public " + className + "(HpcBinaryReader reader) : base(reader) { }"; + CodeTypeMember con = new CodeSnippetTypeMember(conString); + readerClass.Members.Add(con); + + // Add method ReadRecord: + string serializerName = GetStaticSerializerName(type); + string typeName = TypeSystem.TypeName(type, this.AnonymousTypeToName); + StringBuilder methodBuilder = new StringBuilder(); + methodBuilder.AppendLine(" protected override bool ReadRecord(ref " + typeName + " rec)"); + methodBuilder.AppendLine(" {"); + methodBuilder.AppendLine(" if (!this.IsReaderAtEndOfStream())"); + methodBuilder.AppendLine(" {"); + if (AttributeSystem.RecordCanBeNull(this.m_context, type)) + { + methodBuilder.AppendLine(" if (!this.m_reader.ReadBool())"); + methodBuilder.AppendLine(" {"); + methodBuilder.AppendLine(" rec = " + serializerName + ".Read(this.m_reader);"); + methodBuilder.AppendLine(" }"); + } + else + { + methodBuilder.AppendLine(" rec = " + serializerName + ".Read(this.m_reader);"); + } + methodBuilder.AppendLine(" return true;"); + methodBuilder.AppendLine(" }"); + methodBuilder.AppendLine(" return false;"); + methodBuilder.AppendLine(" }"); + CodeTypeMember readRecordMethod = new CodeSnippetTypeMember(methodBuilder.ToString()); + readerClass.Members.Add(readRecordMethod); + } + + // Add a new HpcRecordWriter subclass for a type + internal void AddWriterClass(Type type) + { + Type baseClass = typeof(HpcRecordBinaryWriter<>).MakeGenericType(type); + string baseClassName = TypeSystem.TypeName(baseClass, this.AnonymousTypeToName); + string className = DryadWriterClassName(type); + CodeTypeDeclaration writerClass = new CodeTypeDeclaration(className + " : " + baseClassName); + this.m_dryadCodeSpace.Types.Add(writerClass); + writerClass.IsClass = true; + writerClass.TypeAttributes = TypeAttributes.Public | TypeAttributes.Sealed; + + // Add constructors: + string conString = " public " + className + "(HpcBinaryWriter writer) : base(writer) { }"; + CodeTypeMember con = new CodeSnippetTypeMember(conString); + writerClass.Members.Add(con); + + // Add method WriteRecord: + string serializerName = GetStaticSerializerName(type); + string typeName = TypeSystem.TypeName(type, this.AnonymousTypeToName); + StringBuilder methodBuilder = new StringBuilder(); + methodBuilder.AppendLine(" protected override void WriteRecord(" + typeName + " rec)"); + methodBuilder.AppendLine(" {"); + if (AttributeSystem.RecordCanBeNull(m_context, type)) + { + methodBuilder.AppendLine(" bool isNull = Object.ReferenceEquals(rec, null);"); + methodBuilder.AppendLine(" this.m_writer.Write(isNull);"); + methodBuilder.AppendLine(" if (!isNull)"); + methodBuilder.AppendLine(" {"); + methodBuilder.AppendLine(" " + serializerName + ".Write(this.m_writer, rec);"); + methodBuilder.AppendLine(" }"); + } + else + { + methodBuilder.AppendLine(" " + serializerName + ".Write(this.m_writer, rec);"); + } + methodBuilder.AppendLine(" this.CompleteWriteRecord();"); + methodBuilder.AppendLine(" }"); + CodeTypeMember writeRecordMethod = new CodeSnippetTypeMember(methodBuilder.ToString()); + writerClass.Members.Add(writeRecordMethod); + } + + // Add a new HpcLinqFactory subclass for a type + internal void AddFactoryClass(Type type) + { + Type baseClass = typeof(HpcLinqFactory<>).MakeGenericType(type); + string baseClassName = TypeSystem.TypeName(baseClass, this.AnonymousTypeToName); + CodeTypeDeclaration factoryClass = new CodeTypeDeclaration(HpcLinqFactoryClassName(type) + " : " + baseClassName); + this.m_dryadCodeSpace.Types.Add(factoryClass); + factoryClass.IsClass = true; + factoryClass.TypeAttributes = TypeAttributes.Public; + + // Add method MakeReader(IntPtr handle, UInt32 port): + Type returnType = typeof(HpcRecordReader<>).MakeGenericType(type); + string returnTypeName = TypeSystem.TypeName(returnType, this.AnonymousTypeToName); + StringBuilder mb1 = new StringBuilder(); + mb1.AppendLine(" public override " + returnTypeName + " MakeReader(System.IntPtr handle, uint port)"); + mb1.AppendLine(" {"); + mb1.AppendLine(" return new " + DryadReaderClassName(type) + "(Microsoft.Research.DryadLinq.Internal.HpcLinqVertexEnv.MakeBinaryReader(handle, port));"); + mb1.AppendLine(" }"); + CodeTypeMember readerMethod1 = new CodeSnippetTypeMember(mb1.ToString()); + factoryClass.Members.Add(readerMethod1); + + // Add method MakeReader(NativeBlockStream stream): + StringBuilder mb3 = new StringBuilder(); + mb3.AppendLine(" public override " + returnTypeName + " MakeReader(Microsoft.Research.DryadLinq.Internal.NativeBlockStream stream)"); + mb3.AppendLine(" {"); + mb3.AppendLine(" return new " + DryadReaderClassName(type) + "(Microsoft.Research.DryadLinq.Internal.HpcLinqVertexEnv.MakeBinaryReader(stream));"); + mb3.AppendLine(" }"); + CodeTypeMember readerMethod3 = new CodeSnippetTypeMember(mb3.ToString()); + factoryClass.Members.Add(readerMethod3); + + // Add method MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize): + returnType = typeof(HpcRecordWriter<>).MakeGenericType(type); + returnTypeName = TypeSystem.TypeName(returnType, this.AnonymousTypeToName); + StringBuilder mb4 = new StringBuilder(); + mb4.AppendLine(" public override " + returnTypeName + " MakeWriter(System.IntPtr handle, uint port, int buffSize)"); + mb4.AppendLine(" {"); + mb4.AppendLine(" return new " + DryadWriterClassName(type) + "(Microsoft.Research.DryadLinq.Internal.HpcLinqVertexEnv.MakeBinaryWriter(handle, port, buffSize));"); + mb4.AppendLine(" }"); + CodeTypeMember writerMethod1 = new CodeSnippetTypeMember(mb4.ToString()); + factoryClass.Members.Add(writerMethod1); + + // Add method MakeWriter(NativeBlockStream stream): + StringBuilder mb6 = new StringBuilder(); + mb6.AppendLine(" public override " + returnTypeName + " MakeWriter(Microsoft.Research.DryadLinq.Internal.NativeBlockStream stream)"); + mb6.AppendLine(" {"); + mb6.AppendLine(" return new " + DryadWriterClassName(type) + "(Microsoft.Research.DryadLinq.Internal.HpcLinqVertexEnv.MakeBinaryWriter(stream));"); + mb6.AppendLine(" }"); + CodeTypeMember writerMethod3 = new CodeSnippetTypeMember(mb6.ToString()); + factoryClass.Members.Add(writerMethod3); + } + + // Add an anonymous class + internal bool AddAnonymousClass(Type type) + { + if (!TypeSystem.IsAnonymousType(type)) return false; + if (this.m_anonymousTypeToName.ContainsKey(type)) return true; + + string className = AnonymousClassName(type); + this.m_anonymousTypeToName.Add(type, className); + + CodeTypeDeclaration anonymousClass = new CodeTypeDeclaration(className); + anonymousClass.IsClass = true; + anonymousClass.TypeAttributes = TypeAttributes.Public; + + // Add the fields, the constructor, and properties: + CodeConstructor con = new CodeConstructor(); + con.Attributes = MemberAttributes.Public | MemberAttributes.Final; + PropertyInfo[] props = type.GetProperties(); + System.Array.Sort(props, (x, y) => x.MetadataToken.CompareTo(y.MetadataToken)); + string[] fieldNames = new string[props.Length]; + for (int i = 0; i < props.Length; i++) + { + fieldNames[i] = "_" + props[i].Name; + CodeParameterDeclarationExpression paramExpr; + CodeMemberField memberField; + if (this.AddAnonymousClass(props[i].PropertyType)) + { + string typeName = this.AnonymousTypeToName[props[i].PropertyType]; + memberField = new CodeMemberField(typeName, fieldNames[i]); + paramExpr = new CodeParameterDeclarationExpression(typeName, props[i].Name); + } + else + { + memberField = new CodeMemberField(props[i].PropertyType, fieldNames[i]); + paramExpr = new CodeParameterDeclarationExpression(props[i].PropertyType, props[i].Name); + } + memberField.Attributes = MemberAttributes.Public; + anonymousClass.Members.Add(memberField); + con.Parameters.Add(paramExpr); + CodeExpression fieldExpr = new CodeFieldReferenceExpression( + new CodeThisReferenceExpression(), fieldNames[i]); + con.Statements.Add(new CodeAssignStatement( + fieldExpr, new CodeVariableReferenceExpression(paramExpr.Name))); + + CodeMemberProperty p = new CodeMemberProperty(); + p.Attributes = MemberAttributes.Public | MemberAttributes.Final; + p.Name = props[i].Name; + p.Type = paramExpr.Type; + p.GetStatements.Add(new CodeMethodReturnStatement(fieldExpr)); + anonymousClass.Members.Add(p); + } + anonymousClass.Members.Add(con); + + // Add Equals method: + CodeMemberMethod equalsMethod = new CodeMemberMethod(); + equalsMethod.Attributes = MemberAttributes.Public | MemberAttributes.Override; + equalsMethod.Name = "Equals"; + equalsMethod.Parameters.Add(new CodeParameterDeclarationExpression("Object", "obj")); + equalsMethod.ReturnType = new CodeTypeReference(typeof(bool)); + + CodeExpression initExpr = new CodeSnippetExpression("obj as " + className); + equalsMethod.Statements.Add( + new CodeVariableDeclarationStatement(className, "myObj", initExpr)); + CodeStatement ifStmt = new CodeConditionStatement( + new CodeSnippetExpression("myObj == null"), + new CodeMethodReturnStatement(new CodePrimitiveExpression(false))); + equalsMethod.Statements.Add(ifStmt); + string equalsCode = ""; + for (int i = 0; i < props.Length; i++) + { + string fieldTypeName; + // we must use the proxy-type for anonymous-types. + if (m_anonymousTypeToName.ContainsKey(props[i].PropertyType)) + { + fieldTypeName = m_anonymousTypeToName[props[i].PropertyType]; + } + else + { + fieldTypeName = TypeSystem.TypeName(props[i].PropertyType); + } + + if (i > 0) equalsCode += " && "; + equalsCode += String.Format("EqualityComparer<{0}>.Default.Equals(this.{1}, myObj.{1})", + fieldTypeName, props[i].Name); + } + CodeExpression returnExpr = new CodeSnippetExpression(equalsCode); + equalsMethod.Statements.Add(new CodeMethodReturnStatement(returnExpr)); + anonymousClass.Members.Add(equalsMethod); + + // Add GetHashCode method: + CodeMemberMethod getHashCodeMethod = new CodeMemberMethod(); + getHashCodeMethod.Attributes = MemberAttributes.Public | MemberAttributes.Override; + getHashCodeMethod.Name = "GetHashCode"; + getHashCodeMethod.ReturnType = new CodeTypeReference(typeof(int)); + + CodeVariableDeclarationStatement + hashDecl = new CodeVariableDeclarationStatement(typeof(int), "num", ZeroExpr); + getHashCodeMethod.Statements.Add(hashDecl); + + CodeExpression numExpr = new CodeArgumentReferenceExpression(hashDecl.Name); + for (int i = 0; i < props.Length; i++) + { + if (props[i].PropertyType.IsValueType) + { + CodeExpression hashExpr = new CodeSnippetExpression( + "(-1521134295 * num) + this." + props[i].Name + ".GetHashCode()"); + getHashCodeMethod.Statements.Add(new CodeAssignStatement(numExpr, hashExpr)); + } + else + { + CodeExpression hashExpr = new CodeSnippetExpression( + String.Format("(-1521134295 * num) + (this.{0} != null ? this.{0}.GetHashCode() : 0)", + props[i].Name)); + getHashCodeMethod.Statements.Add(new CodeAssignStatement(numExpr, hashExpr)); + } + } + getHashCodeMethod.Statements.Add(new CodeMethodReturnStatement(numExpr)); + anonymousClass.Members.Add(getHashCodeMethod); + + // Add ToString method: + CodeMemberMethod toStringMethod = new CodeMemberMethod(); + toStringMethod.Attributes = MemberAttributes.Public | MemberAttributes.Override; + toStringMethod.Name = "ToString"; + toStringMethod.ReturnType = new CodeTypeReference(typeof(string)); + StringBuilder toStringCode = new StringBuilder(); + toStringCode.Append("\"{ \""); + for (int i = 0; i < props.Length; i++) + { + if (i > 0) toStringCode.Append(" + \", \""); + toStringCode.Append(" + \""); + toStringCode.Append(props[i].Name); + toStringCode.Append(" = \" + this."); + toStringCode.Append(props[i].Name); + toStringCode.Append(".ToString()"); + } + toStringCode.Append(" + \" }\""); + returnExpr = new CodeSnippetExpression(toStringCode.ToString()); + toStringMethod.Statements.Add(new CodeMethodReturnStatement(returnExpr)); + anonymousClass.Members.Add(toStringMethod); + + this.m_dryadCodeSpace.Types.Add(anonymousClass); + return true; + } + + private CodeStatement[] MakeReadMethodBody(Type type) + { + List statements = new List(); + CodeExpression objExpr = new CodeArgumentReferenceExpression("obj"); + if (type.IsArray) + { + if (type.GetElementType() == typeof(object)) + { + throw new DryadLinqException(HpcLinqErrorCode.CannotHandleObjectFields, + String.Format(SR.CannotHandleObjectFields, type.FullName)); + } + + // Generate obj = new MyType[reader.ReadInt32()] + int rank = type.GetArrayRank(); + Type baseType = type.GetElementType(); + while (baseType.IsArray) + { + baseType = baseType.GetElementType(); + } + string[] lenNames = new string[rank]; + for (int i = 0; i < rank; i++) + { + lenNames[i] = MakeUniqueName("len"); + CodeExpression lenExpr = new CodeSnippetExpression("reader.ReadInt32()"); + var lenStmt = new CodeVariableDeclarationStatement(typeof(int), lenNames[i], lenExpr); + statements.Add(lenStmt); + } + string newCallString = "new " + TypeSystem.TypeName(baseType, this.AnonymousTypeToName); + newCallString += "["; + for (int i = 0; i < rank; i++) + { + if (i != 0) + { + newCallString += ","; + } + newCallString += lenNames[i]; + } + newCallString += "]"; + + Type elemType = type.GetElementType(); + while (elemType.IsArray) + { + int elemRank = elemType.GetArrayRank(); + newCallString += "["; + for (int i = 1; i < elemRank; i++) + { + newCallString += ','; + } + newCallString += "]"; + elemType = elemType.GetElementType(); + } + CodeExpression newCall = new CodeSnippetExpression(newCallString); + statements.Add(new CodeVariableDeclarationStatement(type, "obj", newCall)); + + // Generate reading code + if (type.GetElementType().IsPrimitive) + { + // Use a single ReadRawBytes for primitive array + string lenStr = "sizeof(" + type.GetElementType() + ")"; + for (int i = 0; i < rank; i++) + { + lenStr += "*obj.GetLength(" + i + ")"; + } + string readBytes = " unsafe { fixed (void *p = obj) reader.ReadRawBytes((byte*)p, " + lenStr + "); }"; + statements.Add(new CodeSnippetStatement(readBytes)); + } + else + { + if (StaticConfig.AllowNullArrayElements && !type.GetElementType().IsValueType) + { + CodeExpression bvReadExpr = new CodeSnippetExpression("BitVector.Read(reader)"); + CodeStatement stmt = new CodeVariableDeclarationStatement(typeof(BitVector), "bv", bvReadExpr); + statements.Add(stmt); + } + + CodeVariableReferenceExpression[] indexExprs = new CodeVariableReferenceExpression[lenNames.Length]; + for (int i = 0; i < lenNames.Length; i++) + { + indexExprs[i] = new CodeVariableReferenceExpression("i" + i); + } + CodeStatement[] readStmts = this.MakeReadFieldStatements(type.GetElementType(), objExpr, null, indexExprs); + for (int i = lenNames.Length - 1; i >= 0; i--) + { + CodeVariableDeclarationStatement + initStmt = new CodeVariableDeclarationStatement( + typeof(int), indexExprs[i].VariableName, ZeroExpr); + CodeExpression + testExpr = new CodeBinaryOperatorExpression(indexExprs[i], + CodeBinaryOperatorType.LessThan, + new CodeVariableReferenceExpression(lenNames[i])); + CodeStatement + incStmt = new CodeAssignStatement( + indexExprs[i], + new CodeBinaryOperatorExpression(indexExprs[i], + CodeBinaryOperatorType.Add, + OneExpr)); + readStmts = new CodeStatement[] { new CodeIterationStatement(initStmt, testExpr, incStmt, readStmts) }; + } + statements.AddRange(readStmts); + } + } + else + { + CodeExpression newObjectCall; + if (type.IsValueType) + { + // default(type) + newObjectCall = new CodeObjectCreateExpression(type); + } + else + { + // FormatterServices.GetUninitializedObject(type) + newObjectCall = new CodeMethodInvokeExpression(new CodeTypeReferenceExpression("FormatterServices"), + "GetUninitializedObject", + new CodeTypeOfExpression(type)); + newObjectCall = new CodeCastExpression(type, newObjectCall); + } + statements.Add(new CodeVariableDeclarationStatement(type, "obj", newObjectCall)); + + // For each field of type, generate its deserialization code. + FieldInfo[] fields = TypeSystem.GetAllFields(type); + System.Array.Sort(fields, (x, y) => x.MetadataToken.CompareTo(y.MetadataToken)); + + bool canBeNull = fields.Any(x => !x.FieldType.IsValueType && AttributeSystem.FieldCanBeNull(x)); + if (canBeNull) + { + CodeExpression bvReadExpr = new CodeSnippetExpression("BitVector.Read(reader)"); + CodeStatement stmt = new CodeVariableDeclarationStatement(typeof(BitVector), "bv", bvReadExpr); + statements.Add(stmt); + } + for (int i = 0; i < fields.Length; i++) + { + FieldInfo finfo = fields[i]; + if (TypeSystem.IsFieldSerialized(finfo)) + { + if (finfo.FieldType == typeof(object)) + { + throw new DryadLinqException(HpcLinqErrorCode.CannotHandleObjectFields, + String.Format(SR.CannotHandleObjectFields, type.FullName)); + } + + CodeVariableReferenceExpression[] + indexExprs = new CodeVariableReferenceExpression[] { new CodeVariableReferenceExpression(i.ToString()) }; + CodeStatement[] stmts = this.MakeReadFieldStatements(finfo.FieldType, objExpr, finfo, indexExprs); + statements.AddRange(stmts); + } + } + } + statements.Add(new CodeMethodReturnStatement(objExpr)); + return statements.ToArray(); + } + + private CodeStatement[] MakeWriteMethodBody(Type type) + { + List statements = new List(); + CodeExpression objExpr = new CodeArgumentReferenceExpression("obj"); + CodeExpression writerExpr = new CodeArgumentReferenceExpression("writer"); + + if (type.IsArray) + { + if (type.GetElementType() == typeof(object)) + { + throw new DryadLinqException(HpcLinqErrorCode.CannotHandleObjectFields, + String.Format(SR.CannotHandleObjectFields, type.FullName)); + } + + int rank = type.GetArrayRank(); + for (int i = 0; i < rank; i++) + { + CodeExpression lenExpr = new CodeMethodInvokeExpression(objExpr, "GetLength", new CodePrimitiveExpression(i)); + CodeExpression lenCall = new CodeMethodInvokeExpression(writerExpr, "Write", lenExpr); + statements.Add(new CodeExpressionStatement(lenCall)); + } + + // Generate the writing code + if (type.GetElementType().IsPrimitive) + { + // Use a single WriteRawBytes for primitive array + string lenStr = "sizeof(" + type.GetElementType() + ")"; + for (int i = 0; i < rank; i++) + { + lenStr += "*obj.GetLength(" + i + ")"; + } + string writeBytes = " unsafe { fixed (void *p = obj) writer.WriteRawBytes((byte*)p, " + lenStr + "); }"; + statements.Add(new CodeSnippetStatement(writeBytes)); + } + else + { + CodeVariableReferenceExpression[] indexExprs = new CodeVariableReferenceExpression[rank]; + for (int i = 0; i < rank; i++) + { + indexExprs[i] = new CodeVariableReferenceExpression("i" + i); + } + bool canBeNull = StaticConfig.AllowNullArrayElements && !type.GetElementType().IsValueType; + if (canBeNull) + { + string lenString = "obj.GetLength(0)"; + for (int i = 1; i < rank; i++) + { + lenString += "*obj.GetLength(" + i + ")"; + } + CodeExpression lenExpr = new CodeSnippetExpression(lenString); + CodeExpression bvExpr = new CodeObjectCreateExpression(typeof(BitVector), lenExpr); + CodeStatement bvStmt = new CodeVariableDeclarationStatement("BitVector", "bv", bvExpr); + statements.Add(bvStmt); + } + CodeStmtPair pair = this.MakeWriteFieldStatements(type.GetElementType(), objExpr, null, indexExprs); + + CodeStatement[] writeStmts = pair.Key; + if (writeStmts != null) + { + for (int i = rank - 1; i >= 0; i--) + { + CodeVariableDeclarationStatement + initStmt = new CodeVariableDeclarationStatement( + typeof(int), indexExprs[i].VariableName, ZeroExpr); + CodeExpression lenExpr = new CodeMethodInvokeExpression( + objExpr, "GetLength", new CodePrimitiveExpression(i)); + CodeExpression testExpr = new CodeBinaryOperatorExpression( + indexExprs[i], + CodeBinaryOperatorType.LessThan, + lenExpr); + CodeStatement incStmt = new CodeAssignStatement( + indexExprs[i], + new CodeBinaryOperatorExpression( + indexExprs[i], CodeBinaryOperatorType.Add, OneExpr)); + writeStmts = new CodeStatement[] { new CodeIterationStatement(initStmt, testExpr, incStmt, writeStmts) }; + } + statements.AddRange(writeStmts); + } + + if (canBeNull) + { + CodeExpression bvWriteExpr = new CodeSnippetExpression("BitVector.Write(writer, bv)"); + statements.Add(new CodeExpressionStatement(bvWriteExpr)); + } + + writeStmts = pair.Value; + for (int i = rank - 1; i >= 0; i--) + { + CodeVariableDeclarationStatement + initStmt = new CodeVariableDeclarationStatement( + typeof(int), indexExprs[i].VariableName, ZeroExpr); + CodeExpression lenExpr = new CodeMethodInvokeExpression( + objExpr, "GetLength", new CodePrimitiveExpression(i)); + CodeExpression testExpr = new CodeBinaryOperatorExpression( + indexExprs[i], CodeBinaryOperatorType.LessThan, lenExpr); + CodeStatement incStmt = new CodeAssignStatement( + indexExprs[i], + new CodeBinaryOperatorExpression( + indexExprs[i], CodeBinaryOperatorType.Add, OneExpr)); + writeStmts = new CodeStatement[] { new CodeIterationStatement(initStmt, testExpr, incStmt, writeStmts) }; + } + statements.AddRange(writeStmts); + } + } + else + { + FieldInfo[] fields = TypeSystem.GetAllFields(type); + System.Array.Sort(fields, (x, y) => x.MetadataToken.CompareTo(y.MetadataToken)); + + bool canBeNull = fields.Any(x => !x.FieldType.IsValueType && AttributeSystem.FieldCanBeNull(x)); + if (canBeNull) + { + CodeExpression lenExpr = new CodePrimitiveExpression(fields.Length); + CodeExpression bvExpr = new CodeObjectCreateExpression(typeof(BitVector), lenExpr); + CodeStatement bvStmt = new CodeVariableDeclarationStatement("BitVector", "bv", bvExpr); + statements.Add(bvStmt); + } + + // For each field of type, generate its serialization code + CodeStatement[][] stmtArray = new CodeStatement[fields.Length][]; + for (int i = 0; i < fields.Length; i++) + { + FieldInfo finfo = fields[i]; + if (TypeSystem.IsFieldSerialized(finfo)) + { + if (finfo.FieldType == typeof(object)) + { + throw new DryadLinqException(HpcLinqErrorCode.CannotHandleObjectFields, + String.Format(SR.CannotHandleObjectFields, type.FullName)); + } + + CodeVariableReferenceExpression[] + indexExprs = new CodeVariableReferenceExpression[] { new CodeVariableReferenceExpression(i.ToString()) }; + CodeStmtPair pair = this.MakeWriteFieldStatements(finfo.FieldType, objExpr, finfo, indexExprs); + stmtArray[i] = pair.Value; + if (pair.Key != null) + { + statements.AddRange(pair.Key); + } + } + } + if (canBeNull) + { + CodeExpression bvWriteExpr = new CodeSnippetExpression("BitVector.Write(writer, bv)"); + statements.Add(new CodeExpressionStatement(bvWriteExpr)); + } + for (int i = 0; i < stmtArray.Length; i++) + { + if (stmtArray[i] != null) + { + statements.AddRange(stmtArray[i]); + } + } + } + + return statements.ToArray(); + } + + private CodeStatement[] + MakeReadFieldStatements(Type type, + CodeExpression objExpr, + FieldInfo finfo, + CodeVariableReferenceExpression[] indexExprs) + { + CodeStatement[] stmts; + CodeExpression readerExpr = new CodeArgumentReferenceExpression("reader"); + string readerName = GetBuiltinReaderName(type); + if (readerName == null) + { + // For non-builtin types + string serializerName = GetStaticSerializerName(type); + CodeVariableReferenceExpression serializerExpr = new CodeVariableReferenceExpression(serializerName); + CodeVariableDeclarationStatement tempDecl = null; + CodeExpression setterExpr = null; + + CodeExpression fieldExpr; + if (finfo == null) + { + fieldExpr = new CodeArrayIndexerExpression(objExpr, indexExprs); + } + else if (finfo.IsPublic && !finfo.IsInitOnly) + { + fieldExpr = new CodeFieldReferenceExpression(objExpr, finfo.Name); + } + else + { + string fieldName = TypeSystem.FieldName(finfo.Name); + if (!TypeSystem.IsBackingField(finfo.Name) || + finfo.DeclaringType.GetProperty(fieldName, FieldFlags).GetSetMethod() == null) + { + setterExpr = new CodeVariableReferenceExpression(ExtensionClassName + "." + this.SetterFieldName(finfo)); + fieldName = this.m_fieldToStaticName[finfo]; + } + tempDecl = new CodeVariableDeclarationStatement(type, fieldName); + fieldExpr = new CodeVariableReferenceExpression(tempDecl.Name); + } + + CodeExpression fieldValExpr = new CodeMethodInvokeExpression(serializerExpr, "Read", readerExpr); + CodeStatement readCall = new CodeAssignStatement(fieldExpr, fieldValExpr); + if (tempDecl == null) + { + stmts = new CodeStatement[] { readCall }; + } + else + { + CodeStatement setCall; + if (setterExpr == null) + { + CodeExpression propExpr = new CodePropertyReferenceExpression(objExpr, tempDecl.Name); + setCall = new CodeAssignStatement(propExpr, fieldExpr); + } + else + { + if (finfo.DeclaringType.IsValueType) + { + objExpr = new CodeDirectionExpression(FieldDirection.Out, objExpr); + } + CodeExpression setExpr = new CodeDelegateInvokeExpression(setterExpr, objExpr, fieldExpr); + setCall = new CodeExpressionStatement(setExpr); + } + stmts = new CodeStatement[] { tempDecl, readCall, setCall }; + } + } + else + { + // for builtin types + CodeExpression readCall = new CodeMethodInvokeExpression(readerExpr, readerName); + if (finfo == null) + { + CodeExpression fieldExpr = new CodeArrayIndexerExpression(objExpr, indexExprs); + stmts = new CodeStatement[] { new CodeAssignStatement(fieldExpr, readCall) }; + } + else + { + string fieldName = TypeSystem.FieldName(finfo.Name); + if ((finfo.IsPublic && !finfo.IsInitOnly) || + (TypeSystem.IsBackingField(finfo.Name) && + finfo.DeclaringType.GetProperty(fieldName, FieldFlags).GetSetMethod() != null)) + { + CodeExpression fieldExpr = new CodeFieldReferenceExpression(objExpr, fieldName); + stmts = new CodeStatement[] { new CodeAssignStatement(fieldExpr, readCall) }; + } + else + { + CodeExpression setterExpr = new CodeVariableReferenceExpression( + ExtensionClassName + "." + this.SetterFieldName(finfo)); + if (finfo.DeclaringType.IsValueType) + { + objExpr = new CodeDirectionExpression(FieldDirection.Out, objExpr); + } + CodeExpression setExpr = new CodeDelegateInvokeExpression(setterExpr, objExpr, readCall); + stmts = new CodeStatement[] { new CodeExpressionStatement(setExpr) }; + } + } + } + + if (!type.IsValueType && + (finfo != null || StaticConfig.AllowNullArrayElements) && + (finfo == null || AttributeSystem.FieldCanBeNull(finfo))) + { + CodeExpression bvIndex = indexExprs[0]; + if (finfo == null) + { + string bvIndexString = indexExprs[0].VariableName; + for (int i = 1; i < indexExprs.Length; i++) + { + bvIndexString += "*" + indexExprs[i].VariableName; + } + bvIndex = new CodeSnippetExpression(bvIndexString); + } + CodeExpression bvExpr = new CodeArgumentReferenceExpression("bv"); + CodeExpression ifExpr = new CodeBinaryOperatorExpression( + new CodeIndexerExpression(bvExpr, bvIndex), + CodeBinaryOperatorType.IdentityEquality, + new CodePrimitiveExpression(false)); + CodeStatement stmt = new CodeConditionStatement(ifExpr, stmts); + stmts = new CodeStatement[] { stmt }; + } + return stmts; + } + + private CodeStmtPair MakeWriteFieldStatements(Type type, + CodeExpression objExpr, + FieldInfo finfo, + CodeVariableReferenceExpression[] indexExprs) + { + CodeExpression writerExpr = new CodeArgumentReferenceExpression("writer"); + CodeExpression fieldExpr; + if (finfo == null) + { + fieldExpr = new CodeArrayIndexerExpression(objExpr, indexExprs); + } + else + { + string fieldName = TypeSystem.FieldName(finfo.Name); + if (finfo.IsPublic || + (TypeSystem.IsBackingField(finfo.Name) && + finfo.DeclaringType.GetProperty(fieldName, FieldFlags).GetGetMethod() != null)) + { + fieldExpr = new CodeFieldReferenceExpression(objExpr, fieldName); + } + else + { + CodeExpression getterExpr = new CodeVariableReferenceExpression( + ExtensionClassName + "." + this.GetterFieldName(finfo)); + if (finfo.DeclaringType.IsValueType) + { + objExpr = new CodeDirectionExpression(FieldDirection.Out, objExpr); + } + fieldExpr = new CodeDelegateInvokeExpression(getterExpr, objExpr); + } + } + + CodeExpression writeCall; + if (GetBuiltinReaderName(type) == null) + { + // for non-builtin types + string serializerName = GetStaticSerializerName(type); + CodeVariableReferenceExpression serializerExpr = new CodeVariableReferenceExpression(serializerName); + writeCall = new CodeMethodInvokeExpression(serializerExpr, "Write", writerExpr, fieldExpr); + } + else + { + // for builtin types + writeCall = new CodeMethodInvokeExpression(writerExpr, "Write", fieldExpr); + } + CodeStatement stmt1 = new CodeExpressionStatement(writeCall); + + if (type.IsValueType) + { + return new CodeStmtPair(null, new CodeStatement[] { stmt1 }); + } + else if (finfo == null) + { + if (StaticConfig.AllowNullArrayElements) + { + string bvIndexString = indexExprs[0].VariableName; + for (int i = 1; i < indexExprs.Length; i++) + { + bvIndexString += "*" + indexExprs[i].VariableName; + } + CodeExpression bvIndex = new CodeSnippetExpression(bvIndexString); + CodeExpression nullExpr = new CodeMethodInvokeExpression( + new CodeTypeReferenceExpression("Object"), + "ReferenceEquals", + fieldExpr, + NullExpr); + CodeExpression bvExpr = new CodeArgumentReferenceExpression("bv"); + CodeStatement stmt0 = new CodeExpressionStatement( + new CodeMethodInvokeExpression(bvExpr, "Set", bvIndex)); + stmt0 = new CodeConditionStatement(nullExpr, stmt0); + + CodeExpression notNullExpr = new CodeBinaryOperatorExpression( + new CodeIndexerExpression(bvExpr, bvIndex), + CodeBinaryOperatorType.IdentityEquality, + new CodePrimitiveExpression(false)); + stmt1 = new CodeConditionStatement(notNullExpr, stmt1); + return new CodeStmtPair(new CodeStatement[] { stmt0 }, new CodeStatement[] { stmt1 }); + } + else + { + return new CodeStmtPair(null, new CodeStatement[] { stmt1 }); + } + } + else + { + CodeExpression nullExpr = new CodeMethodInvokeExpression( + new CodeTypeReferenceExpression("Object"), + "ReferenceEquals", + fieldExpr, + NullExpr); + if (AttributeSystem.FieldCanBeNull(finfo)) + { + CodeExpression bvExpr = new CodeArgumentReferenceExpression("bv"); + CodeStatement stmt0 = new CodeExpressionStatement( + new CodeMethodInvokeExpression(bvExpr, "Set", indexExprs[0])); + stmt0 = new CodeConditionStatement(nullExpr, stmt0); + + CodeExpression notNullExpr = new CodeBinaryOperatorExpression( + new CodeIndexerExpression(bvExpr, indexExprs[0]), + CodeBinaryOperatorType.IdentityEquality, + new CodePrimitiveExpression(false)); + stmt1 = new CodeConditionStatement(notNullExpr, stmt1); + return new CodeStmtPair(new CodeStatement[] { stmt0 }, new CodeStatement[] { stmt1 }); + } + else + { + // YY: For now we always check null + string msg = "Field " + finfo.DeclaringType.Name + "." + finfo.Name + " is null."; + CodeExpression msgExpr = new CodePrimitiveExpression(msg); + CodeExpression throwExpr = new CodeObjectCreateExpression(typeof(ArgumentNullException), msgExpr); + CodeStatement stmt0 = new CodeConditionStatement(nullExpr, new CodeThrowExceptionStatement(throwExpr)); + return new CodeStmtPair(null, new CodeStatement[] { stmt0, stmt1 }); + } + } + } + + private static Type FindCustomSerializerType(Type type) + { + // Look for [CustomHpcSerializer] on the UDT. + // Skip inheritance hieararchy, we don't want CustomHpcSerializer declarations + // on the UDT's parent types to take effect. + object[] attributes = type.GetCustomAttributes(typeof(CustomHpcSerializerAttribute), false); + if (attributes.Length == 1) + { + CustomHpcSerializerAttribute attr = (CustomHpcSerializerAttribute)attributes[0]; + Type serializerType = attr.SerializerType; + + // make sure the serializer type specified in the attribute isn't null + if (serializerType == null) + { + throw new DryadLinqException(HpcLinqErrorCode.SerializerTypeMustBeNonNull, + String.Format(SR.SerializerTypeMustBeNonNull, type.FullName)); + } + + // Make sure the serializer type specified in the attribute implements IHpcSerializer + bool found = false; + if (type.IsGenericType) + { + Type type1 = type.GetGenericTypeDefinition(); + foreach (var intf in serializerType.GetInterfaces()) + { + if (intf.GetGenericTypeDefinition() == typeof(IHpcSerializer<>) && + intf.GetGenericArguments()[0].GetGenericTypeDefinition() == type1) + { + found = true; + break; + } + } + } + else + { + Type expectedSerializerInterfaceType = typeof(IHpcSerializer<>).MakeGenericType(type); + found = expectedSerializerInterfaceType.IsAssignableFrom(serializerType); + } + if (!found) + { + throw new DryadLinqException(HpcLinqErrorCode.SerializerTypeMustSupportIHpcSerializer, + String.Format(SR.SerializerTypeMustSupportIHpcSerializer, + serializerType.FullName, type.FullName)); + } + return serializerType; + } + + return null; + } + + // Returns true if either the DryadLINQ HpcSerialization classes define + // the type's Read and Write methods or there is a user-defined serialization + // class for this type. + private static string GetGenericSerializationClassName(Type type) + { + Type[] genericArgs = type.GetGenericArguments(); + if (genericArgs.Length > 2) + { + return null; + } + + Type refType = type.MakeByRefType(); + if (genericArgs.Length == 1) + { + Type[] typeArgs = new Type[] { genericArgs[0], + typeof(HpcSerializer<>).MakeGenericType(genericArgs[0]) }; + Type dsType = typeof(HpcSerialization<,>).MakeGenericType(typeArgs); + MethodInfo readMethod = TypeSystem.FindStaticMethod(dsType, "Read", new Type[]{ typeof(HpcBinaryReader), refType }); + MethodInfo writeMethod = TypeSystem.FindStaticMethod(dsType, "Write", new Type[]{ typeof(HpcBinaryWriter), type }); + if (readMethod != null && writeMethod != null) + { + return "HpcSerialization"; + } + if (typeArgs[0].IsValueType) + { + dsType = typeof(StructHpcSerialization<,>).MakeGenericType(typeArgs); + readMethod = TypeSystem.FindStaticMethod(dsType, "Read", new Type[] { typeof(HpcBinaryReader), refType }); + writeMethod = TypeSystem.FindStaticMethod(dsType, "Write", new Type[] { typeof(HpcBinaryWriter), type }); + if (readMethod != null && writeMethod != null) + { + return "StructHpcSerialization"; + } + } + } + else + { + Type[] typeArgs = new Type[] { genericArgs[0], + genericArgs[1], + typeof(HpcSerializer<>).MakeGenericType(genericArgs[0]), + typeof(HpcSerializer<>).MakeGenericType(genericArgs[1]) }; + Type dsType = typeof(HpcSerialization<,,,>).MakeGenericType(typeArgs); + MethodInfo readMethod = TypeSystem.FindStaticMethod(dsType, "Read", new Type[]{ typeof(HpcBinaryReader), refType }); + MethodInfo writeMethod = TypeSystem.FindStaticMethod(dsType, "Write", new Type[]{ typeof(HpcBinaryWriter), type }); + if (readMethod != null && writeMethod != null) + { + return "HpcSerialization"; + } + if (typeArgs[0].IsValueType && typeArgs[1].IsValueType) + { + dsType = typeof(StructHpcSerialization<,,,>).MakeGenericType(typeArgs); + readMethod = TypeSystem.FindStaticMethod(dsType, "Read", new Type[] { typeof(HpcBinaryReader), refType }); + writeMethod = TypeSystem.FindStaticMethod(dsType, "Write", new Type[] { typeof(HpcBinaryWriter), type }); + if (readMethod != null && writeMethod != null) + { + return "StructHpcSerialization"; + } + } + } + return null; + } + + private static bool IsObject(Type type) + { + Type elemType = type; + while (elemType.IsArray) + { + elemType = elemType.GetElementType(); + } + return elemType == typeof(object); + } + + // Add the serializer class + internal string AddSerializerClass(Type type) + { + // Check if the serializer class is built-in + string serializerName = GetBuiltInHpcSerializer(type); + if (serializerName != null) + { + return serializerName; + } + + // Check if the serializer class is already generated + if (this.m_typeToSerializerName.TryGetValue(type, out serializerName)) + { + return serializerName; + } + + // Check for custom serialization + Type customSerializerType = FindCustomSerializerType(type); + if (customSerializerType != null) + { + serializerName = TypeSystem.TypeName(customSerializerType, this.AnonymousTypeToName); + if (type.IsGenericType) + { + Type[] argTypes = type.GetGenericArguments(); + int len = argTypes.Length; + for (int i = 0; i < len; i++) + { + this.AddAnonymousClass(argTypes[i]); + } + if (customSerializerType.IsGenericTypeDefinition) + { + if (customSerializerType.GetGenericArguments().Length != len * 2) + { + throw new DryadLinqException("The custom serializer " + customSerializerType + + " must have " + (len*2) + " generic type parameters."); + } + + int cnt = 1; + int matchIdx = serializerName.Length - 2; + while (matchIdx >= 0) + { + if (serializerName[matchIdx] == '>') cnt++; + if (serializerName[matchIdx] == '<') cnt--; + if (cnt == 0) break; + matchIdx--; + } + serializerName = serializerName.Substring(0, matchIdx); + serializerName += "<"; + for (int i = 0; i < len; i++) + { + serializerName += this.MakeTypeNameAlias( + TypeSystem.TypeName(argTypes[i], this.m_anonymousTypeToName)); + serializerName += ","; + } + for (int i = 0; i < len; i++) + { + serializerName += this.MakeTypeNameAlias(this.AddSerializerClass(argTypes[i])); + if (i < (len-1)) serializerName += ","; + } + serializerName += ">"; + } + } + return serializerName; + } + + if (!TypeSystem.IsAnonymousType(type)) + { + if (!type.IsPublic && !type.IsNestedPublic) + { + throw new DryadLinqException(HpcLinqErrorCode.TypeRequiredToBePublic, + String.Format(SR.TypeRequiredToBePublic, type)); + } + if (IsObject(type)) + { + throw new DryadLinqException(HpcLinqErrorCode.CannotHandleObjectFields, + String.Format(SR.CannotHandleObjectFields, type.FullName)); + } + + // The serializer has troubles if a data type has no data-members, so we outlaw these. + // Abstract classes don't admit such an easy test. + if (!type.IsAbstract && TypeSystem.GetSize(type) == 0) + { + throw new DryadLinqException(HpcLinqErrorCode.TypeMustHaveDataMembers, + String.Format(SR.TypeMustHaveDataMembers, type)); + } + } + + bool isReal = TypeSystem.IsRealType(type); + this.AddAnonymousClass(type); + bool isTypeSerializable = TypeSystem.IsTypeSerializable(type); + + // Check for builtin serialization + CodeExpression serializationTypeExpr = null; + if (type.IsGenericType) + { + string serializationClassName = GetGenericSerializationClassName(type); + if (serializationClassName != null) + { + // Add anonymous classes for type arguments + Type[] argTypes = type.GetGenericArguments(); + int len = argTypes.Length; + for (int i = 0; i < len; i++) + { + this.AddAnonymousClass(argTypes[i]); + } + CodeTypeReference[] argRefs = new CodeTypeReference[len * 2]; + for (int i = 0; i < len; i++) + { + argRefs[i] = new CodeTypeReference(this.MakeTypeNameAlias(TypeSystem.TypeName(argTypes[i], this.m_anonymousTypeToName))); + argRefs[len + i] = new CodeTypeReference(this.MakeTypeNameAlias(this.AddSerializerClass(argTypes[i]))); + } + CodeTypeReference typeRef = new CodeTypeReference(serializationClassName, argRefs); + serializationTypeExpr = new CodeTypeReferenceExpression(typeRef); + } + } + + // We now add the serializer class + serializerName = "HpcSerializer" + MakeName(type); + this.m_typeToSerializerName[type] = serializerName; + string typeName = TypeSystem.TypeName(type, this.m_anonymousTypeToName); + string baseClassName = "HpcSerializer<" + typeName + ">"; + CodeTypeDeclaration serializerClass = new CodeTypeDeclaration(serializerName + " : " + baseClassName); + this.m_dryadCodeSpace.Types.Add(serializerClass); + serializerClass.IsClass = true; + serializerClass.TypeAttributes = TypeAttributes.Public | TypeAttributes.Sealed; + + // Add the Read method + CodeMemberMethod readMethod = new CodeMemberMethod(); + serializerClass.Members.Add(readMethod); + readMethod.Attributes = MemberAttributes.Public | MemberAttributes.Override; + readMethod.Name = "Read"; + readMethod.Parameters.Add(new CodeParameterDeclarationExpression(typeof(HpcBinaryReader), "reader")); + typeName = this.MakeTypeNameAlias(typeName); + readMethod.ReturnType = new CodeTypeReference(typeName); + + CodeExpression objExpr = new CodeArgumentReferenceExpression("obj"); + CodeExpression readerExpr = new CodeArgumentReferenceExpression("reader"); + if (type.IsEnum) + { + string readerName = GetBuiltinReaderName(type.GetFields()[0].FieldType); + CodeExpression valExpr = new CodeMethodInvokeExpression(readerExpr, readerName); + valExpr = new CodeCastExpression(type, valExpr); + readMethod.Statements.Add(new CodeMethodReturnStatement(valExpr)); + } + else if (serializationTypeExpr != null) + { + CodeExpression outObjExpr = new CodeDirectionExpression(FieldDirection.Out, objExpr); + CodeExpression readCallExpr = new CodeMethodInvokeExpression( + serializationTypeExpr, "Read", readerExpr, outObjExpr); + readMethod.Statements.Add(new CodeVariableDeclarationStatement(typeName, "obj")); + readMethod.Statements.Add(new CodeExpressionStatement(readCallExpr)); + readMethod.Statements.Add(new CodeMethodReturnStatement(objExpr)); + } + else if (TypeSystem.IsAnonymousType(type)) + { + string className = this.m_anonymousTypeToName[type]; + CodeExpression newObjectCall = new CodeMethodInvokeExpression( + new CodeTypeReferenceExpression("FormatterServices"), + "GetUninitializedObject", + new CodeTypeOfExpression(className)); + newObjectCall = new CodeCastExpression(className, newObjectCall); + readMethod.Statements.Add(new CodeVariableDeclarationStatement(className, "obj", newObjectCall)); + + PropertyInfo[] props = type.GetProperties(); + System.Array.Sort(props, (x, y) => x.MetadataToken.CompareTo(y.MetadataToken)); + for (int i = 0; i < props.Length; i++) + { + string fieldName = "_" + props[i].Name; + CodeExpression fieldExpr = new CodeFieldReferenceExpression(objExpr, fieldName); + string readerName = GetBuiltinReaderName(props[i].PropertyType); + CodeStatement stmt; + if (readerName == null) + { + string fieldSerializerName = GetStaticSerializerName(props[i].PropertyType); + CodeVariableReferenceExpression + serializerExpr = new CodeVariableReferenceExpression(fieldSerializerName); + CodeExpression + readCallExpr = new CodeMethodInvokeExpression(serializerExpr, "Read", readerExpr); + stmt = new CodeAssignStatement(fieldExpr, readCallExpr); + } + else + { + CodeExpression readCallExpr = new CodeMethodInvokeExpression(readerExpr, readerName); + stmt = new CodeAssignStatement(fieldExpr, readCallExpr); + } + if (!props[i].PropertyType.IsValueType) + { + CodeExpression ifExpr = new CodeMethodInvokeExpression(readerExpr, "ReadBool"); + stmt = new CodeConditionStatement(ifExpr, stmt); + } + readMethod.Statements.Add(stmt); + } + readMethod.Statements.Add(new CodeMethodReturnStatement(objExpr)); + } + else if (!isReal) + { + throw new DryadLinqException(HpcLinqErrorCode.UDTMustBeConcreteType, + String.Format(SR.UDTMustBeConcreteType, type.FullName)); + } + else if (TypeSystem.HasFieldOfNonPublicType(type)) + { + throw new DryadLinqException(HpcLinqErrorCode.UDTHasFieldOfNonPublicType, + String.Format(SR.UDTHasFieldOfNonPublicType, type.FullName)); + } + else if (typeof(System.Delegate).IsAssignableFrom(type)) + { + throw new DryadLinqException(HpcLinqErrorCode.UDTIsDelegateType, + String.Format(SR.UDTIsDelegateType, type.FullName)); + } + else if (!type.IsSealed && TypeSystem.HasSubtypes(type)) + { + throw new DryadLinqException(HpcLinqErrorCode.CannotHandleSubtypes, + String.Format(SR.CannotHandleSubtypes, type.FullName)); + } + else if (isTypeSerializable) // The only choice we have left is to add the auto generated Read method body. + { + // make sure we aren't trying to auto-serialize a circular type + if (TypeSystem.IsCircularType(type)) + { + throw new DryadLinqException(HpcLinqErrorCode.CannotHandleCircularTypes, + String.Format(SR.CannotHandleCircularTypes, type.FullName)); + } + readMethod.Statements.AddRange(this.MakeReadMethodBody(type)); + } + else + { + // tell the user we could do this automatically for them, but they just need to ask explicitly + throw new DryadLinqException(HpcLinqErrorCode.TypeNotSerializable, + String.Format(SR.TypeNotSerializable, type.FullName)); + } + + // Add the Write method + CodeMemberMethod writeMethod = new CodeMemberMethod(); + serializerClass.Members.Add(writeMethod); + writeMethod.Attributes = MemberAttributes.Public | MemberAttributes.Override; + writeMethod.Name = "Write"; + writeMethod.Parameters.Add(new CodeParameterDeclarationExpression(typeof(HpcBinaryWriter), "writer")); + writeMethod.Parameters.Add(new CodeParameterDeclarationExpression(typeName, "obj")); + writeMethod.ReturnType = new CodeTypeReference(typeof(void)); + + CodeExpression writerExpr = new CodeArgumentReferenceExpression("writer"); + if (type.IsEnum) + { + Type intType = type.GetFields()[0].FieldType; + CodeExpression valExpr = new CodeCastExpression(intType, objExpr); + CodeExpression writeCallExpr = new CodeMethodInvokeExpression(writerExpr, "Write", valExpr); + writeMethod.Statements.Add(new CodeExpressionStatement(writeCallExpr)); + } + else if (serializationTypeExpr != null) + { + CodeExpression writeCallExpr = new CodeMethodInvokeExpression( + serializationTypeExpr, "Write", writerExpr, objExpr); + writeMethod.Statements.Add(new CodeExpressionStatement(writeCallExpr)); + } + else if (TypeSystem.IsAnonymousType(type)) + { + PropertyInfo[] props = type.GetProperties(); + System.Array.Sort(props, (x, y) => x.MetadataToken.CompareTo(y.MetadataToken)); + for (int i = 0; i < props.Length; i++) + { + Type fieldType = props[i].PropertyType; + string fieldName = "_" + props[i].Name; + CodeExpression fieldExpr = new CodeFieldReferenceExpression(objExpr, fieldName); + CodeExpression writeCall; + if (GetBuiltinReaderName(type) == null) + { + string fieldSerializerName = GetStaticSerializerName(fieldType); + CodeVariableReferenceExpression + serializerExpr = new CodeVariableReferenceExpression(fieldSerializerName); + writeCall = new CodeMethodInvokeExpression(serializerExpr, "Write", writerExpr, fieldExpr); + } + else + { + writeCall = new CodeMethodInvokeExpression(writerExpr, "Write", fieldExpr); + } + CodeStatement stmt = new CodeExpressionStatement(writeCall); + if (!fieldType.IsValueType) + { + CodeExpression nullExpr = new CodeMethodInvokeExpression( + new CodeTypeReferenceExpression("Object"), + "ReferenceEquals", + fieldExpr, + NullExpr); + CodeExpression notNullExpr = new CodeBinaryOperatorExpression( + nullExpr, + CodeBinaryOperatorType.IdentityEquality, + new CodePrimitiveExpression(false)); + writeCall = new CodeMethodInvokeExpression(writerExpr, "Write", notNullExpr); + writeMethod.Statements.Add(new CodeExpressionStatement(writeCall)); + stmt = new CodeConditionStatement(notNullExpr, stmt); + } + writeMethod.Statements.Add(stmt); + } + } + else + { + writeMethod.Statements.AddRange(this.MakeWriteMethodBody(type)); + } + + return serializerName; + } + + private CodeMemberField AddCustomSerializerStaticField(Type type, Type customSerializerType) + { + // create unique name for the static instance + string customSerializerInstanceName = String.Format("customSerializer_{0}", MakeName(type)); + CodeMemberField customSerializerField = new CodeMemberField(customSerializerType, customSerializerInstanceName); + customSerializerField.Attributes = MemberAttributes.Assembly | MemberAttributes.Static; + + // Now we need to add the init expression for the serializer instance + if (customSerializerType.IsClass && !customSerializerType.IsByRef) + { + // if the serializer type is a CLASS, this expression will be the default ctor of the custom serializer type + // i.e. "internal static CustomSerializerType customSerializerInstance = new CustomSerializerType();" + customSerializerField.InitExpression = new CodeObjectCreateExpression(customSerializerType); + + // make sure the custom serialzier type has a default constructor because we need to instantiate a static copy + var ctorInfo = customSerializerType.GetConstructor(Type.EmptyTypes); + if (ctorInfo == null) + { + throw new DryadLinqException(HpcLinqErrorCode.CustomSerializerMustSupportDefaultCtor, + String.Format(SR.CustomSerializerMustSupportDefaultCtor, customSerializerType.FullName)); + } + } + else if (customSerializerType.IsValueType && !customSerializerType.IsByRef) + { + // if the serializer type is a VALUE TYPE, this expression will be the default value of the custom serializer type + // i.e. "internal static CustomSerializerType customSerializerInstance = default(CustomSerializerType);" + customSerializerField.InitExpression = new CodeDefaultValueExpression(new CodeTypeReference(customSerializerType)); + } + else + { + // neither class, nor value type means they either passed in an interface or a byref type, none of which we support + throw new DryadLinqException(HpcLinqErrorCode.CustomSerializerMustBeClassOrStruct, + String.Format(SR.CustomSerializerMustBeClassOrStruct, customSerializerType.FullName, type.FullName)); + } + + + // We don't need to ensure uniqueness of "customSerializer_Type_XX" fields + // here because the caller of this method runs only once per UDT + m_dryadExtensionClass.Members.Add(customSerializerField); + + return customSerializerField; + } + + internal CodeVariableDeclarationStatement + MakeVarDeclStatement(string typeName, string varName, CodeExpression expr) + { + return new CodeVariableDeclarationStatement(typeName, MakeUniqueName(varName), expr); + } + + internal CodeVariableDeclarationStatement + MakeVarDeclStatement(Type type, string varName, CodeExpression expr) + { + string typeName = TypeSystem.TypeName(type, this.AnonymousTypeToName); + return new CodeVariableDeclarationStatement(typeName, MakeUniqueName(varName), expr); + } + + internal CodeExpression MakeExpression(Expression expr) + { + string exprString = HpcLinqExpression.ToCSharpString(expr, this.AnonymousTypeToName); + return new CodeSnippetExpression(exprString); + } + + internal CodeVariableDeclarationStatement MakeFactoryDecl(Type type) + { + CodeExpression factoryInitExpr = new CodeObjectCreateExpression(HpcLinqFactoryClassName(type)); + return this.MakeVarDeclStatement(HpcLinqFactoryClassName(type), "factory", factoryInitExpr); + } + + internal CodeVariableDeclarationStatement MakeDryadReaderDecl(Type type, string factoryName) + { + CodeExpression readerInitExpr = new CodeMethodInvokeExpression( + new CodeArgumentReferenceExpression(DryadEnvName), + "MakeReader", + new CodeArgumentReferenceExpression(factoryName)); + return this.MakeVarDeclStatement("var", "dreader", readerInitExpr); + } + + internal CodeVariableDeclarationStatement MakeDryadWriterDecl(Type type, string factoryName) + { + CodeExpression writerInitExpr = new CodeMethodInvokeExpression( + new CodeArgumentReferenceExpression(DryadEnvName), + "MakeWriter", + new CodeArgumentReferenceExpression(factoryName)); + return this.MakeVarDeclStatement("var", "dwriter", writerInitExpr); + } + + internal CodeVariableDeclarationStatement MakeSourceDecl(string methodName, string denvName) + { + CodeExpression sourceInitExpr = new CodeMethodInvokeExpression( + new CodeVariableReferenceExpression(denvName), + methodName, + new CodePrimitiveExpression(true)); + return this.MakeVarDeclStatement("var", "source", sourceInitExpr); + } + + internal static CodeVariableDeclarationStatement MakeDryadVertexParamsDecl(DryadQueryNode node) + { + int inputArity = node.InputArity + node.GetReferencedQueries().Count; + int outputArity = node.OutputArity; + + CodeExpression arg1 = new CodePrimitiveExpression(inputArity); + CodeExpression arg2 = new CodePrimitiveExpression(outputArity); + + CodeExpression dVertexParamsInitExpr = new CodeObjectCreateExpression("HpcLinqVertexParams", arg1, arg2); + CodeVariableDeclarationStatement + dVertexParamsDecl = new CodeVariableDeclarationStatement("HpcLinqVertexParams", + DryadVertexParamName, + dVertexParamsInitExpr); + return dVertexParamsDecl; + } + + internal static CodeAssignStatement SetDryadVertexParamField(string fieldName, object value) + { + CodeExpression vertexParam = new CodeArgumentReferenceExpression(DryadVertexParamName); + CodeExpression left = new CodeFieldReferenceExpression(vertexParam, fieldName); + CodeExpression right = new CodePrimitiveExpression(value); + return new CodeAssignStatement(left, right); + } + + internal static CodeVariableDeclarationStatement MakeDryadEnvDecl(DryadQueryNode node) + { + CodeExpression arg1 = new CodeArgumentReferenceExpression("args"); + CodeExpression arg2 = new CodeArgumentReferenceExpression(DryadVertexParamName); + + CodeExpression + denvInitExpr = new CodeObjectCreateExpression("HpcLinqVertexEnv", arg1, arg2); + return new CodeVariableDeclarationStatement("HpcLinqVertexEnv", DryadEnvName, denvInitExpr); + } + + // Emits a static helper method that checks an environment variable to decide whether to + // launch the debugger, wait for a manual attach or simply skip straight into vertex code. + private bool m_debugHelperEmitted = false; + private void EnsureDebuggerHelperMethodEmitted() + { + if (this.m_debugHelperEmitted) return; + + CodeMemberMethod debugHelperMethod = new CodeMemberMethod(); + debugHelperMethod.Name = DebugHelperMethodName; + debugHelperMethod.Attributes = MemberAttributes.Public | MemberAttributes.Static; + + debugHelperMethod.Statements.Add(new CodeSnippetExpression("string debugEnvVar = Environment.GetEnvironmentVariable(\"LINQTOHPC_DEBUGVERTEX\")")); + debugHelperMethod.Statements.Add(new CodeSnippetExpression("if (debugEnvVar == null) return")); + + CodeConditionStatement conditionalStatement = new CodeConditionStatement( + new CodeSnippetExpression("String.Compare(debugEnvVar, \"LAUNCH\", StringComparison.OrdinalIgnoreCase) == 0"), // The condition to test. + new CodeStatement[] { new CodeSnippetStatement(" System.Diagnostics.Debugger.Launch();") }, // if clause + new CodeStatement[] { new CodeSnippetStatement(" DryadLinqLog.Add(\"Waiting for debugger to attach...\");"), // else clause + new CodeSnippetStatement(" while (!Debugger.IsAttached) System.Threading.Thread.Sleep(1000);"), + new CodeSnippetStatement(" Debugger.Break();") + }); + + debugHelperMethod.Statements.Add(conditionalStatement); + this.m_dryadVertexClass.Members.Add(debugHelperMethod); + this.m_debugHelperEmitted = true; + } + + // Add a new vertex method to the Dryad vertex class + internal CodeMemberMethod AddVertexMethod(DryadQueryNode node) + { + CodeMemberMethod vertexMethod = new CodeMemberMethod(); + vertexMethod.Attributes = MemberAttributes.Public | MemberAttributes.Static; + vertexMethod.ReturnType = new CodeTypeReference(typeof(int)); + vertexMethod.Parameters.Add(new CodeParameterDeclarationExpression(typeof(string), "args")); + vertexMethod.Name = MakeUniqueName(node.NodeType.ToString()); + + CodeTryCatchFinallyStatement tryBlock = new CodeTryCatchFinallyStatement(); + + string startedMsg = "DryadLinqLog.Add(\"Vertex " + vertexMethod.Name + + " started at {0}\", DateTime.Now.ToString(\"MM/dd/yyyy HH:mm:ss.fff\"))"; + vertexMethod.Statements.Add(new CodeSnippetExpression(startedMsg)); + + // We need to call AddCopyResourcesMethod() + vertexMethod.Statements.Add(new CodeSnippetExpression("CopyResources()")); + + if (StaticConfig.LaunchDebugger) + { + // If static config requests it, we do an unconditional Debugger.Launch() at vertex entry. + // Currently this isn't used because StaticConfig.LaunchDebugger is hardcoded to false + System.Console.WriteLine("Launch debugger: may block application"); + + CodeExpression launchExpr = new CodeSnippetExpression("System.Diagnostics.Debugger.Launch()"); + vertexMethod.Statements.Add(new CodeExpressionStatement(launchExpr)); + } + else + { + // Otherwise (the default behavior), we first make sure we emit the debug check helper static method + // and add a call to it at vertex entry. This helper checks an environment variable to decide whether + // to launch the debugger, wait for a manual attach or simply skip straigt into vertex code. + EnsureDebuggerHelperMethodEmitted(); + CodeMethodInvokeExpression debuggerCheckExpr = new CodeMethodInvokeExpression( + new CodeMethodReferenceExpression(new CodeTypeReferenceExpression(VertexClassName), + DebugHelperMethodName)); + + vertexMethod.Statements.Add(new CodeExpressionStatement(debuggerCheckExpr)); + } + + vertexMethod.Statements.Add(MakeDryadVertexParamsDecl(node)); + vertexMethod.Statements.Add(SetDryadVertexParamField("VertexStageName", vertexMethod.Name)); + vertexMethod.Statements.Add(SetDryadVertexParamField("UseLargeBuffer", node.UseLargeWriteBuffer)); + vertexMethod.Statements.Add(SetDryadVertexParamField("KeepInputPortOrder", node.KeepInputPortOrder())); + + // Push the parallel-code settings into HpcLinqVertex + bool multiThreading = this.m_context.Configuration.AllowConcurrentUserDelegatesInSingleProcess; + vertexMethod.Statements.Add(SetDryadVertexParamField("MultiThreading", multiThreading)); + vertexMethod.Statements.Add( + new CodeAssignStatement( + new CodeFieldReferenceExpression(DLVTypeExpr, "s_multiThreading"), + new CodePrimitiveExpression(multiThreading))); + + vertexMethod.Statements.Add(MakeDryadEnvDecl(node)); + + Type[] outputTypes = node.OutputTypes; + string[] writerNames = new string[outputTypes.Length]; + for (int i = 0; i < outputTypes.Length; i++) + { + CodeVariableDeclarationStatement + writerDecl = MakeDryadWriterDecl(outputTypes[i], this.GetStaticFactoryName(outputTypes[i])); + vertexMethod.Statements.Add(writerDecl); + writerNames[i] = writerDecl.Name; + } + + // Add side readers: + node.AddSideReaders(vertexMethod); + + // Generate code based on the node type: + switch (node.NodeType) + { + case QueryNodeType.Where: + case QueryNodeType.OrderBy: + case QueryNodeType.Distinct: + case QueryNodeType.Skip: + case QueryNodeType.SkipWhile: + case QueryNodeType.Take: + case QueryNodeType.TakeWhile: + case QueryNodeType.Merge: + case QueryNodeType.Select: + case QueryNodeType.SelectMany: + case QueryNodeType.GroupBy: + case QueryNodeType.BasicAggregate: + case QueryNodeType.Aggregate: + case QueryNodeType.Contains: + case QueryNodeType.Join: + case QueryNodeType.GroupJoin: + case QueryNodeType.Union: + case QueryNodeType.Intersect: + case QueryNodeType.Except: + case QueryNodeType.RangePartition: + case QueryNodeType.HashPartition: + case QueryNodeType.Apply: + case QueryNodeType.Fork: + case QueryNodeType.Dynamic: + { + Type[] inputTypes = node.InputTypes; + string[] sourceNames = new string[inputTypes.Length]; + for (int i = 0; i < inputTypes.Length; i++) + { + CodeVariableDeclarationStatement + readerDecl = MakeDryadReaderDecl(inputTypes[i], this.GetStaticFactoryName(inputTypes[i])); + vertexMethod.Statements.Add(readerDecl); + sourceNames[i] = readerDecl.Name; + } + string sourceToSink = this.m_vertexCodeGen.AddVertexCode(node, vertexMethod, sourceNames, writerNames); + + if (sourceToSink != null && (node.NodeType == QueryNodeType.Dynamic || node.Parents.Count > 0)) + { + CodeExpression sinkExpr = new CodeMethodInvokeExpression( + new CodeVariableReferenceExpression(writerNames[0]), + "WriteItemSequence", + new CodeVariableReferenceExpression(sourceToSink)); + vertexMethod.Statements.Add(sinkExpr); + } + break; + } + case QueryNodeType.Super: + { + string sourceToSink = this.m_vertexCodeGen.AddVertexCode(node, vertexMethod, null, writerNames); + if (sourceToSink != null && node.Parents.Count > 0) + { + CodeExpression sinkExpr = new CodeMethodInvokeExpression( + new CodeVariableReferenceExpression(writerNames[0]), + "WriteItemSequence", + new CodeVariableReferenceExpression(sourceToSink)); + vertexMethod.Statements.Add(sinkExpr); + } + break; + } + default: + { + //@@TODO: this should not be reachable. could change to Assert/InvalidOpEx + throw new DryadLinqException(HpcLinqErrorCode.Internal, + String.Format(SR.AddVertexNotHandled, node.NodeType)); + } + } + + string completedMsg = "DryadLinqLog.Add(\"Vertex " + vertexMethod.Name + + " completed at {0}\", DateTime.Now.ToString(\"MM/dd/yyyy HH:mm:ss.fff\"))"; + vertexMethod.Statements.Add(new CodeSnippetExpression(completedMsg)); + + // add a catch block + CodeCatchClause catchBlock = new CodeCatchClause("e"); + CodeTypeReferenceExpression errorReportClass = new CodeTypeReferenceExpression("HpcLinqVertexEnv"); + CodeMethodReferenceExpression + errorReportMethod = new CodeMethodReferenceExpression(errorReportClass, "ReportVertexError"); + CodeVariableReferenceExpression exRef = new CodeVariableReferenceExpression(catchBlock.LocalName); + catchBlock.Statements.Add(new CodeMethodInvokeExpression(errorReportMethod, exRef)); + tryBlock.CatchClauses.Add(catchBlock); + + // wrap the entire vertex method in a try/catch block + tryBlock.TryStatements.AddRange(vertexMethod.Statements); + vertexMethod.Statements.Clear(); + vertexMethod.Statements.Add(tryBlock); + + // Always add "return 0", to make CLR hosting happy... + vertexMethod.Statements.Add(new CodeMethodReturnStatement(ZeroExpr)); + + this.m_dryadVertexClass.Members.Add(vertexMethod); + return vertexMethod; + } + + internal string AddVertexCode(CodeMemberMethod vertexMethod, Pipeline pipeline) + { + if (pipeline.Length == 0) + { + //@@TODO: this should not be reachable. could change to Assert/InvalidOpEx + throw new DryadLinqException(HpcLinqErrorCode.Internal, SR.CannotBeEmpty); + } + DryadQueryNode firstNode = pipeline[0]; + if (firstNode.CanAttachPipeline) + { + firstNode.AttachedPipeline = pipeline; + return this.m_vertexCodeGen.AddVertexCode(firstNode, vertexMethod, pipeline.ReaderNames, pipeline.WriterNames); + } + + int startIndex = 0; + string applySource = pipeline.ReaderNames[0]; + if (!firstNode.IsHomomorphic) + { + applySource = this.m_vertexCodeGen.AddVertexCode(firstNode, vertexMethod, pipeline.ReaderNames, pipeline.WriterNames); + if (pipeline.Length == 1) return applySource; + startIndex = 1; + } + + // The vertex code + Type paramType = pipeline[startIndex].InputTypes[0].MakeArrayType(); + ParameterExpression param = Expression.Parameter(paramType, MakeUniqueName("x")); + CodeExpression pipelineArg = pipeline.BuildExpression(startIndex, param, param); + bool orderPreserving = (m_context.Configuration.SelectiveOrderPreservation || + pipeline[pipeline.Length - 1].OutputDataSetInfo.orderByInfo.IsOrdered); + + CodeExpression applyExpr; + if (m_context.Configuration.AllowConcurrentUserDelegatesInSingleProcess) + { + applyExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + "PApply", + new CodeVariableReferenceExpression(applySource), + pipelineArg, + new CodePrimitiveExpression(orderPreserving)); + } + else + { + applyExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + "DryadApply", + new CodeVariableReferenceExpression(applySource), + pipelineArg); + } + CodeVariableDeclarationStatement + sourceDecl = this.MakeVarDeclStatement("var", "source", applyExpr); + vertexMethod.Statements.Add(sourceDecl); + return sourceDecl.Name; + } + + private void GenerateCode(string dummyFile, string srcFile) + { + CSharpCodeProvider provider = new CSharpCodeProvider(); + + // Add a source file containing the following line: + string dummyClass = "namespace Microsoft.Research.DryadLinq { public static partial class " + ExtensionClassName + " { } }"; + using (StreamWriter srcWriter = new StreamWriter(dummyFile)) + { + srcWriter.Write(dummyClass); + srcWriter.Close(); + } + + // Generate code for dryadUnit and store in srcFile: + CodeGeneratorOptions options = new CodeGeneratorOptions(); + options.BracingStyle = "C"; + using (IndentedTextWriter srcWriter = new IndentedTextWriter(new StreamWriter(srcFile), " ")) + { + provider.GenerateCodeFromCompileUnit(this.m_dryadLinqUnit, srcWriter, options); + srcWriter.Close(); + } + } + + // Compile this compilation unit. Generate source code if asked + // loadGeneratedAssembly specifies whether the generated assembly + // should get loaded after compilation (into m_loadedVertexAssembly) + private void GenerateCodeAndCompile(string dummyFile, + string srcFile, + string binFile, + bool loadGeneratedAssembly) + { + // Generate the sources: + this.GenerateCode(dummyFile, srcFile); + + // Build the parameters for source compilation. + CompilerParameters cp = new CompilerParameters(); + + // Add assembly references. + HashSet assemblySet = new HashSet(); + assemblySet.Add("System"); + assemblySet.Add("System.Core"); + assemblySet.Add("System.Data"); + assemblySet.Add("System.Data.Linq"); + assemblySet.Add("System.Xml"); + foreach (string name in assemblySet) + { + cp.ReferencedAssemblies.Add(name + ".dll"); + } + + // Add references to assemblies referenced by entry assembly + foreach (Assembly asm in TypeSystem.GetAllAssemblies()) + { + string name = asm.GetName().Name; + if (name != "mscorlib" && + !assemblySet.Contains(name) && + !String.IsNullOrEmpty(asm.Location)) + { + cp.ReferencedAssemblies.Add(asm.Location); + } + } + + // Compiler options. + cp.CompilerOptions = @"/unsafe"; + if (! m_context.Configuration.CompileForVertexDebugging) + { + cp.CompilerOptions += @" /optimize+"; + } + + // Generate PDB. + cp.IncludeDebugInformation = m_context.Configuration.CompileForVertexDebugging; + + // Generate an executable instead of a class library. + cp.GenerateExecutable = false; + + // Set the assembly file name to generate. + cp.OutputAssembly = binFile; + + // Save the assembly as a physical file. + cp.GenerateInMemory = false; + + // Invoke compilation. + IDictionary providerOptions = new Dictionary(); + + // If the user hasn't requested "MatchClientNetFrameworkVersion" then we will force + // the compiler version to be 3.5 (or the same as the .NET version L2H is compiled against) + // This is to make sure we satisfy the minimun guaranteed .NET version on the cluster nodes. + // + // However if the user set config.MatchClientNetFrameworkVersion to true, it means + // they know the cluster nodes match the client's (possibly higher) .NET version, and + // they request compilation to match that. In that case, we will use whatever the default + // compiler version is (== matching client's .NET runtime) + // + // We also explicitly set the string to "v3.5" if the client is a v3.5 process (regardless of the MCNFV flag). + // This is because the default compiler version on .NET 3.5 is actually v2. + if (!m_context.Configuration.MatchClientNetFrameworkVersion || Environment.Version.Major < 4) + { + // Hardcode the compiler version to 3.5, which is what L2H is built against. + // NOTE: if we ever build L2H against a newer .NET version, we need to update this string. + providerOptions["CompilerVersion"] = "v3.5"; + } + + CSharpCodeProvider provider = new CSharpCodeProvider(providerOptions); + CompilerResults cr = provider.CompileAssemblyFromFile(cp, dummyFile, srcFile); + + if (cr.Errors.Count > 0) + { + // Display compilation errors. + Console.Error.WriteLine("Errors building {0}", cr.PathToAssembly); + foreach (CompilerError ce in cr.Errors) + { + HpcClientSideLog.Add(" {0}\n\r", ce.ToString()); + } + throw new DryadLinqException(HpcLinqErrorCode.FailedToBuild, + String.Format(SR.FailedToBuild, binFile, HpcClientSideLog.CLIENT_LOG_FILENAME)); + } + else + { + HpcClientSideLog.Add("{0} was built successfully.", cr.PathToAssembly); + this.m_generatedVertexDllPath = Path.Combine(Directory.GetCurrentDirectory(), binFile); + if (loadGeneratedAssembly) + { + this.m_loadedVertexAssembly = cr.CompiledAssembly; + } + + // @TODO: should we lock the generated DLL if a load wasn't requested? + } + } + + private void BuildAssembly(bool loadGeneratedAssembly) + { + // there's nothing to do if we previously built *and* loaded the vertex assembly + if (this.m_loadedVertexAssembly != null) return; + + // if we previously built a vertex DLL without loading it, and someone is + // requesting the loaded copy now, just load it and return currently there's + // no scenario that would hit this case, but adding it here for completeness sake. + if (loadGeneratedAssembly && this.m_generatedVertexDllPath != null) + { + this.m_loadedVertexAssembly = Assembly.LoadFrom(this.m_generatedVertexDllPath); + return; + } + + int inProcessVertexInstanceID = Interlocked.Increment(ref s_DryadLinqDllVersion); + + string dummyFile = HpcLinqCodeGen.GetPathForGeneratedFile(DummyExtensionSourceFile, null); + string targetName = HpcLinqCodeGen.GetPathForGeneratedFile(TargetDllName, inProcessVertexInstanceID); + string srcFile = HpcLinqCodeGen.GetPathForGeneratedFile(VertexSourceFile, inProcessVertexInstanceID); + + this.GenerateCodeAndCompile(dummyFile, srcFile, targetName, loadGeneratedAssembly); + } + + public static object GetFactory(HpcLinqContext context, Type type) + { + lock (s_codeGenLock) + { + if (s_TypeToFactory.ContainsKey(type)) + { + return s_TypeToFactory[type]; + } + + HpcLinqCodeGen codeGen = new HpcLinqCodeGen(context, new VertexCodeGen()); + codeGen.AddDryadCodeForType(type); + + // build assembly, and load into memory, because we'll next instantiate + // the factory type out of the generated assembly. + codeGen.BuildAssembly(true); + + string factoryTypeFullName = TargetNamespace + "." + HpcLinqFactoryClassName(type); + object factory = codeGen.m_loadedVertexAssembly.CreateInstance(factoryTypeFullName); + s_TypeToFactory.Add(type, factory); + return factory; + } + } + + internal void BuildDryadLinqAssembly(HpcLinqQueryGen queryGen) + { + lock (s_codeGenLock) + { + // this method only gets called from HpcLinqQueryGen.GenerateDryadProgram() before job submission. + // Since we don't load the generated vertex DLL after that, the check for + // "should re-gen?" below is based on m_generatedVertexDllPath being set + if (this.m_generatedVertexDllPath== null) + { + queryGen.CodeGenVisit(); + this.BuildAssembly(false); + } + } + } + + internal string GetDryadLinqDllName() + { + if (this.m_generatedVertexDllPath == null) + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, SR.AutogeneratedAssemblyMissing); + } + return Path.GetFileName(this.m_generatedVertexDllPath); + } + + internal string GetTargetLocation() + { + if (this.m_generatedVertexDllPath== null) + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, SR.AutogeneratedAssemblyMissing); + } + return this.m_generatedVertexDllPath; + } + + // + // Utility method for creating a unique path for generated files (vertex source, DLL, object store, query plan XML etc.) + // Each process gets its own directory under the temp path (formatted as LINQTOHPC_), and all generated files go under that directory + // + internal static string GetPathForGeneratedFile(string fileNameTemplate, int? inProcessInstanceID) + { + Process process = System.Diagnostics.Process.GetCurrentProcess(); + + // The temp folder format is: + // :\Users\\AppData\Local\Temp\LINQTOHPC\_ + // + string artifactsPath = Path.Combine(System.IO.Path.GetTempPath(), + String.Format("LINQTOHPC\\{0}_{1}", + Path.GetFileNameWithoutExtension(process.MainModule.ModuleName), + process.Id)); + if (!Directory.Exists(artifactsPath)) + { + Directory.CreateDirectory(artifactsPath); + } + + string fileName = null; + if (inProcessInstanceID != null) + { + // If an in-process instance ID is provided we format the filename template to + // . + // + string baseFileName = Path.GetFileNameWithoutExtension(fileNameTemplate); + string fileExtension = Path.GetExtension(fileNameTemplate); + fileName = String.Format("{0}{1}{2}", baseFileName, inProcessInstanceID.Value, fileExtension); + } + else + { + // otherwise use the filename as is + fileName = fileNameTemplate; + } + + return Path.Combine(artifactsPath, fileName); + } + } +} diff --git a/LinqToDryad/DryadFactory.cs b/LinqToDryad/DryadFactory.cs new file mode 100644 index 0000000..87834de --- /dev/null +++ b/LinqToDryad/DryadFactory.cs @@ -0,0 +1,455 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; +using System.Reflection; +using System.Data.SqlTypes; +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq.Internal +{ + public abstract class HpcLinqFactory + { + public abstract HpcRecordReader MakeReader(NativeBlockStream nativeStream); + public abstract HpcRecordReader MakeReader(IntPtr handle, UInt32 port); + public abstract HpcRecordWriter MakeWriter(NativeBlockStream nativeStream); + public abstract HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize); + } + + public sealed class HpcLinqFactoryByte : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordByteReader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordByteReader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordByteWriter(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordByteWriter(new HpcBinaryWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactorySByte : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordSByteReader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordSByteReader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordSByteWriter(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordSByteWriter(new HpcBinaryWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactoryBool : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordBoolReader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordBoolReader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordBoolWriter(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordBoolWriter(new HpcBinaryWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactoryChar : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordCharReader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordCharReader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordCharWriter(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordCharWriter(new HpcBinaryWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactoryShort : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordShortReader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordShortReader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordShortWriter(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordShortWriter(new HpcBinaryWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactoryUShort : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordUShortReader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordUShortReader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordUShortWriter(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordUShortWriter(new HpcBinaryWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactoryInt32 : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordInt32Reader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordInt32Reader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordInt32Writer(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordInt32Writer(new HpcBinaryWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactoryUInt32 : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordUInt32Reader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordUInt32Reader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordUInt32Writer(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordUInt32Writer(new HpcBinaryWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactoryInt64 : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordInt64Reader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordInt64Reader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordInt64Writer(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordInt64Writer(new HpcBinaryWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactoryUInt64 : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordUInt64Reader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordUInt64Reader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordUInt64Writer(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordUInt64Writer(new HpcBinaryWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactoryFloat : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordFloatReader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordFloatReader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordFloatWriter(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordFloatWriter(new HpcBinaryWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactoryDecimal : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordDecimalReader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordDecimalReader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordDecimalWriter(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordDecimalWriter(new HpcBinaryWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactoryDouble : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordDoubleReader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordDoubleReader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordDoubleWriter(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordDoubleWriter(new HpcBinaryWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactoryDateTime : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordDateTimeReader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordDateTimeReader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordDateTimeWriter(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordDateTimeWriter(new HpcBinaryWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactoryString : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordStringReader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordStringReader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordStringWriter(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordStringWriter(new HpcBinaryWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactoryGuid : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordGuidReader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordGuidReader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordGuidWriter(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordGuidWriter(new HpcBinaryWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactoryLineRecord : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordTextReader(new HpcTextReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordTextReader(new HpcTextReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordTextWriter(new HpcTextWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordTextWriter(new HpcTextWriter(handle, port, buffSize)); + } + } + + public sealed class HpcLinqFactorySqlDateTime : HpcLinqFactory + { + public override HpcRecordReader MakeReader(NativeBlockStream nativeStream) + { + return new HpcRecordSqlDateTimeReader(new HpcBinaryReader(nativeStream)); + } + + public override HpcRecordReader MakeReader(IntPtr handle, UInt32 port) + { + return new HpcRecordSqlDateTimeReader(new HpcBinaryReader(handle, port)); + } + + public override HpcRecordWriter MakeWriter(NativeBlockStream nativeStream) + { + return new HpcRecordSqlDateTimeWriter(new HpcBinaryWriter(nativeStream)); + } + + public override HpcRecordWriter MakeWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcRecordSqlDateTimeWriter(new HpcBinaryWriter(handle, port, buffSize)); + } + } +} diff --git a/LinqToDryad/DryadLinqCollection.cs b/LinqToDryad/DryadLinqCollection.cs new file mode 100644 index 0000000..20115d6 --- /dev/null +++ b/LinqToDryad/DryadLinqCollection.cs @@ -0,0 +1,1470 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq.Internal +{ + [Serializable] + public struct IndexedValue : IEquatable>, IComparable> + { + private int _index; + private T _value; + + public int Index + { + get { return _index; } + set { _index = value; } + } + + public T Value + { + get { return _value; } + set { _value = value; } + } + + public IndexedValue(int index, T value) + { + _index = index; + _value = value; + } + + public bool Equals(IndexedValue val) + { + return this.Index == val.Index; + } + + public int CompareTo(IndexedValue val) + { + return this.Index - val.Index; + } + + public override int GetHashCode() + { + return this.Index; + } + + public override bool Equals(object obj) + { + if (!(obj is IndexedValue)) + return false; + + return this.Equals((IndexedValue)obj); + } + + public static bool operator ==(IndexedValue a, IndexedValue b) + { + return a.Equals(b); + } + + public static bool operator !=(IndexedValue a, IndexedValue b) + { + return !a.Equals(b); + } + } + + public struct HpcLinqGrouping : IGrouping + { + private K m_key; + private IEnumerable m_elems; + + public HpcLinqGrouping(K key, IEnumerable elems) + { + this.m_key = key; + this.m_elems = elems; + } + + public K Key + { + get { return this.m_key; } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + return this.m_elems.GetEnumerator(); + } + } + + public class BigCollection : IEnumerable + { + protected const int ChunkSize = (1 << 21); + + private TElement[][] m_elements; + private Int32 m_pos1; + private Int32 m_pos2; + + public BigCollection() + { + this.m_elements = new TElement[4][]; + this.m_elements[0] = new TElement[2]; + this.m_pos1 = 0; + this.m_pos2 = 0; + } + + public long Count() + { + return (ChunkSize * (long)this.m_pos1) + this.m_pos2; + } + + public TElement this[long index] + { + get { return this.m_elements[index/ChunkSize][index%ChunkSize]; } + set { this.m_elements[index/ChunkSize][index%ChunkSize] = value; } + } + + public void Add(TElement elem) + { + if (this.m_pos2 == this.m_elements[this.m_pos1].Length) + { + if (this.m_pos2 == ChunkSize) + { + this.m_pos1++; + this.m_pos2 = 0; + if (this.m_pos1 == this.m_elements.Length) + { + TElement[][] elems = new TElement[this.m_pos1 * 2][]; + Array.Copy(this.m_elements, 0, elems, 0, this.m_pos1); + this.m_elements = elems; + } + this.m_elements[this.m_pos1] = new TElement[ChunkSize]; + } + else + { + TElement[] newElems = new TElement[this.m_pos2 * 2]; + Array.Copy(this.m_elements[this.m_pos1], 0, newElems, 0, this.m_pos2); + this.m_elements[this.m_pos1] = newElems; + } + } + + this.m_elements[this.m_pos1][this.m_pos2] = elem; + this.m_pos2++; + } + + public IEnumerable Reverse() + { + for (int i = this.m_pos2 - 1; i >= 0; i--) + { + yield return this.m_elements[this.m_pos1][i]; + } + for (int i = this.m_pos1 - 1; i >= 0; i--) + { + TElement[] elems = this.m_elements[i]; + for (int j = elems.Length - 1; j >= 0; j--) + { + yield return elems[j]; + } + } + } + + public IEnumerator GetEnumerator() + { + for (int i = 0; i < this.m_pos1; i++) + { + TElement[] elems = this.m_elements[i]; + for (int j = 0; j < elems.Length; j++) + { + yield return elems[j]; + } + } + for (int i = 0; i < this.m_pos2; i++) + { + yield return this.m_elements[this.m_pos1][i]; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + } + + public class BigDictionary : IEnumerable> + { + private const int Ratio = 2; + private const int LogChunkSize = 21; + private const int ChunkSize = (1 << LogChunkSize); + private const int ChunkMask = ChunkSize - 1; + private const uint MaxCount = UInt32.MaxValue - ChunkSize; + + private IEqualityComparer m_comparer; + private UInt32[] m_buckets; + private Entry[][] m_entries; + private UInt32 m_pos1; + private UInt32 m_pos2; + private Int64 m_count; + private UInt32 m_freeList; + + public BigDictionary() + : this(null, 1024) + { + } + + public BigDictionary(IEqualityComparer comparer) + : this(comparer, 1024) + { + } + + public BigDictionary(IEqualityComparer comparer, int initialCapacity) + { + this.m_comparer = (comparer == null) ? EqualityComparer.Default : comparer; + this.m_buckets = new uint[CollectionHelper.GetNextPrime(initialCapacity)]; + this.m_entries = new Entry[4][]; + this.m_entries[0] = new Entry[ChunkSize]; + this.m_pos1 = 0; + this.m_pos2 = 1; + this.m_count = 0; + this.m_freeList = 0; + } + + public Int64 Count + { + get { return this.m_count; } + } + + public TValue this[TKey key] + { + get { + TValue value; + if (!this.TryGetValue(key, out value)) + { + throw new DryadLinqException(HpcLinqErrorCode.KeyNotFound, SR.KeyNotFound); + } + return value; + } + set { + this.Add(key, value); + } + } + + public bool TryGetValue(TKey key, out TValue value) + { + int hashCode = this.m_comparer.GetHashCode(key) & 0x7FFFFFFF; + int bucket = hashCode % this.m_buckets.Length; + UInt32 index = this.m_buckets[bucket]; + while (index > 0) + { + Entry entry = this.m_entries[index >> LogChunkSize][index & ChunkMask]; + if (this.m_comparer.Equals(entry.m_key, key)) + { + value = entry.m_value; + return true; + } + index = entry.m_next; + } + value = default(TValue); + return false; + } + + public bool ContainsKey(TKey key) + { + TValue value; + return this.TryGetValue(key, out value); + } + + public bool Add(TKey key, TValue value) + { + int hashCode = this.m_comparer.GetHashCode(key) & 0x7FFFFFFF; + int bucket = hashCode % this.m_buckets.Length; + UInt32 index = this.m_buckets[bucket]; + while (index > 0) + { + Entry entry = this.m_entries[index >> LogChunkSize][index & ChunkMask]; + if (this.m_comparer.Equals(entry.m_key, key)) + { + return false; + } + index = entry.m_next; + } + + // is not in the dictionary, so add it + if (this.m_freeList > 0) + { + index = this.m_freeList; + this.m_freeList = this.m_entries[index >> LogChunkSize][index & ChunkMask].m_next; + Entry newEntry = new Entry(key, value, this.m_buckets[bucket]); + this.m_entries[index >> LogChunkSize][index & ChunkMask] = newEntry; + } + else + { + if (this.m_count == (this.m_buckets.Length * Ratio)) + { + this.Resize(); + bucket = hashCode % this.m_buckets.Length; + } + Entry newEntry = new Entry(key, value, this.m_buckets[bucket]); + if (this.m_pos2 == ChunkSize) + { + if (this.m_count >= MaxCount) + { + throw new DryadLinqException(HpcLinqErrorCode.TooManyItems, SR.TooManyItems); + } + this.m_pos1++; + this.m_pos2 = 0; + if (this.m_pos1 == this.m_entries.Length) + { + Entry[][] newEntries = new Entry[this.m_pos1 * 2][]; + Array.Copy(this.m_entries, 0, newEntries, 0, this.m_pos1); + this.m_entries = newEntries; + } + this.m_entries[this.m_pos1] = new Entry[ChunkSize]; + } + + this.m_entries[this.m_pos1][this.m_pos2] = newEntry; + index = ((this.m_pos1 << LogChunkSize) | this.m_pos2); + this.m_pos2++; + } + + this.m_buckets[bucket] = index; + this.m_count++; + return true; + } + + // Remove an item from the set. Return true iff the item is in the set and + // is removed successfully from the set. + public bool Remove(TKey key) + { + int hashCode = this.m_comparer.GetHashCode(key) & 0x7FFFFFFF; + int bucket = hashCode % this.m_buckets.Length; + + uint pidx = 0; + uint cidx = this.m_buckets[bucket]; + while (cidx > 0) + { + Entry entry = this.m_entries[cidx >> LogChunkSize][cidx & ChunkMask]; + if (this.m_comparer.Equals(entry.m_key, key)) + { + if (pidx == 0) + { + this.m_buckets[bucket] = entry.m_next; + } + else + { + this.m_entries[pidx >> LogChunkSize][pidx & ChunkMask].m_next = entry.m_next; + } + this.m_entries[cidx >> LogChunkSize][cidx & ChunkMask].m_next = this.m_freeList; + this.m_freeList = cidx; + this.m_count--; + return true; + } + pidx = cidx; + cidx = entry.m_next; + } + return false; + } + + private void Resize() + { + int oldSize = this.m_buckets.Length; + int newSize = CollectionHelper.GetNextPrime(oldSize); + if (newSize > oldSize) + { + this.m_buckets = new uint[newSize]; + for (uint i = 1; i <= this.m_count; i++) + { + Entry entry = this.m_entries[i >> LogChunkSize][i & ChunkMask]; + int bucket = (this.m_comparer.GetHashCode(entry.m_key) & 0x7FFFFFFF) % newSize; + this.m_entries[i >> LogChunkSize][i & ChunkMask].m_next = this.m_buckets[bucket]; + this.m_buckets[bucket] = i; + } + } + } + + public IEnumerable GetKeys() + { + for (int i = 0; i < this.m_buckets.Length; i++) + { + uint index = this.m_buckets[i]; + while (index > 0) + { + Entry entry = this.m_entries[index >> LogChunkSize][index & ChunkMask]; + yield return entry.m_key; + index = entry.m_next; + } + } + } + + public IEnumerable GetValues() + { + for (int i = 0; i < this.m_buckets.Length; i++) + { + uint index = this.m_buckets[i]; + while (index > 0) + { + Entry entry = this.m_entries[index >> LogChunkSize][index & ChunkMask]; + yield return entry.m_value; + index = entry.m_next; + } + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator> GetEnumerator() + { + for (int i = 0; i < this.m_buckets.Length; i++) + { + uint index = this.m_buckets[i]; + while (index > 0) + { + Entry entry = this.m_entries[index >> LogChunkSize][index & ChunkMask]; + yield return new Pair(entry.m_key, entry.m_value); + index = entry.m_next; + } + } + } + + private struct Entry + { + public TKey m_key; + public TValue m_value; + public UInt32 m_next; + + public Entry(TKey key, TValue value, UInt32 next) + { + this.m_key = key; + this.m_value = value; + this.m_next = next; + } + } + } + + public class AccumulateDictionary : IEnumerable> + { + private const int Ratio = 2; + private const int LogChunkSize = 21; + private const int ChunkSize = (1 << LogChunkSize); + private const int ChunkMask = ChunkSize - 1; + private const uint MaxCount = UInt32.MaxValue - ChunkSize; + + private IEqualityComparer m_comparer; + private Func m_seed; + private Func m_accumulator; + private UInt32[] m_buckets; + private Entry[][] m_entries; + private UInt32 m_pos1; + private UInt32 m_pos2; + private Int64 m_count; + + public AccumulateDictionary(Func seed, + Func accumulator) + : this(null, 1024, seed, accumulator) + { + } + + public AccumulateDictionary(IEqualityComparer comparer, + Func seed, + Func accumulator) + : this(comparer, 1024, seed, accumulator) + { + } + + public AccumulateDictionary(IEqualityComparer comparer, + int initialCapacity, + Func seed, + Func accumulator) + { + this.m_comparer = (comparer == null) ? EqualityComparer.Default : comparer; + this.m_seed = seed; + this.m_accumulator = accumulator; + this.m_buckets = new uint[CollectionHelper.GetNextPrime(initialCapacity)]; + this.m_entries = new Entry[4][]; + this.m_entries[0] = new Entry[ChunkSize]; + this.m_pos1 = 0; + this.m_pos2 = 1; + this.m_count = 0; + } + + public Int64 Count + { + get { return this.m_count; } + } + + public bool TryGetValue(TKey key, out TValue value) + { + int hashCode = this.m_comparer.GetHashCode(key) & 0x7FFFFFFF; + int bucket = hashCode % this.m_buckets.Length; + UInt32 index = this.m_buckets[bucket]; + while (index > 0) + { + Entry entry = this.m_entries[index >> LogChunkSize][index & ChunkMask]; + if (this.m_comparer.Equals(entry.m_key, key)) + { + value = entry.m_value; + return true; + } + index = entry.m_next; + } + value = default(TValue); + return false; + } + + public bool ContainsKey(TKey key) + { + TValue value; + return this.TryGetValue(key, out value); + } + + public void Add(TKey key, TSource elem) + { + int hashCode = this.m_comparer.GetHashCode(key) & 0x7FFFFFFF; + int bucket = hashCode % this.m_buckets.Length; + UInt32 index = this.m_buckets[bucket]; + while (index > 0) + { + Entry entry = this.m_entries[index >> LogChunkSize][index & ChunkMask]; + if (this.m_comparer.Equals(entry.m_key, key)) + { + this.m_entries[index >> LogChunkSize][index & ChunkMask].m_value + = this.m_accumulator(entry.m_value, elem); + return; + } + index = entry.m_next; + } + + // is not in the dictionary, so add it + if (this.m_count == (this.m_buckets.Length * Ratio)) + { + this.Resize(); + bucket = hashCode % this.m_buckets.Length; + } + TValue val = this.m_seed(elem); + Entry newEntry = new Entry(key, val, this.m_buckets[bucket]); + if (this.m_pos2 == ChunkSize) + { + if (this.m_count >= MaxCount) + { + throw new DryadLinqException("Too many items"); + } + this.m_pos1++; + this.m_pos2 = 0; + if (this.m_pos1 == this.m_entries.Length) + { + Entry[][] newEntries = new Entry[this.m_pos1 * 2][]; + Array.Copy(this.m_entries, 0, newEntries, 0, this.m_pos1); + this.m_entries = newEntries; + } + this.m_entries[this.m_pos1] = new Entry[ChunkSize]; + } + + this.m_entries[this.m_pos1][this.m_pos2] = newEntry; + index = ((this.m_pos1 << LogChunkSize) | this.m_pos2); + this.m_pos2++; + + this.m_buckets[bucket] = index; + this.m_count++; + } + + private void Resize() + { + int oldSize = this.m_buckets.Length; + int newSize = CollectionHelper.GetNextPrime(oldSize); + if (newSize > oldSize) + { + this.m_buckets = new uint[newSize]; + for (uint i = 1; i <= this.m_count; i++) + { + Entry entry = this.m_entries[i >> LogChunkSize][i & ChunkMask]; + int bucket = (this.m_comparer.GetHashCode(entry.m_key) & 0x7FFFFFFF) % newSize; + this.m_entries[i >> LogChunkSize][i & ChunkMask].m_next = this.m_buckets[bucket]; + this.m_buckets[bucket] = i; + } + } + } + + public IEnumerable GetKeys() + { + for (int i = 0; i < this.m_buckets.Length; i++) + { + uint index = this.m_buckets[i]; + while (index > 0) + { + Entry entry = this.m_entries[index >> LogChunkSize][index & ChunkMask]; + yield return entry.m_key; + index = entry.m_next; + } + } + } + + public IEnumerable GetValues() + { + for (int i = 0; i < this.m_buckets.Length; i++) + { + uint index = this.m_buckets[i]; + while (index > 0) + { + Entry entry = this.m_entries[index >> LogChunkSize][index & ChunkMask]; + yield return entry.m_value; + index = entry.m_next; + } + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator> GetEnumerator() + { + for (int i = 0; i < this.m_buckets.Length; i++) + { + uint index = this.m_buckets[i]; + while (index > 0) + { + Entry entry = this.m_entries[index >> LogChunkSize][index & ChunkMask]; + yield return new Pair(entry.m_key, entry.m_value); + index = entry.m_next; + } + } + } + + private struct Entry + { + public TKey m_key; + public TValue m_value; + public UInt32 m_next; + + public Entry(TKey key, TValue value, UInt32 next) + { + this.m_key = key; + this.m_value = value; + this.m_next = next; + } + } + } + + public class BigHashSet : IEnumerable + { + private const int Ratio = 2; + private const int LogChunkSize = 21; + private const int ChunkSize = (1 << LogChunkSize); + private const int ChunkMask = ChunkSize - 1; + private const uint MaxCount = UInt32.MaxValue - ChunkSize; + + private IEqualityComparer m_comparer; + private UInt32[] m_buckets; + private Entry[][] m_entries; + private UInt32 m_pos1; + private UInt32 m_pos2; + private Int64 m_count; + private UInt32 m_freeList; + + public BigHashSet() + : this(null, 1024) + { + } + + public BigHashSet(IEqualityComparer comparer) + : this(comparer, 1024) + { + } + + public BigHashSet(IEqualityComparer comparer, int initialCapacity) + { + this.m_comparer = (comparer == null) ? EqualityComparer.Default : comparer; + this.m_buckets = new uint[CollectionHelper.GetNextPrime(initialCapacity)]; + this.m_entries = new Entry[4][]; + this.m_entries[0] = new Entry[ChunkSize]; + this.m_pos1 = 0; + this.m_pos2 = 1; + this.m_count = 0; + this.m_freeList = 0; + } + + public Int64 Count + { + get { return this.m_count; } + } + + public bool Contains(TElement item) + { + int hashCode = this.m_comparer.GetHashCode(item) & 0x7FFFFFFF; + int bucket = hashCode % this.m_buckets.Length; + UInt32 index = this.m_buckets[bucket]; + while (index > 0) + { + Entry entry = this.m_entries[index >> LogChunkSize][index & ChunkMask]; + if (this.m_comparer.Equals(entry.m_item, item)) + { + return true; + } + index = entry.m_next; + } + return false; + } + + // Add an item into the set. Return true iff the item is not in the set and + // is added successfully into the set. + public bool Add(TElement item) + { + int hashCode = this.m_comparer.GetHashCode(item) & 0x7FFFFFFF; + int bucket = hashCode % this.m_buckets.Length; + UInt32 index = this.m_buckets[bucket]; + while (index > 0) + { + Entry entry = this.m_entries[index >> LogChunkSize][index & ChunkMask]; + if (this.m_comparer.Equals(entry.m_item, item)) + { + return false; + } + index = entry.m_next; + } + + // item is not in the set, so add it + if (this.m_freeList > 0) + { + index = this.m_freeList; + this.m_freeList = this.m_entries[index >> LogChunkSize][index & ChunkMask].m_next; + Entry newEntry = new Entry(item, this.m_buckets[bucket]); + this.m_entries[index >> LogChunkSize][index & ChunkMask] = newEntry; + } + else + { + if (this.m_count == (this.m_buckets.Length * Ratio)) + { + this.Resize(); + bucket = hashCode % this.m_buckets.Length; + } + Entry newEntry = new Entry(item, this.m_buckets[bucket]); + if (this.m_pos2 == ChunkSize) + { + if (this.m_count >= MaxCount) + { + throw new DryadLinqException(HpcLinqErrorCode.TooManyItems, SR.TooManyItems); + } + this.m_pos1++; + this.m_pos2 = 0; + if (this.m_pos1 == this.m_entries.Length) + { + Entry[][] newEntries = new Entry[this.m_pos1 * 2][]; + Array.Copy(this.m_entries, 0, newEntries, 0, this.m_pos1); + this.m_entries = newEntries; + } + this.m_entries[this.m_pos1] = new Entry[ChunkSize]; + } + + this.m_entries[this.m_pos1][this.m_pos2] = newEntry; + index = ((this.m_pos1 << LogChunkSize) | this.m_pos2); + this.m_pos2++; + } + + this.m_buckets[bucket] = index; + this.m_count++; + return true; + } + + // Remove an item from the set. Return true iff the item is in the set and + // is removed successfully from the set. + public bool Remove(TElement item) + { + int hashCode = this.m_comparer.GetHashCode(item) & 0x7FFFFFFF; + int bucket = hashCode % this.m_buckets.Length; + + uint pidx = 0; + uint cidx = this.m_buckets[bucket]; + while (cidx > 0) + { + Entry entry = this.m_entries[cidx >> LogChunkSize][cidx & ChunkMask]; + if (this.m_comparer.Equals(entry.m_item, item)) + { + if (pidx == 0) + { + this.m_buckets[bucket] = entry.m_next; + } + else + { + this.m_entries[pidx >> LogChunkSize][pidx & ChunkMask].m_next = entry.m_next; + } + this.m_entries[cidx >> LogChunkSize][cidx & ChunkMask].m_next = this.m_freeList; + this.m_freeList = cidx; + this.m_count--; + return true; + } + pidx = cidx; + cidx = entry.m_next; + } + return false; + } + + private void Resize() + { + int oldSize = this.m_buckets.Length; + int newSize = CollectionHelper.GetNextPrime(oldSize); + if (newSize > oldSize) + { + this.m_buckets = new uint[newSize]; + for (uint i = 1; i <= this.m_count; i++) + { + Entry entry = this.m_entries[i >> LogChunkSize][i & ChunkMask]; + int bucket = (this.m_comparer.GetHashCode(entry.m_item) & 0x7FFFFFFF) % newSize; + this.m_entries[i >> LogChunkSize][i & ChunkMask].m_next = this.m_buckets[bucket]; + this.m_buckets[bucket] = i; + } + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + for (int i = 0; i < this.m_buckets.Length; i++) + { + uint index = this.m_buckets[i]; + while (index > 0) + { + Entry entry = this.m_entries[index >> LogChunkSize][index & ChunkMask]; + yield return entry.m_item; + index = entry.m_next; + } + } + } + + private struct Entry + { + public TElement m_item; + public UInt32 m_next; + + public Entry(TElement item, UInt32 next) + { + this.m_item = item; + this.m_next = next; + } + } + } + + public class Grouping : IGrouping + { + private static int MaxCount = Int32.MaxValue / TypeSystem.GetInMemSize(typeof(TElement)); + + private TKey m_key; + private TElement[] m_elements; + private int m_count; + private Grouping m_next; + + public Grouping(TKey key) + : this(key, 2) + { + } + + internal Grouping(TKey key, int len) + { + this.m_key = key; + this.m_elements = new TElement[len]; + this.m_count = 0; + this.m_next = null; + } + + public IEnumerator GetEnumerator() + { + for (int i = 0; i < this.m_count; i++) + { + yield return this.m_elements[i]; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public TKey Key + { + get { return this.m_key; } + } + + public int Count() + { + return this.m_count; + } + + internal TElement[] Elements + { + get { return m_elements; } + } + + internal Grouping Next + { + get { return this.m_next; } + set { this.m_next = value; } + } + + public void AddItem(TElement elem) + { + if (this.m_elements.Length == this.m_count) + { + if (this.m_count >= MaxCount) + { + throw new DryadLinqException("Too many elements in a single group."); + } + int newSize = this.m_count * 2; + if (newSize > MaxCount) newSize = MaxCount; + TElement[] newElements = new TElement[newSize]; + Array.Copy(this.m_elements, 0, newElements, 0, this.m_count); + this.m_elements = newElements; + } + this.m_elements[this.m_count++] = elem; + } + + public override string ToString() + { + return "Grouping[" + this.Key + "]"; + } + } + + public class BigGrouping : IGrouping + { + protected const int ChunkSize = (1 << 21); + + private TKey m_key; + private TElement[][] m_elements; + private int m_pos1; + private int m_pos2; + private BigGrouping m_next; + + public BigGrouping(TKey key) + { + this.m_key = key; + this.m_elements = new TElement[2][]; + this.m_elements[0] = new TElement[2]; + this.m_pos1 = 0; + this.m_pos2 = 0; + this.m_next = null; + } + + public TKey Key + { + get { return this.m_key; } + } + + public long Count() + { + return (ChunkSize * (long)this.m_pos1) + this.m_pos2; + } + + internal BigGrouping Next + { + get { return this.m_next; } + set { this.m_next = value; } + } + + public void AddItem(TElement elem) + { + if (this.m_pos2 == this.m_elements[this.m_pos1].Length) + { + if (this.m_pos2 == ChunkSize) + { + this.m_pos1++; + this.m_pos2 = 0; + if (this.m_pos1 == this.m_elements.Length) + { + TElement[][] elems = new TElement[this.m_pos1 * 2][]; + Array.Copy(this.m_elements, 0, elems, 0, this.m_pos1); + this.m_elements = elems; + } + this.m_elements[this.m_pos1] = new TElement[ChunkSize]; + } + else + { + TElement[] newElems = new TElement[this.m_pos2 * 2]; + Array.Copy(this.m_elements[this.m_pos1], 0, newElems, 0, this.m_pos2); + this.m_elements[this.m_pos1] = newElems; + } + } + + this.m_elements[this.m_pos1][this.m_pos2] = elem; + this.m_pos2++; + } + + public IEnumerator GetEnumerator() + { + for (int i = 0; i < this.m_pos1; i++) + { + TElement[] elems = this.m_elements[i]; + for (int j = 0; j < elems.Length; j++) + { + yield return elems[j]; + } + } + for (int j = 0; j < this.m_pos2; j++) + { + yield return this.m_elements[this.m_pos1][j]; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public override string ToString() + { + return "Grouping[" + this.Key + "]"; + } + } + + public class Grouping + { + private TKey m_key; + private TElement[] m_elems; + private TResult[] m_results; + private int m_count; + private Grouping m_next; + + public Grouping(TKey key) + : this(key, 2) + { + } + + internal Grouping(TKey key, int len) + { + this.m_key = key; + this.m_elems = new TElement[len]; + this.m_count = 0; + this.m_results = new TResult[2]; + this.m_next = null; + } + + public TKey Key + { + get { return this.m_key; } + } + + public int Count() + { + return this.ElemCount; + } + + private int ElemCount + { + get { return (this.m_count & 0xFFFFFF); } + } + + private int ResCount + { + get { return (this.m_count >> 24); } + } + + internal Grouping Next + { + get { return this.m_next; } + set { this.m_next = value; } + } + + public void Reduce(Func, TResult> resultSelector, + Func, TResult> combiner) + { + int elemCnt = this.ElemCount; + if (elemCnt > 0) + { + TElement[] curElems = this.m_elems; + if (elemCnt < this.m_elems.Length) + { + curElems = new TElement[elemCnt]; + Array.Copy(this.m_elems, 0, curElems, 0, elemCnt); + } + int resCount = this.ResCount; + this.m_results[resCount] = resultSelector(curElems); + if (resCount == 1) + { + this.m_results[0] = combiner(this.m_results); + } + + this.m_elems = new TElement[2]; + this.m_count = 0x1000000; + } + } + + public TResult GetResult(Func, TResult> resultSelector, + Func, TResult> combiner) + { + int elemCnt = this.ElemCount; + if (elemCnt > 0) + { + TElement[] curElems = this.m_elems; + if (elemCnt < this.m_elems.Length) + { + curElems = new TElement[elemCnt]; + Array.Copy(this.m_elems, 0, curElems, 0, elemCnt); + } + int resCount = this.ResCount; + this.m_results[resCount] = resultSelector(curElems); + if (resCount == 1) + { + this.m_results[0] = combiner(this.m_results); + } + } + return this.m_results[0]; + } + + public void AddItem(TElement elem) + { + int elemCnt = this.ElemCount; + if (this.m_elems.Length == elemCnt) + { + if (elemCnt >= 0x800000) + { + throw new DryadLinqException(HpcLinqErrorCode.TooManyElementsBeforeReduction, + SR.TooManyElementsBeforeReduction); + } + TElement[] newElems = new TElement[elemCnt * 2]; + Array.Copy(this.m_elems, 0, newElems, 0, elemCnt); + this.m_elems = newElems; + } + this.m_elems[elemCnt] = elem; + this.m_count++; + } + + public override string ToString() + { + return "Grouping[" + this.Key + "]"; + } + } + + public class GroupingHashSet : IEnumerable> + { + private const int Ratio = 2; + private const int MaxGroupSize = 64; + + private Grouping[] m_buckets; + private IEqualityComparer m_comparer; + private long m_count; + private int m_maxGroupSize; + + public GroupingHashSet(IEqualityComparer comparer) + : this(comparer, 1024, MaxGroupSize) + { + } + + internal GroupingHashSet(IEqualityComparer comparer, int capacity) + : this(comparer, capacity, MaxGroupSize) + { + } + + internal GroupingHashSet(IEqualityComparer comparer, int capacity, int maxGroupSize) + { + int size = CollectionHelper.GetNextPrime(capacity); + this.m_buckets = new Grouping[size]; + this.m_comparer = comparer; + this.m_count = 0; + this.m_maxGroupSize = maxGroupSize; + } + + internal Grouping GetGroup(TKey key) + { + int hashCode = this.m_comparer.GetHashCode(key); + int startIdx = (hashCode & 0x7FFFFFFF) % this.m_buckets.Length; + for (Grouping g = this.m_buckets[startIdx]; g != null; g = g.Next) + { + if (hashCode == this.m_comparer.GetHashCode(g.Key) && + this.m_comparer.Equals(key, g.Key)) + { + return g; + } + } + return null; + } + + public Grouping AddItem(TKey key, TElement elem) + { + int hashCode = this.m_comparer.GetHashCode(key); + int startIdx = (hashCode & 0x7FFFFFFF) % this.m_buckets.Length; + for (Grouping g = this.m_buckets[startIdx]; g != null; g = g.Next) + { + if (hashCode == this.m_comparer.GetHashCode(g.Key) && + this.m_comparer.Equals(key, g.Key)) + { + g.AddItem(elem); + return g; + } + } + + // Add a new group for the element: + if (this.m_count == (this.m_buckets.Length * Ratio)) + { + this.Resize(); + startIdx = (hashCode & 0x7FFFFFFF) % this.m_buckets.Length; + } + Grouping newGroup = new Grouping(key); + newGroup.AddItem(elem); + newGroup.Next = this.m_buckets[startIdx]; + this.m_buckets[startIdx] = newGroup; + this.m_count++; + return newGroup; + } + + internal Grouping AddItemPartial(TKey key, TElement elem) + { + int hashCode = this.m_comparer.GetHashCode(key); + int startIdx = (hashCode & 0x7FFFFFFF) % this.m_buckets.Length; + Grouping g = this.m_buckets[startIdx]; + if (g != null && + hashCode == this.m_comparer.GetHashCode(g.Key) && + this.m_comparer.Equals(key, g.Key) && + g.Count() < this.m_maxGroupSize) + { + g.AddItem(elem); + return null; + } + Grouping newGroup = new Grouping(key); + newGroup.AddItem(elem); + this.m_buckets[startIdx] = newGroup; + if (g == null) this.m_count++; + return g; + } + + private void Resize() + { + int oldSize = this.m_buckets.Length; + int newSize = CollectionHelper.GetNextPrime(oldSize); + if (newSize > oldSize) + { + Grouping[] oldBuckets = this.m_buckets; + this.m_buckets = new Grouping[newSize]; + for (int i = 0; i < oldBuckets.Length; i++) + { + Grouping curGroup = oldBuckets[i]; + while (curGroup != null) + { + // Add the group: + Grouping nextGroup = curGroup.Next; + int hashCode = this.m_comparer.GetHashCode(curGroup.Key); + int startIdx = (hashCode & 0x7FFFFFFF) % newSize; + curGroup.Next = this.m_buckets[startIdx]; + this.m_buckets[startIdx] = curGroup; + curGroup = nextGroup; + } + } + } + } + + public IEnumerator> GetEnumerator() + { + for (int i = 0; i < this.m_buckets.Length; i++) + { + Grouping g = this.m_buckets[i]; + while (g != null) + { + yield return g; + g = g.Next; + } + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + } + + internal class GroupingHashSet : IEnumerable> + { + private const int Ratio = 2; + private const int MaxGroupSize = 64; + + private Grouping[] m_buckets; + private IEqualityComparer m_comparer; + private long m_count; + private long m_elemCount; + private int m_maxGroupSize; + + public GroupingHashSet(IEqualityComparer comparer) + : this(comparer, 1024, MaxGroupSize) + { + } + + internal GroupingHashSet(IEqualityComparer comparer, int capacity) + : this(comparer, capacity, MaxGroupSize) + { + } + + internal GroupingHashSet(IEqualityComparer comparer, int capacity, int maxGroupSize) + { + int size = CollectionHelper.GetNextPrime(capacity); + this.m_buckets = new Grouping[size]; + this.m_comparer = comparer; + this.m_count = 0; + this.m_elemCount = 0; + this.m_maxGroupSize = maxGroupSize; + } + + public long GroupCount + { + get { return this.m_count; } + } + + public long ElemCount + { + get { return this.m_elemCount; } + } + + public Grouping AddItem(TKey key, TElement elem) + { + this.m_elemCount++; + int hashCode = this.m_comparer.GetHashCode(key); + int startIdx = (hashCode & 0x7FFFFFFF) % this.m_buckets.Length; + for (Grouping g = this.m_buckets[startIdx]; g != null; g = g.Next) + { + if (hashCode == this.m_comparer.GetHashCode(g.Key) && + this.m_comparer.Equals(key, g.Key)) + { + g.AddItem(elem); + return g; + } + } + + // Add a new group for the element: + if (this.m_count == (this.m_buckets.Length * Ratio)) + { + this.Resize(); + startIdx = (hashCode & 0x7FFFFFFF) % this.m_buckets.Length; + } + Grouping newGroup = new Grouping(key); + newGroup.AddItem(elem); + newGroup.Next = this.m_buckets[startIdx]; + this.m_buckets[startIdx] = newGroup; + this.m_count++; + return newGroup; + } + + internal Grouping AddItemPartial(TKey key, TElement elem) + { + this.m_elemCount++; + int hashCode = this.m_comparer.GetHashCode(key); + int startIdx = (hashCode & 0x7FFFFFFF) % this.m_buckets.Length; + Grouping g = this.m_buckets[startIdx]; + if (g != null && + hashCode == this.m_comparer.GetHashCode(g.Key) && + this.m_comparer.Equals(key, g.Key) && + g.Count() < this.m_maxGroupSize) + { + g.AddItem(elem); + return null; + } + Grouping g1 = new Grouping(key); + g1.AddItem(elem); + this.m_buckets[startIdx] = g1; + if (g == null) + { + this.m_count++; + } + else + { + this.m_elemCount -= g.Count(); + } + return g; + } + + internal void Reduce(Func, TResult> resultSelector, + Func, TResult> combiner) + { + for (int i = 0; i < this.m_buckets.Length; i++) + { + Grouping curGroup = this.m_buckets[i]; + while (curGroup != null) + { + curGroup.Reduce(resultSelector, combiner); + curGroup = curGroup.Next; + } + } + this.m_elemCount = 0; + } + + private void Resize() + { + int oldSize = this.m_buckets.Length; + int newSize = CollectionHelper.GetNextPrime(oldSize); + if (newSize > oldSize) + { + Grouping[] oldBuckets = this.m_buckets; + this.m_buckets = new Grouping[newSize]; + for (int i = 0; i < oldBuckets.Length; i++) + { + Grouping curGroup = oldBuckets[i]; + while (curGroup != null) + { + // Add the group: + Grouping nextGroup = curGroup.Next; + int hashCode = this.m_comparer.GetHashCode(curGroup.Key); + int startIdx = (hashCode & 0x7FFFFFFF) % newSize; + curGroup.Next = this.m_buckets[startIdx]; + this.m_buckets[startIdx] = curGroup; + curGroup = nextGroup; + } + } + } + } + + public IEnumerator> GetEnumerator() + { + for (int i = 0; i < this.m_buckets.Length; i++) + { + Grouping g = this.m_buckets[i]; + while (g != null) + { + yield return g; + g = g.Next; + } + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + } + + internal static class CollectionHelper + { + private static readonly int[] primes = new int[] { 2053, 16411, 1048583, 8388617, 16777259, 33554467, 67108879 }; + + internal static int GetNextPrime(int p) + { + int len = primes.Length; + for (int i = 0; i < len; i++) + { + if (primes[i] > p) return primes[i]; + } + return primes[len-1]; + } + } +} diff --git a/LinqToDryad/DryadLinqDecomposition.cs b/LinqToDryad/DryadLinqDecomposition.cs new file mode 100644 index 0000000..3dc8a58 --- /dev/null +++ b/LinqToDryad/DryadLinqDecomposition.cs @@ -0,0 +1,917 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Reflection; +using System.Linq; +using System.Linq.Expressions; +using System.Diagnostics; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + internal class DecompositionInfo + { + private Expression m_func; // The original function call + private LambdaExpression m_seed; // (TSource) => TAccumulate + private LambdaExpression m_accumulator; // (TAccumulate, TSource) => TAccumulate + private LambdaExpression m_recursiveAccumulator; // (TAccumulate, TAccumulate) => TAccumulate + private LambdaExpression m_finalReducer; // (TAccumulate) => TResult + + public DecompositionInfo(Expression func, + LambdaExpression seed, + LambdaExpression accumulator, + LambdaExpression recursiveAccumulator, + LambdaExpression finalReducer) + { + this.m_func = func; + this.m_seed = seed; + this.m_accumulator = accumulator; + this.m_recursiveAccumulator = recursiveAccumulator; + this.m_finalReducer = finalReducer; + } + + public Expression Func + { + get { return this.m_func; } + } + + public LambdaExpression Seed + { + get { return this.m_seed; } + } + + public LambdaExpression Accumulator + { + get { return this.m_accumulator; } + } + + public LambdaExpression RecursiveAccumulator + { + get { return this.m_recursiveAccumulator; } + } + + public LambdaExpression FinalReducer + { + get { return this.m_finalReducer; } + } + } + + internal class Decomposition + { + internal static List + GetDecompositionInfoList(LambdaExpression resultSelectExpr, HpcLinqCodeGen codeGen) + { + ParameterExpression keyParam; + ParameterExpression groupParam; + if (resultSelectExpr.Parameters.Count == 1) + { + keyParam = null; + groupParam = resultSelectExpr.Parameters[0]; + } + else + { + Debug.Assert(resultSelectExpr.Parameters.Count == 2); + keyParam = resultSelectExpr.Parameters[0]; + groupParam = resultSelectExpr.Parameters[1]; + } + + List infoList = new List(1); + bool isDecomposed = GetDecompositionInfoList(keyParam, groupParam, + resultSelectExpr.Body, + infoList, codeGen); + if (isDecomposed) + { + return infoList; + } + return null; + } + + private static bool GetDecompositionInfoList(ParameterExpression keyParam, + ParameterExpression groupParam, + MemberBinding mbinding, + List infoList, + HpcLinqCodeGen codeGen) + { + if (mbinding is MemberAssignment) + { + Expression expr = ((MemberAssignment)mbinding).Expression; + return GetDecompositionInfoList(keyParam, groupParam, expr, infoList, codeGen); + } + else if (mbinding is MemberMemberBinding) + { + foreach (MemberBinding mb in ((MemberMemberBinding)mbinding).Bindings) + { + bool isDecomposed = GetDecompositionInfoList(keyParam, groupParam, mb, infoList, codeGen); + if (!isDecomposed) return false; + } + } + else if (mbinding is MemberListBinding) + { + foreach (ElementInit ei in ((MemberListBinding)mbinding).Initializers) + { + foreach (Expression arg in ei.Arguments) + { + bool isDecomposed = GetDecompositionInfoList(keyParam, groupParam, arg, infoList, codeGen); + if (!isDecomposed) return false; + } + } + } + return true; + } + + private static bool GetDecompositionInfoList(ParameterExpression keyParam, + ParameterExpression groupParam, + Expression expr, + List infoList, + HpcLinqCodeGen codeGen) + { + IEnumerable argList = null; + if (HpcLinqExpression.IsConstant(expr)) + { + return true; + } + else if (expr is BinaryExpression) + { + BinaryExpression be = (BinaryExpression)expr; + argList = new Expression[] { be.Left, be.Right }; + } + else if (expr is UnaryExpression) + { + UnaryExpression ue = (UnaryExpression)expr; + return GetDecompositionInfoList(keyParam, groupParam, ue.Operand, infoList, codeGen); + } + else if (expr is ConditionalExpression) + { + ConditionalExpression ce = (ConditionalExpression)expr; + argList = new Expression[] { ce.Test, ce.IfTrue, ce.IfFalse }; + } + else if (expr is MethodCallExpression) + { + MethodCallExpression mcExpr = (MethodCallExpression)expr; + DecompositionInfo dinfo = GetDecompositionInfo(groupParam, mcExpr, codeGen); + if (dinfo != null) + { + infoList.Add(dinfo); + return true; + } + if (mcExpr.Object != null) + { + bool isDecomposed = GetDecompositionInfoList(keyParam, groupParam, + mcExpr.Object, + infoList, codeGen); + if (!isDecomposed) return false; + } + argList = mcExpr.Arguments; + } + else if (expr is NewExpression) + { + argList = ((NewExpression)expr).Arguments; + } + else if (expr is NewArrayExpression) + { + argList = ((NewArrayExpression)expr).Expressions; + } + else if (expr is ListInitExpression) + { + ListInitExpression li = (ListInitExpression)expr; + bool isDecomposed = GetDecompositionInfoList(keyParam, groupParam, + li.NewExpression, + infoList, codeGen); + for (int i = 0, n = li.Initializers.Count; i < n; i++) + { + ElementInit ei = li.Initializers[i]; + foreach (Expression arg in ei.Arguments) + { + isDecomposed = GetDecompositionInfoList(keyParam, groupParam, arg, infoList, codeGen); + if (!isDecomposed) return false; + } + } + return true; + } + else if (expr is MemberInitExpression) + { + MemberInitExpression mi = (MemberInitExpression)expr; + bool isDecomposed = GetDecompositionInfoList(keyParam, groupParam, + mi.NewExpression, + infoList, codeGen); + if (!isDecomposed) return false; + foreach (MemberBinding mb in mi.Bindings) + { + isDecomposed = GetDecompositionInfoList(keyParam, groupParam, mb, infoList, codeGen); + if (!isDecomposed) return false; + } + return true; + } + else if (keyParam == null) + { + while (expr is MemberExpression) + { + MemberExpression me = (MemberExpression)expr; + if (me.Expression == groupParam && + me.Member.Name == "Key") + { + return true; + } + expr = me.Expression; + } + return false; + } + else + { + while (expr is MemberExpression) + { + expr = ((MemberExpression)expr).Expression; + } + return (expr == keyParam); + } + + foreach (var argExpr in argList) + { + bool isDecomposed = GetDecompositionInfoList(keyParam, groupParam, argExpr, infoList, codeGen); + if (!isDecomposed) return false; + } + return true; + } + + private static DecompositionInfo GetDecompositionInfo(ParameterExpression groupParam, + MethodCallExpression mcExpr, + HpcLinqCodeGen codeGen) + { + if (mcExpr.Arguments.Count == 0 || mcExpr.Arguments[0] != groupParam) + { + return null; + } + for (int i = 1; i < mcExpr.Arguments.Count; i++) + { + if (HpcLinqExpression.Contains(groupParam, mcExpr.Arguments[i])) + { + return null; + } + } + + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + Type[] paramTypeArgs = groupParam.Type.GetGenericArguments(); + Type sourceElemType = paramTypeArgs[paramTypeArgs.Length - 1]; + Type resultType = mcExpr.Type; + Type decomposerType = null; + + DecomposableAttribute decomposableAttrib = AttributeSystem.GetDecomposableAttrib(mcExpr); + if (decomposableAttrib != null) + { + decomposerType = decomposableAttrib.DecompositionType; + } + else + { + MethodInfo mInfo = mcExpr.Method; + if (mInfo.DeclaringType == typeof(System.Linq.Enumerable) || + mInfo.DeclaringType == typeof(System.Linq.Queryable)) + { + // For built-in decomposable operators. + switch (mInfo.Name) + { + case "Count": + case "LongCount": + { + Type outputType; + Expression body; + if (mInfo.Name == "Count") + { + outputType = typeof(Int32); + body = Expression.Constant(1, outputType); + } + else + { + outputType = typeof(Int64); + body = Expression.Constant((long)1, outputType); + } + ParameterExpression param1 = Expression.Parameter(outputType, "a"); + ParameterExpression param2 = Expression.Parameter(sourceElemType, "e"); + LambdaExpression seedExpr = Expression.Lambda(body, param2); + body = Expression.AddChecked(param1, body); + LambdaExpression accumulateExpr = Expression.Lambda(body, param1, param2); + param2 = Expression.Parameter(outputType, "b"); + body = Expression.AddChecked(param1, param2); + LambdaExpression recursiveAccumulateExpr = Expression.Lambda(body, param1, param2); + + return new DecompositionInfo(mcExpr, seedExpr, accumulateExpr, recursiveAccumulateExpr, null); + } + case "Any": + { + ParameterExpression param1 = Expression.Parameter(typeof(bool), "a"); + ParameterExpression param2; + Expression body; + if (mcExpr.Arguments.Count == 1) + { + param2 = Expression.Parameter(sourceElemType, "e"); + body = Expression.Constant(true, typeof(bool)); + } + else + { + LambdaExpression predExpr = HpcLinqExpression.GetLambda(mcExpr.Arguments[1]); + param2 = predExpr.Parameters[0]; + body = predExpr.Body; + } + + LambdaExpression seedExpr = Expression.Lambda(body, param2); + LambdaExpression accumulateExpr = Expression.Lambda(Expression.Or(param1, body), param1, param2); + param2 = Expression.Parameter(typeof(bool), "b"); + body = Expression.Or(param1, param2); + LambdaExpression recursiveAccumulateExpr = Expression.Lambda(body, param1, param2); + + return new DecompositionInfo(mcExpr, seedExpr, accumulateExpr, recursiveAccumulateExpr, null); + } + case "All": + { + ParameterExpression param1 = Expression.Parameter(typeof(bool), "a"); + LambdaExpression predExpr = HpcLinqExpression.GetLambda(mcExpr.Arguments[1]); + ParameterExpression param2 = predExpr.Parameters[0]; + + Expression body = predExpr.Body; + LambdaExpression seedExpr = Expression.Lambda(body, param2); + LambdaExpression accumulateExpr = Expression.Lambda(Expression.And(param1, body), param1, param2); + param2 = Expression.Parameter(typeof(bool), "b"); + body = Expression.And(param1, param2); + LambdaExpression recursiveAccumulateExpr = Expression.Lambda(body, param1, param2); + + return new DecompositionInfo(mcExpr, seedExpr, accumulateExpr, recursiveAccumulateExpr, null); + } + case "First": + { + ParameterExpression param1 = Expression.Parameter(sourceElemType, "a"); + ParameterExpression param2 = Expression.Parameter(sourceElemType, "e"); + + LambdaExpression seedExpr = Expression.Lambda(param2, param2); + LambdaExpression accumulateExpr = Expression.Lambda(param1, param1, param2); + LambdaExpression recursiveAccumulateExpr = accumulateExpr; + + return new DecompositionInfo(mcExpr, seedExpr, accumulateExpr, recursiveAccumulateExpr, null); + } + case "Last": + { + ParameterExpression param1 = Expression.Parameter(sourceElemType, "a"); + ParameterExpression param2 = Expression.Parameter(sourceElemType, "e"); + + LambdaExpression seedExpr = Expression.Lambda(param2, param2); + LambdaExpression accumulateExpr = Expression.Lambda(param2, param1, param2); + LambdaExpression recursiveAccumulateExpr = accumulateExpr; + + return new DecompositionInfo(mcExpr, seedExpr, accumulateExpr, recursiveAccumulateExpr, null); + } + case "Sum": + { + ParameterExpression param1; + ParameterExpression param2; + Expression arg2; + if (mInfo.GetParameters().Length == 1) + { + param2 = Expression.Parameter(sourceElemType, "e"); + arg2 = param2; + } + else + { + LambdaExpression selectExpr = HpcLinqExpression.GetLambda(mcExpr.Arguments[1]); + param2 = selectExpr.Parameters[0]; + arg2 = selectExpr.Body; + } + + Expression abody, sbody; + if (arg2.Type.IsGenericType) + { + param1 = Expression.Parameter(arg2.Type.GetGenericArguments()[0], "a"); + MethodInfo accumulateInfo = typeof(HpcLinqVertex).GetMethod( + "SumAccumulate", + new Type[] { param1.Type, arg2.Type }); + sbody = Expression.Constant(0, param1.Type); + sbody = Expression.Call(accumulateInfo, sbody, arg2); + abody = Expression.Call(accumulateInfo, param1, arg2); + } + else + { + param1 = Expression.Parameter(arg2.Type, "a"); + sbody = arg2; + abody = Expression.AddChecked(param1, arg2); + } + + LambdaExpression seedExpr = Expression.Lambda(sbody, param2); + LambdaExpression accumulateExpr = Expression.Lambda(abody, param1, param2); + param2 = Expression.Parameter(param1.Type, "b"); + Expression rbody = Expression.AddChecked(param1, param2); + LambdaExpression recursiveAccumulateExpr = Expression.Lambda(rbody, param1, param2); + Expression fbody = Expression.Convert(param1, arg2.Type); + LambdaExpression finalReduceExpr = Expression.Lambda(fbody, param1); + return new DecompositionInfo(mcExpr, seedExpr, accumulateExpr, + recursiveAccumulateExpr, finalReduceExpr); + } + case "Max": + case "Min": + { + ParameterExpression param2; + Expression abody; + if (mInfo.GetParameters().Length == 1) + { + param2 = Expression.Parameter(sourceElemType, "e"); + abody = param2; + } + else + { + LambdaExpression selectExpr = HpcLinqExpression.GetLambda(mcExpr.Arguments[1]); + param2 = selectExpr.Parameters[0]; + abody = selectExpr.Body; + } + + ParameterExpression param1 = Expression.Parameter(abody.Type, "a"); + Expression sbody = abody; + MethodInfo accumulateInfo; + string methodName = (mInfo.Name == "Max") ? "MaxAccumulate" : "MinAccumulate"; + if (mInfo.IsGenericMethod && (mInfo.GetParameters().Length == 1)) + { + accumulateInfo = typeof(HpcLinqVertex).GetMethod(methodName + "Generic"); + accumulateInfo = accumulateInfo.MakeGenericMethod(sourceElemType); + } + else + { + accumulateInfo = typeof(HpcLinqVertex).GetMethod( + methodName, + new Type[] { param1.Type, abody.Type }); + } + abody = Expression.Call(accumulateInfo, param1, abody); + + LambdaExpression seedExpr = Expression.Lambda(sbody, param2); + LambdaExpression accumulateExpr = Expression.Lambda(abody, param1, param2); + param2 = Expression.Parameter(param1.Type, "b"); + Expression rbody = Expression.Call(accumulateInfo, param1, param2); + LambdaExpression recursiveAccumulateExpr = Expression.Lambda(rbody, param1, param2); + return new DecompositionInfo(mcExpr, seedExpr, accumulateExpr, recursiveAccumulateExpr, null); + } + case "Aggregate": + { + ParameterExpression elemParam = Expression.Parameter(sourceElemType, "e"); + LambdaExpression accumulateExpr; + LambdaExpression seedExpr; + if (mcExpr.Arguments.Count == 2) + { + accumulateExpr = HpcLinqExpression.GetLambda(mcExpr.Arguments[1]); + seedExpr = Expression.Lambda(elemParam, elemParam); + } + else + { + accumulateExpr = HpcLinqExpression.GetLambda(mcExpr.Arguments[2]); + object seedVal = evaluator.Eval(mcExpr.Arguments[1]); + Expression body = Expression.Constant(seedVal, seedVal.GetType()); + ParameterSubst subst = new ParameterSubst(accumulateExpr.Parameters[0], body); + body = subst.Visit(accumulateExpr.Body); + seedExpr = Expression.Lambda(body, accumulateExpr.Parameters[1]); + } + if (!HpcLinqExpression.IsAssociative(accumulateExpr)) + { + return null; + } + LambdaExpression recursiveAccumulateExpr = HpcLinqExpression.GetAssociativeCombiner(accumulateExpr); + return new DecompositionInfo(mcExpr, seedExpr, accumulateExpr, recursiveAccumulateExpr, null); + } + case "Average": + { + ParameterExpression param2; + Expression abody; + if (mInfo.GetParameters().Length == 1) + { + param2 = Expression.Parameter(sourceElemType, "e"); + abody = param2; + } + else + { + LambdaExpression selectExpr = HpcLinqExpression.GetLambda(mcExpr.Arguments[1]); + param2 = selectExpr.Parameters[0]; + abody = selectExpr.Body; + } + Type aggValueType = abody.Type; + if (aggValueType == typeof(int) || + aggValueType == typeof(int?)) + { + aggValueType = typeof(long); + } + else if (aggValueType == typeof(long?)) + { + aggValueType = typeof(long); + } + else if (aggValueType == typeof(float) || + aggValueType == typeof(float?)) + { + aggValueType = typeof(double); + } + else if (aggValueType == typeof(double?)) + { + aggValueType = typeof(double); + } + else if (aggValueType == typeof(decimal?)) + { + aggValueType = typeof(decimal); + } + + Type sumAndCountType = typeof(AggregateValue<>).MakeGenericType(aggValueType); + ParameterExpression param1 = Expression.Parameter(sumAndCountType, "a"); + MethodInfo accumulateInfo = typeof(HpcLinqVertex).GetMethod( + "AverageAccumulate", + new Type[] { sumAndCountType, abody.Type }); + + // Seed: + Expression sbody = Expression.New(sumAndCountType); + sbody = Expression.Call(accumulateInfo, sbody, abody); + LambdaExpression seedExpr = Expression.Lambda(sbody, param2); + + // Accumulate: + abody = Expression.Call(accumulateInfo, param1, abody); + LambdaExpression accumulateExpr = Expression.Lambda(abody, param1, param2); + + // RecursiveAccumulate: + param2 = Expression.Parameter(param1.Type, "b"); + PropertyInfo valueInfo = sumAndCountType.GetProperty("Value"); + PropertyInfo countInfo = sumAndCountType.GetProperty("Count"); + Expression sumExpr1 = Expression.Property(param1, valueInfo); + Expression countExpr1 = Expression.Property(param1, countInfo); + Expression sumExpr2 = Expression.Property(param2, valueInfo); + Expression countExpr2 = Expression.Property(param2, countInfo); + Expression sumExpr = Expression.AddChecked(sumExpr1, sumExpr2); + Expression countExpr = Expression.AddChecked(countExpr1, countExpr2); + ConstructorInfo cinfo = sumAndCountType.GetConstructor(new Type[] { sumExpr.Type, countExpr.Type }); + Expression rbody = Expression.New(cinfo, sumExpr, countExpr); + LambdaExpression recursiveAccumulateExpr = Expression.Lambda(rbody, param1, param2); + + // FinalReduce: + if (sumExpr1.Type == typeof(long)) + { + sumExpr1 = Expression.Convert(sumExpr1, typeof(double)); + } + Expression fbody = Expression.Divide(sumExpr1, countExpr1); + fbody = Expression.Convert(fbody, resultType); + if (resultType.IsGenericType) + { + Expression zeroExpr = Expression.Constant(0, typeof(long)); + Expression condExpr = Expression.GreaterThan(countExpr1, zeroExpr); + Expression nullExpr = Expression.Constant(null, resultType); + fbody = Expression.Condition(condExpr, fbody, nullExpr); + } + LambdaExpression finalReduceExpr = Expression.Lambda(fbody, param1); + return new DecompositionInfo(mcExpr, seedExpr, accumulateExpr, recursiveAccumulateExpr, finalReduceExpr); + } + case "Contains": + { + decomposerType = typeof(ContainsDecomposition<>).MakeGenericType(sourceElemType); + break; + } + case "Distinct": + { + decomposerType = typeof(DistinctDecomposition<>).MakeGenericType(sourceElemType); + break; + } + default: + { + return null; + } + } + } + } + + if (decomposerType == null) return null; + + Type implementedInterface = null; + Type[] interfaces = decomposerType.GetInterfaces(); + foreach (Type intf in interfaces) + { + if (intf.GetGenericTypeDefinition() == typeof(IDecomposable<,,>)) + { + if (implementedInterface != null) + { + throw new DryadLinqException("Decomposition class can implement only one decomposable interface."); + } + implementedInterface = intf; + } + } + + if (implementedInterface == null || + implementedInterface.GetGenericArguments().Length != 3) + { + throw new DryadLinqException("Decomposition class " + decomposerType.FullName + + "must implement IDecomposable<,,>"); + } + + // The second type of the implemented interface definition is the accumulatorType. + Type accumulatorType = implementedInterface.GetGenericArguments()[1]; + + // Now check that all the types match up. + Type decomposerInterface = typeof(IDecomposable<,,>).MakeGenericType( + sourceElemType, accumulatorType, resultType); + if (!decomposerInterface.IsAssignableFrom(decomposerType)) + { + throw new DryadLinqException("Decomposition class must match the function that it decorates."); + } + if (decomposerType.ContainsGenericParameters) + { + if (decomposerType.GetGenericArguments().Length != 1 || + !decomposerType.GetGenericArguments()[0].IsGenericParameter) + { + throw new DryadLinqException(decomposerType.Name + " must match the function it annotates."); + } + decomposerType = decomposerType.MakeGenericType(sourceElemType); + } + if (decomposerType.GetConstructor(Type.EmptyTypes) == null) + { + throw new DryadLinqException("Decomposition class must have a default constructor."); + } + + // Add to the codegen a call of the static Initializer of decomposerType + Expression[] args = new Expression[mcExpr.Arguments.Count-1]; + for (int i = 0; i < args.Length; i++) + { + args[i] = Expression.Convert(mcExpr.Arguments[i+1], typeof(object)); + } + Expression stateExpr = Expression.NewArrayInit(typeof(object), args); + string decomposerName = codeGen.AddDecompositionInitializer(decomposerType, stateExpr); + ParameterExpression decomposer = Expression.Parameter(decomposerType, decomposerName); + + // Seed: TSource => TAccumulate + MethodInfo seedInfo1 = decomposerType.GetMethod("Seed"); + ParameterExpression p2 = Expression.Parameter(sourceElemType, "e"); + Expression sbody1 = Expression.Call(decomposer, seedInfo1, p2); + LambdaExpression seedExpr1 = Expression.Lambda(sbody1, p2); + + // Accumulate: (TAccumulate, TSource) => TAccumulate + MethodInfo accumulateInfo1 = decomposerType.GetMethod("Accumulate"); + ParameterExpression p1 = Expression.Parameter(accumulatorType, "a"); + Expression abody1 = Expression.Call(decomposer, accumulateInfo1, p1, p2); + LambdaExpression accumulateExpr1 = Expression.Lambda(abody1, p1, p2); + + // RecursiveAccumulate: (TAccumulate, TAccumulate) => TAccumulate + MethodInfo recursiveAccumulateInfo1 = decomposerType.GetMethod("RecursiveAccumulate"); + p2 = Expression.Parameter(accumulatorType, "e"); + Expression rbody1 = Expression.Call(decomposer, recursiveAccumulateInfo1, p1, p2); + LambdaExpression recursiveAccumulateExpr1 = Expression.Lambda(rbody1, p1, p2); + + // FinalReduce: TAccumulate => TResult + MethodInfo finalReduceInfo1 = decomposerType.GetMethod("FinalReduce"); + Expression fbody1 = Expression.Call(decomposer, finalReduceInfo1, p1); + LambdaExpression finalReduceExpr1 = Expression.Lambda(fbody1, p1); + + return new DecompositionInfo(mcExpr, seedExpr1, accumulateExpr1, recursiveAccumulateExpr1, finalReduceExpr1); + } + + // Precondition: idx < dInfoList.Count + internal static Expression AccumulateList(Expression valueExpr, + ParameterExpression elemParam, + List dInfoList, + int idx) + { + LambdaExpression accumulateExpr = dInfoList[idx].Accumulator; + if (dInfoList.Count == idx + 1) + { + ParameterSubst subst = new ParameterSubst(accumulateExpr.Parameters[0], valueExpr); + Expression resultExpr = subst.Visit(accumulateExpr.Body); + subst = new ParameterSubst(accumulateExpr.Parameters[1], elemParam); + return subst.Visit(resultExpr); + } + else + { + PropertyInfo keyPropInfo = valueExpr.Type.GetProperty("Key"); + Expression keyValueExpr = Expression.Property(valueExpr, keyPropInfo); + ParameterSubst subst = new ParameterSubst(accumulateExpr.Parameters[0], keyValueExpr); + Expression expr1 = subst.Visit(accumulateExpr.Body); + subst = new ParameterSubst(accumulateExpr.Parameters[1], elemParam); + expr1 = subst.Visit(expr1); + + PropertyInfo valuePropInfo = valueExpr.Type.GetProperty("Value"); + Expression valueValueExpr = Expression.Property(valueExpr, valuePropInfo); + Expression expr2 = AccumulateList(valueValueExpr, elemParam, dInfoList, idx + 1); + + Type pairType = typeof(Pair<,>).MakeGenericType(expr1.Type, expr2.Type); + return Expression.New(pairType.GetConstructors()[0], expr1, expr2); + } + } + + // Precondition: idx < dInfoList.Count + internal static Expression RecursiveAccumulateList(Expression valueExpr1, + Expression valueExpr2, + List dInfoList, + int idx) + { + LambdaExpression recursiveAccumulateExpr = dInfoList[idx].RecursiveAccumulator; + if (dInfoList.Count == idx + 1) + { + ParameterSubst subst = new ParameterSubst(recursiveAccumulateExpr.Parameters[0], valueExpr1); + Expression resultExpr = subst.Visit(recursiveAccumulateExpr.Body); + subst = new ParameterSubst(recursiveAccumulateExpr.Parameters[1], valueExpr2); + return subst.Visit(resultExpr); + } + else + { + PropertyInfo keyPropInfo1 = valueExpr1.Type.GetProperty("Key"); + Expression keyValueExpr1 = Expression.Property(valueExpr1, keyPropInfo1); + PropertyInfo keyPropInfo2 = valueExpr2.Type.GetProperty("Key"); + Expression keyValueExpr2 = Expression.Property(valueExpr2, keyPropInfo2); + ParameterSubst subst = new ParameterSubst(recursiveAccumulateExpr.Parameters[0], keyValueExpr1); + Expression expr1 = subst.Visit(recursiveAccumulateExpr.Body); + subst = new ParameterSubst(recursiveAccumulateExpr.Parameters[1], keyValueExpr2); + expr1 = subst.Visit(expr1); + + PropertyInfo valuePropInfo1 = valueExpr1.Type.GetProperty("Value"); + Expression valueValueExpr1 = Expression.Property(valueExpr1, valuePropInfo1); + PropertyInfo valuePropInfo2 = valueExpr2.Type.GetProperty("Value"); + Expression valueValueExpr2 = Expression.Property(valueExpr2, valuePropInfo2); + Expression expr2 = RecursiveAccumulateList(valueValueExpr1, valueValueExpr2, dInfoList, idx + 1); + + Type pairType = typeof(Pair<,>).MakeGenericType(expr1.Type, expr2.Type); + return Expression.New(pairType.GetConstructors()[0], expr1, expr2); + } + } + } + + public class ContainsDecomposition : IDecomposable + { + private TSource m_value; + private IEqualityComparer m_comparer; + + public void Initialize(object state) + { + object[] args = state as object[]; + this.m_value = (TSource)args[0]; + if (args.Length > 1) + { + this.m_comparer = (IEqualityComparer)args[1]; + } + else + { + this.m_comparer = EqualityComparer.Default; + } + } + + public bool Seed(TSource val) + { + return this.m_comparer.Equals(this.m_value, val); + } + + public bool Accumulate(bool acc, TSource val) + { + return acc || this.m_comparer.Equals(this.m_value, val); + } + + public bool RecursiveAccumulate(bool acc, bool val) + { + return acc || val; + } + + public bool FinalReduce(bool val) + { + return val; + } + } + + public class DistinctDecomposition + : IDecomposable, IEnumerable> + { + private IEqualityComparer m_comparer; + + public void Initialize(object state) + { + object[] args = state as object[]; + if (args.Length > 0) + { + this.m_comparer = (IEqualityComparer)args[0]; + } + else + { + this.m_comparer = EqualityComparer.Default; + } + } + + public DistinctSet Seed(TSource val) + { + DistinctSet set = new DistinctSet(); + set.Add(val, this.m_comparer); + return set; + } + + public DistinctSet Accumulate(DistinctSet acc, TSource val) + { + acc.Add(val, this.m_comparer); + return acc; + } + + public DistinctSet RecursiveAccumulate(DistinctSet acc, + DistinctSet val) + { + foreach (TSource x in val.GetElems(this.m_comparer)) + { + acc.Add(x, this.m_comparer); + } + return acc; + } + + public IEnumerable FinalReduce(DistinctSet val) + { + return val.ToArray(this.m_comparer); + } + } + + public class DistinctSet + { + private const Int32 MaxCount = 32; + private static readonly TSource[] Empty = new TSource[0]; + + private TSource[] m_distinctElems; + private TSource[] m_elems; + private Int32 m_count; + + public DistinctSet() + { + this.m_distinctElems = Empty; + this.m_elems = new TSource[1]; + this.m_count = 0; + } + + public void Add(TSource elem, IEqualityComparer comparer) + { + if (this.m_count == this.m_elems.Length) + { + if (this.m_count < MaxCount) + { + TSource[] newElems = new TSource[this.m_count * 2]; + Array.Copy(this.m_elems, 0, newElems, 0, this.m_count); + this.m_elems = newElems; + } + else + { + this.m_distinctElems = this.ToArray(comparer); + this.m_elems = new TSource[2]; + this.m_count = 0; + } + } + this.m_elems[this.m_count++] = elem; + } + + public IEnumerable GetElems(IEqualityComparer comparer) + { + HashSet set = new HashSet(comparer); + for (int i = 0; i < this.m_count; i++) + { + if (set.Add(this.m_elems[i])) + { + yield return this.m_elems[i]; + } + } + foreach (var elem in this.m_distinctElems) + { + if (!set.Contains(elem)) + { + yield return elem; + } + } + } + + public TSource[] ToArray(IEqualityComparer comparer) + { + HashSet set = new HashSet(comparer); + for (int i = 0; i < this.m_count; i++) + { + set.Add(this.m_elems[i]); + } + Int32 idx = 0; + for (int i = 0; i < this.m_distinctElems.Length; i++) + { + if (!set.Contains(this.m_distinctElems[i])) + { + this.m_distinctElems[idx++] = this.m_distinctElems[i]; + } + } + TSource[] distinctElems = new TSource[idx + set.Count]; + Array.Copy(this.m_distinctElems, 0, distinctElems, 0, idx); + foreach (var x in set) + { + distinctElems[idx++] = x; + } + return distinctElems; + } + } +} diff --git a/LinqToDryad/DryadLinqException.cs b/LinqToDryad/DryadLinqException.cs new file mode 100644 index 0000000..236eeee --- /dev/null +++ b/LinqToDryad/DryadLinqException.cs @@ -0,0 +1,96 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.Text; +using System.Linq.Expressions; +using System.Runtime.Serialization; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + /// + /// + /// + [Serializable] + public class DryadLinqException : Exception + { + private int m_errorCode; + + public DryadLinqException(string message) : base(message) + { + } + + public DryadLinqException(string message, Exception inner) : base(message, inner) + { + } + + protected DryadLinqException(SerializationInfo info, StreamingContext context) + : base(info, context) + { + if (info != null) + { + this.m_errorCode = int.Parse(info.GetString("ErrorCode")); + } + } + + internal DryadLinqException(int errorCode, string message) + : base(message) + { + m_errorCode = errorCode; + } + + + internal DryadLinqException(int errorCode, string message, Exception innerException) + : base(message, innerException) + { + m_errorCode = errorCode; + } + + /// + /// Exception's error code. Maps to values in HpcLinqErrorCode. + /// + public int ErrorCode { get { return m_errorCode; } } + + internal static Exception Create(int errorCode, string msg, Expression expr) + { + StringBuilder sb = new StringBuilder(); + sb.Append(msg); + sb.Append(" Expression : "); + sb.AppendLine(HpcLinqExpression.Summarize(expr, 1)); + + return new DryadLinqException(errorCode, sb.ToString()); + } + + public override void GetObjectData(SerializationInfo info, StreamingContext context) + { + base.GetObjectData(info, context); + + if (info != null) + { + info.AddValue("ErrorCode", this.m_errorCode); + } + } + } +} diff --git a/LinqToDryad/DryadLinqExpression.cs b/LinqToDryad/DryadLinqExpression.cs new file mode 100644 index 0000000..3cbee41 --- /dev/null +++ b/LinqToDryad/DryadLinqExpression.cs @@ -0,0 +1,1774 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Text; +using System.IO; +using System.Reflection; +using System.Linq; +using System.Linq.Expressions; +using System.Diagnostics; +using System.Text.RegularExpressions; +using System.Runtime.Serialization.Formatters.Binary; +using System.Security.Cryptography; +using System.Runtime.Serialization; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + // Various methods to support Expression manipulation. + internal static class HpcLinqExpression + { + #region TOCSHARPSTRING + public static ParameterExpression GetParameterMemberAccess(Expression expr) + { + while (expr is MemberExpression) + { + if (!(((MemberExpression)expr).Member is FieldInfo) && + !(((MemberExpression)expr).Member is PropertyInfo)) + { + return null; + } + expr = ((MemberExpression)expr).Expression; + } + return (expr as ParameterExpression); + } + + public static Expression CreateMemberAccess(Expression expr, params string[] fieldNames) + { + Expression resultExpr = expr; + foreach (string name in fieldNames) + { + if (expr.Type.GetField(name) != null) + { + resultExpr = Expression.Field(resultExpr, name); + } + else if (expr.Type.GetProperty(name) != null) + { + resultExpr = Expression.Property(resultExpr, name); + } + else + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, + String.Format(SR.TypeDoesNotContainMember, expr.Type, name)); + } + } + return resultExpr; + } + + // TBD: Not quite complete + public static bool IsConstant(Expression expr) + { + if (expr is ConstantExpression) return true; + if (expr is MemberExpression) + { + Expression expr1 = ((MemberExpression)expr).Expression; + return expr1 == null || IsConstant(expr1); + } + return false; + } + + public static LambdaExpression GetLambda(Expression expr) + { + while (expr.NodeType == ExpressionType.Quote) + { + expr = ((UnaryExpression)expr).Operand; + } + return expr as LambdaExpression; + } + + public static bool Contains(ParameterExpression param, Expression expr) + { + FreeParameters freeParams = new FreeParameters(); + freeParams.Visit(expr); + return freeParams.Parameters.Contains(param); + } + + internal static bool IsAssociative(LambdaExpression expr) + { + if (AttributeSystem.GetAssociativeAttrib(expr) != null) + { + return true; + } + ParameterExpression param1 = expr.Parameters[0]; + ParameterExpression param2 = expr.Parameters[1]; + Expression operand1 = null; + Expression operand2 = null; + BinaryExpression body = expr.Body as BinaryExpression; + if (body == null) + { + MethodCallExpression mcExpr = expr.Body as MethodCallExpression; + if (mcExpr != null && mcExpr.Method.DeclaringType == typeof(System.Math)) + { + if (mcExpr.Method.Name == "Max" || mcExpr.Method.Name == "Min") + { + operand1 = mcExpr.Arguments[0]; + operand2 = mcExpr.Arguments[1]; + } + } + } + else if (body.NodeType == ExpressionType.Add || + body.NodeType == ExpressionType.AddChecked || + body.NodeType == ExpressionType.Multiply || + body.NodeType == ExpressionType.MultiplyChecked || + body.NodeType == ExpressionType.And || + body.NodeType == ExpressionType.Or || + body.NodeType == ExpressionType.ExclusiveOr) + { + if (body.Method == null) + { + operand1 = body.Left; + operand2 = body.Right; + } + } + + if (operand1 != null) + { + if (operand1 == param1) + { + return !Contains(param1, operand2); + } + if (operand2 == param1) + { + return !Contains(param1, operand1); + } + } + return false; + } + + internal static ExpressionType GetNodeType(string opName) + { + switch (opName) + { + case "op_Addition": + return ExpressionType.Add; + case "op_Subtraction": + return ExpressionType.Subtract; + case "op_Multiply": + return ExpressionType.Multiply; + case "op_Division": + return ExpressionType.Divide; + case "op_Modulus": + return ExpressionType.Modulo; + case "op_BitwiseAnd": + return ExpressionType.And; + case "op_BitwiseOr": + return ExpressionType.Or; + case "op_ExclusiveOr": + return ExpressionType.ExclusiveOr; + case "op_LeftShift": + return ExpressionType.LeftShift; + case "op_RightShift": + return ExpressionType.RightShift; + case "op_Equality": + return ExpressionType.Equal; + case "op_Inequality": + return ExpressionType.NotEqual; + case "op_LessThan": + return ExpressionType.LessThan; + case "op_GreaterThan": + return ExpressionType.GreaterThan; + case "op_LessThanOrEqual": + return ExpressionType.LessThanOrEqual; + case "op_GreaterThanOrEqual": + return ExpressionType.GreaterThanOrEqual; + case "op_UnaryPlus": + return ExpressionType.UnaryPlus; + case "op_UnaryNegation": + return ExpressionType.Negate; + case "op_LogicalNot": + return ExpressionType.Not; + default: + // @TODO: does this have to appear in HpcLinqErrorCode? Consider a generic "Internal error" fault code, with English only messages. + throw new DryadLinqException(HpcLinqErrorCode.UnrecognizedOperatorName, + String.Format(SR.UnrecognizedOperatorName , opName)); + } + } + + internal static LambdaExpression GetAssociativeCombiner(LambdaExpression expr) + { + AssociativeAttribute attrib = AttributeSystem.GetAssociativeAttrib(expr); + if (attrib == null) + { + BinaryExpression bexpr = expr.Body as BinaryExpression; + if (bexpr == null) + { + MethodCallExpression mcExpr = expr.Body as MethodCallExpression; + ParameterExpression px = Expression.Parameter(mcExpr.Arguments[0].Type, "x"); + ParameterExpression py = Expression.Parameter(mcExpr.Arguments[1].Type, "y"); + Expression body = Expression.Call(mcExpr.Method, px, py); + return Expression.Lambda(body, px, py); + } + else + { + ParameterExpression px = Expression.Parameter(bexpr.Left.Type, "x"); + ParameterExpression py = Expression.Parameter(bexpr.Right.Type, "y"); + Expression body = Expression.MakeBinary(bexpr.NodeType, px, py, bexpr.IsLiftedToNull, bexpr.Method); + return Expression.Lambda(body, px, py); + } + } + else + { + Type[] funcTypeArgs = expr.Type.GetGenericArguments(); + MethodInfo cInfo = null; + Type associativeType = attrib.AssociativeType; + if (associativeType == null) + { + if (expr.Body is MethodCallExpression) + { + cInfo = ((MethodCallExpression)expr.Body).Method; + } + else if (expr.Body is BinaryExpression) + { + cInfo = ((BinaryExpression)expr.Body).Method; + } + ParameterInfo[] pInfos = cInfo.GetParameters(); + if (cInfo == null || pInfos.Length != 2) + { + throw new DryadLinqException(HpcLinqErrorCode.AssociativeMethodHasWrongForm, + string.Format(SR.AssociativeMethodHasWrongForm, cInfo.Name)); + } + if (funcTypeArgs[0] != pInfos[0].ParameterType || pInfos[0].ParameterType != pInfos[1].ParameterType) + { + throw new DryadLinqException(HpcLinqErrorCode.AssociativeMethodHasWrongForm, + string.Format(SR.AssociativeMethodHasWrongForm, cInfo.Name)); + } + } + else + { + // determine if the attribute specifies an IAssociative. + if (associativeType.ContainsGenericParameters) + { + if (associativeType.GetGenericArguments().Length != 1) + { + throw new DryadLinqException(HpcLinqErrorCode.AssociativeTypeDoesNotImplementInterface, + String.Format(SR.AssociativeTypeDoesNotImplementInterface, associativeType.FullName)); + } + if (associativeType.GetGenericArguments()[0].IsGenericParameter) + { + associativeType = associativeType.MakeGenericType(funcTypeArgs[0]); + } + } + + Type implementedInterface = null; + Type[] interfaces = associativeType.GetInterfaces(); + foreach (Type inter in interfaces) + { + if (inter.GetGenericTypeDefinition() == typeof(IAssociative<>)) + { + if (implementedInterface != null) + { + throw new DryadLinqException(HpcLinqErrorCode.AssociativeTypeImplementsTooManyInterfaces, + String.Format(SR.AssociativeTypeImplementsTooManyInterfaces, associativeType.FullName)); + } + implementedInterface = inter; + } + } + if (implementedInterface == null) + { + throw new DryadLinqException(HpcLinqErrorCode.AssociativeTypeDoesNotImplementInterface, + String.Format(SR.AssociativeTypeDoesNotImplementInterface, associativeType.FullName)); + } + if (implementedInterface.GetGenericArguments()[0] != funcTypeArgs[0]) + { + throw new DryadLinqException(HpcLinqErrorCode.AssociativeTypesDoNotMatch, + String.Format(SR.AssociativeTypesDoNotMatch, associativeType.FullName)); + } + if (!associativeType.IsPublic && !associativeType.IsNestedPublic) + { + throw new DryadLinqException(HpcLinqErrorCode.AssociativeTypeMustBePublic, + String.Format(SR.AssociativeTypeMustBePublic, associativeType.FullName)); + } + try + { + associativeType = typeof(GenericAssociative<,>).MakeGenericType(associativeType, funcTypeArgs[0]); + } + catch (Exception) + { + throw new DryadLinqException(HpcLinqErrorCode.AssociativeTypeDoesNotHavePublicDefaultCtor, + String.Format(SR.AssociativeTypeDoesNotHavePublicDefaultCtor, associativeType.FullName)); + } + cInfo = associativeType.GetMethod("Accumulate", new Type[] { funcTypeArgs[0], funcTypeArgs[0] }); + Debug.Assert(cInfo != null, "problem finding method on associativeType"); + } + + ParameterExpression px = Expression.Parameter(funcTypeArgs[0], "x"); + ParameterExpression py = Expression.Parameter(funcTypeArgs[0], "y"); + Expression body; + if (cInfo.Name.StartsWith("op_", StringComparison.Ordinal)) + { + body = Expression.MakeBinary(GetNodeType(cInfo.Name), px, py, false, cInfo); + } + else + { + body = Expression.Call(cInfo, px, py); + } + return Expression.Lambda(body, px, py); + } + } + + public static LambdaExpression Rewrite(LambdaExpression expr, LambdaExpression selector, Substitution pSubst) + { + if (expr != null) + { + Type resultType = selector.Body.Type; + ParameterExpression resultParam = Expression.Parameter(resultType, "key_1"); + + // Perform substitutions + ExpressionSubst subst = new ExpressionSubst(pSubst); + subst.AddSubst(selector.Body, resultParam); + if (selector.Body is NewExpression) + { + NewExpression newBody = (NewExpression)selector.Body; + if (newBody.Constructor != null) + { + resultType = newBody.Constructor.DeclaringType; + } + if (TypeSystem.IsAnonymousType(resultType)) + { + PropertyInfo[] props = resultType.GetProperties(); + + //the following test is never expected to occur, and an assert would most likely suffice. + if (props.Length != newBody.Arguments.Count) + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, SR.BugInHandlingAnonymousClass); + } + + for (int i = 0; i < props.Length; i++) + { + Expression leftExpr = newBody.Arguments[i]; + Expression rightExpr = CreateMemberAccess(resultParam, props[i].Name); + subst.AddSubst(leftExpr, rightExpr); + } + } + } + if (selector.Body is MemberInitExpression) + { + ReadOnlyCollection bindings = ((MemberInitExpression)selector.Body).Bindings; + for (int i = 0; i < bindings.Count; i++) + { + if (bindings[i] is MemberAssignment) + { + Expression leftExpr = ((MemberAssignment)bindings[i]).Expression; + Expression rightExpr = CreateMemberAccess(resultParam, ((MemberAssignment)bindings[i]).Member.Name); + subst.AddSubst(leftExpr, rightExpr); + } + } + } + else + { + FieldMappingAttribute[] attribs = AttributeSystem.GetFieldMappingAttribs(selector); + if (attribs != null) + { + foreach (FieldMappingAttribute attrib in attribs) + { + string[] srcFieldNames = attrib.Source.Split('.'); + string paramName = srcFieldNames[0]; + + ParameterInfo[] paramInfos = null; + if (selector.Body is MethodCallExpression) + { + paramInfos = ((MethodCallExpression)selector.Body).Method.GetParameters(); + } + else if (selector.Body is NewExpression) + { + paramInfos = ((NewExpression)selector.Body).Constructor.GetParameters(); + } + + if (paramInfos != null) + { + int argIdx = -1; + for (int i = 0; i < paramInfos.Length; i++) + { + if (paramInfos[i].Name == paramName) + { + argIdx = i; + break; + } + } + + Expression leftExpr = null; + if (argIdx != -1) + { + if (selector.Body is MethodCallExpression) + { + leftExpr = ((MethodCallExpression)selector.Body).Arguments[argIdx]; + } + else if (selector.Body is NewExpression) + { + leftExpr = ((NewExpression)selector.Body).Arguments[argIdx]; + } + } + if (leftExpr == null) + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, + "The source of the FieldMapping annotation was wrong. " + + paramName + " is not a formal parameter."); + } + + string[] fieldNames = new string[srcFieldNames.Length - 1]; + for (int i = 1; i < srcFieldNames.Length; i++) + { + fieldNames[i] = srcFieldNames[i-1]; + } + leftExpr = CreateMemberAccess(leftExpr, fieldNames); + Expression rightExpr = CreateMemberAccess(resultParam, attrib.Destination.Split('.')); + subst.AddSubst(leftExpr, rightExpr); + } + } + } + } + Expression resultBody = subst.Visit(expr.Body); + + // Check if the substitutions are complete + FreeParameters freeParams = new FreeParameters(); + freeParams.Visit(resultBody); + if (freeParams.Parameters.Count == 1 && freeParams.Parameters.Contains(resultParam)) + { + Type funcType = typeof(Func<,>).MakeGenericType(resultType, expr.Body.Type); + return Expression.Lambda(funcType, resultBody, resultParam); + } + } + return null; + } + + public static string EscapeString(char ch) + { + switch (ch) + { + case '\'': + case '\"': + case '\\': + return "\\" + ch; + case '\0': + return "\\0"; + case '\a': + return "\\a"; + case '\b': + return "\\b"; + case '\f': + return "\\f"; + case '\n': + return "\\n"; + case '\r': + return "\\r"; + case '\t': + return "\\t"; + case '\v': + return "\\v"; + default: + return null; + } + } + + private static Dictionary s_transIdMap = new Dictionary(); + public static string ToCSharpString(Expression expr) + { + return ToCSharpString(expr, new Dictionary()); + } + + public static string ToCSharpString(Expression expr, Dictionary typeNames) + { + StringBuilder builder = new StringBuilder(); + BuildExpression(builder, expr, typeNames); + return builder.ToString(); + } + + private static void BuildExpression(StringBuilder builder, + Expression expr, + Dictionary typeNames) + { + if (IsConstant(expr)) + { + object val = ExpressionSimplifier.Evaluate(expr); + if (val == null) + { + builder.Append("(" + TypeSystem.TypeName(expr.Type, typeNames) + ")"); + builder.Append("null"); + } + else + { + Type valType = val.GetType(); + if (valType.IsPrimitive) + { + TypeCode tcode = Type.GetTypeCode(expr.Type); + if (tcode == TypeCode.Boolean) + { + builder.Append(((bool)val) ? "true" : "false"); + } + else if (tcode == TypeCode.Char) + { + string escapeStr = EscapeString((char)val); + if (escapeStr == null) + { + builder.Append("'" + val + "'"); + } + else + { + builder.Append("'" + escapeStr + "'"); + } + } + else + { + builder.Append("((" + TypeSystem.TypeName(valType, typeNames) + ")("); + builder.Append(val + "))"); + } + } + else if (valType.IsEnum) + { + builder.Append(TypeSystem.TypeName(valType, typeNames) + "." + val); + } + else if (val is string) + { + builder.Append("@\""); + builder.Append(((string)val).Replace("\"", "\"\"")); + builder.Append("\""); + } + else if (val is Expression) + { + BuildExpression(builder, (Expression)val, typeNames); + } + else + { + int valIdx = HpcLinqObjectStore.Put(val); + builder.Append("((" + TypeSystem.TypeName(expr.Type, typeNames) + ")"); + builder.Append("HpcLinqObjectStore.Get(" + valIdx + "))"); + } + } + } + else if (expr is BinaryExpression) + { + BuildBinaryExpression(builder, (BinaryExpression)expr, typeNames); + } + else if (expr is ConditionalExpression) + { + BuildConditionalExpression(builder, (ConditionalExpression)expr, typeNames); + } + else if (expr is ConstantExpression) + { + BuildConstantExpression(builder, (ConstantExpression)expr, typeNames); + } + else if (expr is InvocationExpression) + { + BuildInvocationExpression(builder, (InvocationExpression)expr, typeNames); + } + else if (expr is LambdaExpression) + { + BuildLambdaExpression(builder, (LambdaExpression)expr, typeNames); + } + else if (expr is MemberExpression) + { + BuildMemberExpression(builder, (MemberExpression)expr, typeNames); + } + else if (expr is MethodCallExpression) + { + BuildMethodCallExpression(builder, (MethodCallExpression)expr, typeNames); + } + else if (expr is NewExpression) + { + BuildNewExpression(builder, (NewExpression)expr, typeNames); + } + else if (expr is NewArrayExpression) + { + BuildNewArrayExpression(builder, (NewArrayExpression)expr, typeNames); + } + else if (expr is MemberInitExpression) + { + BuildMemberInitExpression(builder, (MemberInitExpression)expr, typeNames); + } + else if (expr is ListInitExpression) + { + BuildListInitExpression(builder, (ListInitExpression)expr, typeNames); + } + else if (expr is ParameterExpression) + { + BuildParameterExpression(builder, (ParameterExpression)expr, typeNames); + } + else if (expr is TypeBinaryExpression) + { + BuildTypeBinaryExpression(builder, (TypeBinaryExpression)expr, typeNames); + } + else if (expr is UnaryExpression) + { + BuildUnaryExpression(builder, (UnaryExpression)expr, typeNames); + } + else + { + throw new DryadLinqException(HpcLinqErrorCode.UnsupportedExpressionsType, + String.Format(SR.UnsupportedExpressionsType, expr.NodeType)); + } + } + + private static void BuildInvocationExpression(StringBuilder builder, + InvocationExpression expr, + Dictionary typeNames) + { + builder.Append("("); + builder.Append("("); + + // type cast to method + builder.Append("("); + builder.Append(TypeSystem.TypeName(expr.Expression.Type, typeNames)); + builder.Append(")"); + // method name + builder.Append("("); + BuildExpression(builder, expr.Expression, typeNames); + builder.Append(")"); + + builder.Append(")"); + + // method invocation + builder.Append("("); + bool isFirst = true; + foreach (Expression arg in expr.Arguments) + { + if (isFirst) + { + isFirst = false; + } + else + { + builder.Append(", "); + } + BuildExpression(builder, arg, typeNames); + } + builder.Append(")"); + builder.Append(")"); + } + + private static void BuildBinaryExpression(StringBuilder builder, + BinaryExpression expr, + Dictionary typeNames) + { + if (expr.NodeType == ExpressionType.ArrayIndex) + { + BuildExpression(builder, expr.Left, typeNames); + builder.Append("["); + BuildExpression(builder, expr.Right, typeNames); + builder.Append("]"); + } + else + { + string op = GetBinaryOperator(expr); + if (op != null) + { + bool isChecked = (expr.NodeType == ExpressionType.AddChecked || + expr.NodeType == ExpressionType.SubtractChecked || + expr.NodeType == ExpressionType.MultiplyChecked); + if (isChecked) builder.Append("checked("); + builder.Append("("); + BuildExpression(builder, expr.Left, typeNames); + builder.Append(" "); + builder.Append(op); + builder.Append(" "); + BuildExpression(builder, expr.Right, typeNames); + builder.Append(")"); + if (isChecked) builder.Append(")"); + } + else { + builder.Append(expr.NodeType); + builder.Append("("); + BuildExpression(builder, expr.Left, typeNames); + builder.Append(", "); + BuildExpression(builder, expr.Right, typeNames); + builder.Append(")"); + } + } + } + + internal static string GetBinaryOperator(BinaryExpression expr) + { + switch (expr.NodeType) + { + case ExpressionType.Add: + case ExpressionType.AddChecked: + return "+"; + case ExpressionType.Subtract: + case ExpressionType.SubtractChecked: + return "-"; + case ExpressionType.Multiply: + case ExpressionType.MultiplyChecked: + return "*"; + case ExpressionType.Divide: + return "/"; + case ExpressionType.Modulo: + return "%"; + case ExpressionType.And: + return (expr.Type == typeof(bool) || expr.Type == typeof(bool?)) ? "&&" : "&"; + case ExpressionType.AndAlso: + return "&&"; + case ExpressionType.Or: + return (expr.Type == typeof(bool) || expr.Type == typeof(bool?)) ? "||" : "|"; + case ExpressionType.OrElse: + return "||"; + case ExpressionType.LessThan: + return "<"; + case ExpressionType.LessThanOrEqual: + return "<="; + case ExpressionType.GreaterThan: + return ">"; + case ExpressionType.GreaterThanOrEqual: + return ">="; + case ExpressionType.Equal: + return "=="; + case ExpressionType.NotEqual: + return "!="; + case ExpressionType.Coalesce: + return "??"; + case ExpressionType.RightShift: + return ">>"; + case ExpressionType.LeftShift: + return "<<"; + case ExpressionType.ExclusiveOr: + case ExpressionType.Power: + return "^"; + } + return null; + } + + private static void BuildConditionalExpression(StringBuilder builder, + ConditionalExpression expr, + Dictionary typeNames) + { + builder.Append("("); + builder.Append("("); + BuildExpression(builder, expr.Test, typeNames); + builder.Append(") ? ("); + BuildExpression(builder, expr.IfTrue, typeNames); + builder.Append(") : ("); + BuildExpression(builder, expr.IfFalse, typeNames); + builder.Append(")"); + builder.Append(")"); + } + + private static void BuildConstantExpression(StringBuilder builder, + ConstantExpression expr, + Dictionary typeNames) + { + if (expr.Value == null) + { + builder.Append("null"); + } + else + { + if (expr.Value is string) + { + builder.Append("@\""); + builder.Append(expr.Value); + builder.Append("\""); + } + else if (expr.Value.ToString() == expr.Value.GetType().ToString()) + { + builder.Append("value("); + builder.Append(expr.Value); + builder.Append(")"); + } + else + { + builder.Append(expr.Value); + } + } + } + + private static void BuildLambdaExpression(StringBuilder builder, + LambdaExpression expr, + Dictionary typeNames) + { + foreach (ParameterExpression param in expr.Parameters) + { + if (TypeSystem.IsTransparentIdentifier(param.Name)) + { + string newName = HpcLinqCodeGen.MakeUniqueName("h__TransparentIdentifier"); + s_transIdMap[param.Name] = newName; + } + } + if (expr.Parameters.Count == 1) + { + BuildExpression(builder, expr.Parameters[0], typeNames); + } + else + { + builder.Append("("); + for (int i = 0, n = expr.Parameters.Count; i < n; i++) + { + if (i > 0) builder.Append(", "); + BuildExpression(builder, expr.Parameters[i], typeNames); + } + builder.Append(")"); + } + builder.Append(" => "); + BuildExpression(builder, expr.Body, typeNames); + } + + private static void BuildMemberExpression(StringBuilder builder, + MemberExpression expr, + Dictionary typeNames) + { + if (expr.Expression == null) + { + builder.Append(TypeSystem.TypeName(expr.Member.DeclaringType, typeNames)); + } + else + { + ParameterExpression param = expr.Expression as ParameterExpression; + if (param != null) + { + BuildExpression(builder, param, typeNames); + } + else + { + BuildExpression(builder, expr.Expression, typeNames); + } + } + builder.Append("."); + string memberName = expr.Member.Name; + if (TypeSystem.IsTransparentIdentifier(memberName)) + { + if (s_transIdMap.ContainsKey(memberName)) + { + memberName = s_transIdMap[memberName]; + } + else + { + string newName = HpcLinqCodeGen.MakeUniqueName("h__TransparentIdentifier"); + s_transIdMap[memberName] = newName; + memberName = newName; + } + } + builder.Append(memberName); + } + + private static void BuildMethodCallExpression(StringBuilder builder, + MethodCallExpression expr, + Dictionary typeNames) + { + Expression obj = expr.Object; + int start = 0; + if (Attribute.GetCustomAttribute(expr.Method, typeof(System.Runtime.CompilerServices.ExtensionAttribute)) != null) + { + start = 1; + obj = expr.Arguments[0]; + } + bool desugar = expr.Method.IsStatic && !TypeSystem.IsQueryOperatorCall(expr); + if (obj == null || desugar) + { + Type type = expr.Method.DeclaringType; + builder.Append(TypeSystem.TypeName(type, typeNames)); + } + else + { + BuildExpression(builder, obj, typeNames); + } + + if (TypeSystem.IsProperty(expr.Method)) + { + // Special case: an indexer property + builder.Append("["); + for (int i = start, n = expr.Arguments.Count; i < n; i++) + { + if (i > start) builder.Append(", "); + BuildExpression(builder, expr.Arguments[i], typeNames); + } + builder.Append("]"); + } + else + { + bool isArrayIndexer = (expr.Method.DeclaringType.IsArray && expr.Method.Name == "Get"); + if (isArrayIndexer) + { + builder.Append("["); + } + else + { + builder.Append("."); + builder.Append(expr.Method.Name); + if (expr.Method.IsGenericMethod && + !TypeSystem.ContainsAnonymousType(expr.Method.GetGenericArguments())) + { + builder.Append("<"); + bool first = true; + foreach (Type t in expr.Method.GetGenericArguments()) + { + if (first) + { + first = false; + } + else + { + builder.Append(","); + } + builder.Append(TypeSystem.TypeName(t, typeNames)); + } + builder.Append(">"); + } + builder.Append("("); + } + + bool isFirst = true; + if (obj != null && desugar) + { + isFirst = false; + BuildExpression(builder, obj, typeNames); + } + for (int i = start, n = expr.Arguments.Count; i < n; i++) + { + if (isFirst) + { + isFirst = false; + } + else + { + builder.Append(", "); + } + BuildExpression(builder, expr.Arguments[i], typeNames); + } + + builder.Append((isArrayIndexer) ? "]" : ")"); + } + } + + private static void BuildNewExpression(StringBuilder builder, + NewExpression expr, + Dictionary typeNames) + { + Type type = (expr.Constructor == null) ? expr.Type : expr.Constructor.DeclaringType; + builder.Append("new "); + string typeName = null; + if (TypeSystem.IsAnonymousType(type)) + { + if (!typeNames.TryGetValue(type, out typeName)) + { + PropertyInfo[] props = type.GetProperties(); + System.Array.Sort(props, (x, y) => x.MetadataToken.CompareTo(y.MetadataToken)); + builder.Append("{"); + for (int i = 0; i < props.Length; i++) + { + if (i > 0) builder.Append(", "); + string propName = props[i].Name; + if (TypeSystem.IsTransparentIdentifier(propName)) + { + if (s_transIdMap.ContainsKey(propName)) + { + propName = s_transIdMap[propName]; + } + else + { + string newName = HpcLinqCodeGen.MakeUniqueName("h__TransparentIdentifier"); + s_transIdMap.Add(propName, newName); + propName = newName; + } + } + builder.Append(propName + " = "); + BuildExpression(builder, expr.Arguments[i], typeNames); + } + builder.Append("}"); + return; + } + } + else + { + typeName = TypeSystem.TypeName(type, typeNames); + } + + builder.Append(typeName); + builder.Append("("); + for (int i = 0; i < expr.Arguments.Count; i++) + { + if (i > 0) builder.Append(", "); + BuildExpression(builder, expr.Arguments[i], typeNames); + } + builder.Append(")"); + } + + private static void BuildNewArrayExpression(StringBuilder builder, + NewArrayExpression expr, + Dictionary typeNames) + { + builder.Append("new "); + if (expr.NodeType == ExpressionType.NewArrayBounds) + { + Type baseType = expr.Type.GetElementType(); + while (baseType.IsArray) + { + baseType = baseType.GetElementType(); + } + builder.Append(TypeSystem.TypeName(baseType, typeNames)); + builder.Append("["); + for (int i = 0, n = expr.Expressions.Count; i < n; i++) + { + if (i > 0) builder.Append(", "); + BuildExpression(builder, expr.Expressions[i], typeNames); + } + builder.Append("]"); + + Type elemType = expr.Type.GetElementType(); + while (elemType.IsArray) + { + builder.Append("["); + int rank = elemType.GetArrayRank(); + for (int i = 1; i < rank; i++) + { + builder.Append(','); + } + builder.Append("]"); + elemType = elemType.GetElementType(); + } + } + else + { + Debug.Assert(expr.NodeType == ExpressionType.NewArrayInit); + builder.Append(TypeSystem.TypeName(expr.Type, typeNames)); + builder.Append(" {"); + for (int i = 0, n = expr.Expressions.Count; i < n; i++) + { + if (i > 0) builder.Append(", "); + BuildExpression(builder, expr.Expressions[i], typeNames); + } + builder.Append("}"); + } + } + + private static void BuildMemberInitExpression(StringBuilder builder, + MemberInitExpression expr, + Dictionary typeNames) + { + if (expr.NewExpression.Arguments.Count == 0 && + expr.NewExpression.Type.Name.Contains("<")) + { + // anonymous type constructor + builder.Append("new"); + } + else + { + BuildExpression(builder, expr.NewExpression, typeNames); + } + builder.Append(" {"); + for (int i = 0, n = expr.Bindings.Count; i < n; i++) + { + if (i > 0) builder.Append(", "); + BuildMemberBinding(builder, expr.Bindings[i], typeNames); + } + builder.Append("}"); + } + + private static void BuildListInitExpression(StringBuilder builder, + ListInitExpression expr, + Dictionary typeNames) + { + BuildExpression(builder, expr.NewExpression, typeNames); + builder.Append(" {"); + for (int i = 0, n = expr.Initializers.Count; i < n; i++) + { + if (i > 0) builder.Append(", "); + BuildElementInit(builder, expr.Initializers[i], typeNames); + } + builder.Append("}"); + } + + private static void BuildParameterExpression(StringBuilder builder, + ParameterExpression expr, + Dictionary typeNames) + { + if (expr.Name == null) + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, SR.UnnamedParameterExpression); + } + string paramName = expr.Name; + if (s_transIdMap.ContainsKey(paramName)) + { + paramName = s_transIdMap[paramName]; + } + builder.Append(paramName); + } + + private static void BuildTypeBinaryExpression(StringBuilder builder, + TypeBinaryExpression expr, + Dictionary typeNames) + { + Debug.Assert(expr.NodeType == ExpressionType.TypeIs); + builder.Append("("); + BuildExpression(builder, expr.Expression, typeNames); + builder.Append(" is "); + builder.Append(TypeSystem.TypeName(expr.TypeOperand, typeNames)); + builder.Append(")"); + } + + private static void BuildUnaryExpression(StringBuilder builder, + UnaryExpression expr, + Dictionary typeNames) + { + switch (expr.NodeType) + { + case ExpressionType.ArrayLength: + { + BuildExpression(builder, expr.Operand, typeNames); + builder.Append(".Length"); + break; + } + case ExpressionType.Convert: + { + bool isChecked = (expr.NodeType == ExpressionType.ConvertChecked); + if (isChecked) builder.Append("checked("); + + builder.Append("(("); + builder.Append(TypeSystem.TypeName(expr.Type, typeNames)); + builder.Append(")"); + BuildExpression(builder, expr.Operand, typeNames); + builder.Append(")"); + + if (isChecked) builder.Append(")"); + break; + } + case ExpressionType.TypeAs: + { + builder.Append("("); + BuildExpression(builder, expr.Operand, typeNames); + builder.Append(" as "); + builder.Append(TypeSystem.TypeName(expr.Type, typeNames)); + builder.Append(")"); + break; + } + case ExpressionType.Not: + { + //bug 15050.. Not is represented in C# two ways, depending on operand type. + // see http://msdn.microsoft.com/en-us/library/bb361179.aspx + if (expr.Operand.Type == typeof(bool) || expr.Operand.Type == typeof(bool?)) + { + builder.Append("(!("); + BuildExpression(builder, expr.Operand, typeNames); + builder.Append("))"); + } + else + { + builder.Append("(~("); + BuildExpression(builder, expr.Operand, typeNames); + builder.Append("))"); + } + + break; + } + case ExpressionType.Negate: + { + builder.Append("(-("); + BuildExpression(builder, expr.Operand, typeNames); + builder.Append("))"); + break; + } + case ExpressionType.NegateChecked: + { + builder.Append("checked(-("); + BuildExpression(builder, expr.Operand, typeNames); + builder.Append("))"); + break; + } + case ExpressionType.Quote: + { + BuildExpression(builder, expr.Operand, typeNames); + break; + } + default: + { + builder.Append(expr.NodeType); + builder.Append("("); + BuildExpression(builder, expr.Operand, typeNames); + builder.Append(")"); + break; + } + } + } + + private static void BuildMemberBinding(StringBuilder builder, + MemberBinding binding, + Dictionary typeNames) + { + if (binding is MemberAssignment) + { + builder.Append(binding.Member.Name); + builder.Append(" = "); + BuildExpression(builder, ((MemberAssignment)binding).Expression, typeNames); + } + else if (binding is MemberMemberBinding) + { + builder.Append(binding.Member.Name); + builder.Append(" = {"); + + MemberMemberBinding mmBinding = (MemberMemberBinding)binding; + for (int i = 0, n = mmBinding.Bindings.Count; i < n; i++) + { + if (i > 0) builder.Append(", "); + BuildMemberBinding(builder, mmBinding.Bindings[i], typeNames); + } + builder.Append("}"); + } + else + { + Debug.Assert(binding is MemberListBinding); + builder.Append(binding.Member.Name); + builder.Append(" = {"); + + MemberListBinding mlBinding = (MemberListBinding)binding; + for (int i = 0, n = mlBinding.Initializers.Count; i < n; i++) + { + if (i > 0) builder.Append(", "); + BuildElementInit(builder, mlBinding.Initializers[i], typeNames); + } + builder.Append("}"); + } + } + + private static void BuildElementInit(StringBuilder builder, + ElementInit elemInit, + Dictionary typeNames) + { + // An ElementInitExpression is something like: "Add(arg1, arg2..)" + // which corresponds to C# syntax such as: + // 1. new Dictionary { {arg1, arg2} }; + // AND/OR + // 2. Dictionary x = new Dictionary(); x.Add(arg1,arg2); + // + // The caller of BuildElementInit looks after the multiple *items* added to a collection, but this method must cope with + // Add() methods that take one or more arguments. + // + // The main example for multi-argument Add() methods is Dictionary.Add(key,val), but user-defined classes + // can also participate. C# uses duck-typing for this and similar language extensions. + + // Bug 15049: when emitting inline (syntax #1) don't emit elemInit.AddMethod, but do emit braces to demarcate the parameters + builder.Append("{"); + bool isFirst = true; + foreach (Expression argument in elemInit.Arguments) + { + if (isFirst) + { + isFirst = false; + } + else + { + builder.Append(","); + } + BuildExpression(builder, argument, typeNames); + } + builder.Append("}"); + } + #endregion + + #region SUMMARIZE + // summarizing expressions: like ToString, but more compact + public static string Summarize(Expression expr) + { + return Summarize(expr, 0); + } + + public static string Summarize(Expression expr, int level) + { + StringBuilder builder = new StringBuilder(); + Summarize(builder, expr, level); + return builder.ToString(); + } + + private static void Summarize(StringBuilder builder, Expression expr, int level) + { + if (IsConstant(expr)) + { + object val = ExpressionSimplifier.Evaluate(expr); + if (val == null) + { + builder.Append("null"); + } + else + { + Type valType = val.GetType(); + if (valType.IsPrimitive) + { + TypeCode tcode = Type.GetTypeCode(expr.Type); + if (tcode == TypeCode.Char) + { + string escapeStr = EscapeString((char)val); + if (escapeStr == null) + { + builder.Append("'" + val + "'"); + } + else + { + builder.Append("'" + escapeStr + "'"); + } + } + else + { + builder.Append(val.ToString()); + } + } + else if (val is Expression) + { + Summarize(builder, (Expression)val, level); + } + else if (val is Delegate) + { + builder.Append(((Delegate)val).Method.Name); + } + else + { + builder.Append('_'); + } + } + } + else if (expr is BinaryExpression) + { + SummarizeBinaryExpression(builder, (BinaryExpression)expr, level); + } + else if (expr is ConditionalExpression) + { + SummarizeConditionalExpression(builder, (ConditionalExpression)expr, level); + } + else if (expr is ConstantExpression) + { + SummarizeConstantExpression(builder, (ConstantExpression)expr); + } + else if (expr is InvocationExpression) + { + SummarizeInvocationExpression(builder, (InvocationExpression)expr, level); + } + else if (expr is LambdaExpression) + { + SummarizeLambdaExpression(builder, (LambdaExpression)expr, level); + } + else if (expr is MemberExpression) + { + SummarizeMemberExpression(builder, (MemberExpression)expr, level); + } + else if (expr is MethodCallExpression) + { + SummarizeMethodCallExpression(builder, (MethodCallExpression)expr, level); + } + else if (expr is NewExpression) + { + SummarizeNewExpression(builder, (NewExpression)expr, level); + } + else if (expr is NewArrayExpression) + { + SummarizeNewArrayExpression(builder, (NewArrayExpression)expr, level); + } + else if (expr is MemberInitExpression) + { + SummarizeMemberInitExpression(builder, (MemberInitExpression)expr, level); + } + else if (expr is ListInitExpression) + { + SummarizeListInitExpression(builder, (ListInitExpression)expr, level); + } + else if (expr is ParameterExpression) + { + SummarizeParameterExpression(builder, (ParameterExpression)expr); + } + else if (expr is TypeBinaryExpression) + { + SummarizeTypeBinaryExpression(builder, (TypeBinaryExpression)expr, level); + } + else if (expr is UnaryExpression) + { + SummarizeUnaryExpression(builder, (UnaryExpression)expr, level); + } + else + { + throw new DryadLinqException(HpcLinqErrorCode.UnsupportedExpressionType, + String.Format(SR.UnsupportedExpressionType, expr.NodeType)); + } + } + + private static void SummarizeInvocationExpression(StringBuilder builder, InvocationExpression expr, int level) + { + bool isConstant = expr.Expression is ConstantExpression; + if (!isConstant) + builder.Append("("); + Summarize(builder, expr.Expression, level); + if (!isConstant) + builder.Append(")"); + + // method invocation + builder.Append("("); + bool isFirst = true; + foreach (Expression arg in expr.Arguments) + { + if (isFirst) + { + isFirst = false; + } + else + { + builder.Append(", "); + } + Summarize(builder, arg, level); + } + builder.Append(")"); + } + + private static void SummarizeBinaryExpression(StringBuilder builder, BinaryExpression expr, int level) + { + if (expr.NodeType == ExpressionType.ArrayIndex) + { + Summarize(builder, expr.Left, level); + builder.Append("["); + Summarize(builder, expr.Right, level); + builder.Append("]"); + } + else + { + string op = GetBinaryOperator(expr); + if (op != null) + { + builder.Append("("); + Summarize(builder, expr.Left, level); + builder.Append(" "); + builder.Append(op); + builder.Append(" "); + Summarize(builder, expr.Right, level); + builder.Append(")"); + } + else + { + builder.Append(expr.NodeType); + builder.Append("("); + Summarize(builder, expr.Left, level); + builder.Append(", "); + Summarize(builder, expr.Right, level); + builder.Append(")"); + } + } + } + + private static void SummarizeConditionalExpression(StringBuilder builder, ConditionalExpression expr, int level) + { + builder.Append("("); + Summarize(builder, expr.Test, level); + builder.Append(") ? ("); + Summarize(builder, expr.IfTrue, level); + builder.Append(") : ("); + Summarize(builder, expr.IfFalse, level); + builder.Append(")"); + } + + private static void SummarizeConstantExpression(StringBuilder builder, ConstantExpression expr) + { + builder.Append('_'); + } + + private static void SummarizeLambdaExpression(StringBuilder builder, LambdaExpression expr, int level) + { + if (expr.Parameters.Count == 1) + { + Summarize(builder, expr.Parameters[0], level); + } + else + { + builder.Append("("); + for (int i = 0, n = expr.Parameters.Count; i < n; i++) + { + if (i > 0) builder.Append(", "); + Summarize(builder, expr.Parameters[i], level); + } + builder.Append(")"); + } + builder.Append(" => "); + Summarize(builder, expr.Body, level); + } + + private static void SummarizeMemberExpression(StringBuilder builder, MemberExpression expr, int level) + { + if (expr.Expression == null) + { + builder.Append(TypeSystem.TypeName(expr.Member.DeclaringType)); + } + else + { + Summarize(builder, expr.Expression, level); + } + builder.Append("."); + builder.Append(expr.Member.Name); + } + + private static void SummarizeMethodCallExpression(StringBuilder builder, MethodCallExpression expr, int level) + { + Expression obj = expr.Object; + int start = 0; + if (Attribute.GetCustomAttribute(expr.Method, typeof(System.Runtime.CompilerServices.ExtensionAttribute)) != null) + { + start = 1; + obj = expr.Arguments[0]; + } + if (obj == null) + { + Type type = expr.Method.DeclaringType; + builder.Append(HpcLinqUtil.SimpleName(TypeSystem.TypeName(type))); + } + else if (level > 0) + { + builder.Append('_'); + } + else + { + Summarize(builder, obj, level); + } + + if (TypeSystem.IsProperty(expr.Method)) + { + // Special case: an indexer property + builder.Append("["); + for (int i = start, n = expr.Arguments.Count; i < n; i++) + { + if (i > start) builder.Append(", "); + Summarize(builder, expr.Arguments[i], level); + } + builder.Append("]"); + } + else + { + builder.Append("."); + builder.Append(HpcLinqUtil.SimpleName(expr.Method.Name)); + builder.Append("("); + for (int i = start, n = expr.Arguments.Count; i < n; i++) + { + if (i > start) builder.Append(", "); + Summarize(builder, expr.Arguments[i], level); + } + builder.Append(")"); + } + } + + private static string SummarizeType(string typename) + { + // drop common 'System.*' prefixes + string result = Regex.Replace(typename, @"System[a-zA-Z\.]*\.([a-zA-Z]+)", "$1"); + return result; + } + + private static void SummarizeNewExpression(StringBuilder builder, NewExpression expr, int level) + { + Type type = (expr.Constructor == null) ? expr.Type : expr.Constructor.DeclaringType; + builder.Append("new "); + builder.Append(SummarizeType(TypeSystem.TypeName(type))); + + int n = expr.Arguments.Count; + builder.Append("("); + for (int i = 0; i < n; i++) + { + if (i > 0) builder.Append(", "); + Summarize(builder, expr.Arguments[i], level); + } + builder.Append(")"); + } + + private static void SummarizeNewArrayExpression(StringBuilder builder, NewArrayExpression expr, int level) + { + if (expr.NodeType == ExpressionType.NewArrayBounds) + { + builder.Append("new "); + builder.Append(expr.Type.ToString()); + builder.Append("("); + for (int i = 0, n = expr.Expressions.Count; i < n; i++) + { + if (i > 0) builder.Append(", "); + Summarize(builder, expr.Expressions[i], level); + } + builder.Append(")"); + } + else + { + Debug.Assert(expr.NodeType == ExpressionType.NewArrayInit); + builder.Append("new "); + builder.Append("[] {"); + for (int i = 0, n = expr.Expressions.Count; i < n; i++) + { + if (i > 0) builder.Append(", "); + Summarize(builder, expr.Expressions[i], level); + } + builder.Append("}"); + } + } + + private static void SummarizeMemberInitExpression(StringBuilder builder, MemberInitExpression expr, int level) + { + if (expr.NewExpression.Arguments.Count == 0 && + expr.NewExpression.Type.Name.Contains("<")) + { + // anonymous type constructor + builder.Append("new"); + } + else + { + Summarize(builder, expr.NewExpression, level); + } + builder.Append(" {"); + for (int i = 0, n = expr.Bindings.Count; i < n; i++) + { + if (i > 0) builder.Append(", "); + SummarizeMemberBinding(builder, expr.Bindings[i], level); + } + builder.Append("}"); + } + + private static void SummarizeListInitExpression(StringBuilder builder, ListInitExpression expr, int level) + { + Summarize(builder, expr.NewExpression, level); + builder.Append(" {"); + for (int i = 0, n = expr.Initializers.Count; i < n; i++) + { + if (i > 0) builder.Append(", "); + SummarizeElementInit(builder, expr.Initializers[i], level); + } + builder.Append("}"); + } + + private static void SummarizeParameterExpression(StringBuilder builder, ParameterExpression expr) + { + if (expr.Name != null) + { + builder.Append(expr.Name); + } + else + { + builder.Append('_'); + } + } + + private static void SummarizeTypeBinaryExpression(StringBuilder builder, TypeBinaryExpression expr, int level) + { + Debug.Assert(expr.NodeType == ExpressionType.TypeIs); + builder.Append("("); + Summarize(builder, expr.Expression, level); + builder.Append(" is "); + builder.Append(TypeSystem.TypeName(expr.TypeOperand)); + builder.Append(")"); + } + + private static void SummarizeUnaryExpression(StringBuilder builder, UnaryExpression expr, int level) + { + switch (expr.NodeType) + { + case ExpressionType.ArrayLength: + { + Summarize(builder, expr.Operand, level); + builder.Append(".Length"); + break; + } + case ExpressionType.Convert: + { + // do not show type casts + Summarize(builder, expr.Operand, level); + break; + } + case ExpressionType.TypeAs: + { + // do not show type casts + Summarize(builder, expr.Operand, level); + break; + } + case ExpressionType.Not: + { + builder.Append("!("); + Summarize(builder, expr.Operand, level); + builder.Append(")"); + break; + } + case ExpressionType.Negate: + { + builder.Append("-("); + Summarize(builder, expr.Operand, level); + builder.Append(")"); + break; + } + case ExpressionType.Quote: + { + Summarize(builder, expr.Operand, level); + break; + } + default: + { + builder.Append(expr.NodeType); + builder.Append("("); + Summarize(builder, expr.Operand, level); + builder.Append(")"); + break; + } + } + } + + private static void SummarizeMemberBinding(StringBuilder builder, MemberBinding binding, int level) + { + if (binding is MemberAssignment) + { + builder.Append(binding.Member.Name); + builder.Append(" = "); + Summarize(builder, ((MemberAssignment)binding).Expression, level); + } + else if (binding is MemberMemberBinding) + { + builder.Append(binding.Member.Name); + builder.Append(" = {"); + + MemberMemberBinding mmBinding = (MemberMemberBinding)binding; + for (int i = 0, n = mmBinding.Bindings.Count; i < n; i++) + { + if (i > 0) builder.Append(", "); + SummarizeMemberBinding(builder, mmBinding.Bindings[i], level); + } + builder.Append("}"); + } + else + { + Debug.Assert(binding is MemberListBinding); + builder.Append(binding.Member.Name); + builder.Append(" = {"); + + MemberListBinding mlBinding = (MemberListBinding)binding; + for (int i = 0, n = mlBinding.Initializers.Count; i < n; i++) + { + if (i > 0) builder.Append(", "); + SummarizeElementInit(builder, mlBinding.Initializers[i], level); + } + builder.Append("}"); + } + } + + private static void SummarizeElementInit(StringBuilder builder, ElementInit elemInit, int level) + { + builder.Append(elemInit.AddMethod.Name); + builder.Append("("); + bool isFirst = true; + foreach (Expression argument in elemInit.Arguments) + { + if (isFirst) + { + isFirst = false; + } + else + { + builder.Append(","); + } + Summarize(builder, argument, level); + } + builder.Append(")"); + } + #endregion + } +} diff --git a/LinqToDryad/DryadLinqExtension.cs b/LinqToDryad/DryadLinqExtension.cs new file mode 100644 index 0000000..56c8440 --- /dev/null +++ b/LinqToDryad/DryadLinqExtension.cs @@ -0,0 +1,247 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Globalization; +using System.Reflection; +using System.Linq.Expressions; +using System.Linq; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + /// + /// This provides some useful classes and operators that are commonly used + /// in applications. The operators are defined using DryadLINQ operators. + /// + [Serializable] + public struct Pair : IEquatable> + { + private T1 m_key; + private T2 m_value; + + [FieldMapping("x", "Key")] + [FieldMapping("y", "Value")] + public Pair(T1 x, T2 y) + { + this.m_key = x; + this.m_value = y; + } + + public T1 Key + { + get { return this.m_key; } + } + + public T2 Value + { + get { return this.m_value; } + } + + public override bool Equals(Object obj) + { + if (!(obj is Pair)) return false; + Pair pair = (Pair)obj; + return this.m_key.Equals(pair.Key) && this.m_value.Equals(pair.Value); + } + + public bool Equals(Pair val) + { + return this.m_key.Equals(val.Key) && this.m_value.Equals(val.Value); + } + + public static bool Equals(Pair a, Pair b) + { + return a.Equals(b); + } + + public static bool operator ==(Pair a, Pair b) + { + return a.Equals(b); + } + + public static bool operator !=(Pair a, Pair b) + { + return !a.Equals(b); + } + + public override int GetHashCode() + { + return (-1521134295 * this.m_key.GetHashCode()) + this.m_value.GetHashCode(); + } + + public override string ToString() + { + return "<" + this.Key + ", " + this.Value + ">"; + } + } + + public static class HpcLinqExtension + { + /// + /// The standard MapReduce. + /// + /// The type of the records of input dataset + /// The type of the resulting records of mapper + /// The type of the keys for hash exchange + /// The type of the resulting records of reducer + /// The input dataset + /// The map function + /// The key extraction function + /// The reduce function + /// The result dataset of MapReduce + public static IQueryable + MapReduce( + this IQueryable source, + Expression>> mapper, + Expression> keySelector, + Expression, TResult>> reducer) + { + return source.SelectMany(mapper).GroupBy(keySelector, reducer); + } + + /// + /// Compute the cross product of two datasets. The function procFunc is applied to each + /// pair of the cross product to form the output dataset. + /// + /// The type of the records of dataset source1 + /// The type of the records of dataset source2 + /// The type of the records of the result dataset + /// The first input dataset + /// The second input dataset + /// The function to apply to each pair of the cross product + /// The output dataset + public static IQueryable + CrossProduct(this IQueryable source1, + IQueryable source2, + Expression> procFunc) + { + return source1.ApplyPerPartition(source2, (x_1, y_1) => HpcLinqHelper.Cross(x_1, y_1, procFunc), true); + } + + /// + /// Conditional DoWhile loop. + /// + /// The type of the input records + /// The input dataset + /// The code body of the DoWhile loop + /// The termination condition of the DoWhile loop + /// The output dataset + public static IQueryable + DoWhile(this IQueryable source, + Func, IQueryable> body, + Func, IQueryable, IQueryable> cond, + Int32 count) + { + if (count < 0) + { + throw new ArgumentOutOfRangeException("count"); + } + if (count == 0) return source; + + IQueryable before = source; + while (true) + { + IQueryable after = before; + for (int i = 0; i < count; i++) + { + after = body(after); + } + var more = cond(before, after); + HpcLinqQueryable.SubmitAndWait(after, more); + if (!more.Single()) return after; + before = after; + } + } + + /// + /// Conditional DoWhile loop. + /// + /// The type of the input records + /// The input dataset + /// The code body of the DoWhile loop + /// The termination condition of the DoWhile loop + /// The output dataset + public static IQueryable + DoWhile(this IQueryable source, + Func, IQueryable> body, + Func, IQueryable, IQueryable> cond) + { + IQueryable before = source; + while (true) + { + IQueryable after = body(before); + var more = cond(before, after); + HpcLinqQueryable.SubmitAndWait(after, more); + if (!more.Single()) return after; + before = after; + } + } + + /// + /// Broadcast a dataset to multiple partitions + /// + /// The type of the input records + /// The input dataset + /// The output dataset, which consists of multiple copies of source + public static IQueryable BroadCast(this IQueryable source) + { + return source.ApplyPerPartition(source, (x, y) => HpcLinqHelper.SelectSecond(x, y), true); + } + + /// + /// Broadcast a dataset to n partitions. + /// + /// The type of the input records + /// The input dataset + /// The output dataset, each partition of which is a copy of source + public static IQueryable BroadCast(this IQueryable source, int n) + { + var dummy = source.ApplyPerPartition(x => HpcLinqHelper.ValueZero(x)) + .HashPartition(x => x, n); + return dummy.ApplyPerPartition(source, (x, y) => HpcLinqHelper.SelectSecond(x, y), true); + } + + /// + /// Check if each partition of the input dataset is ordered. + /// + /// The type of the records of the input dataset + /// The type of the keys on which ordering is based + /// The input dataset + /// The key extraction function + /// A Comparer on TKey to compare records + /// True if the check is for descending + /// The same dataset as the input + public static IQueryable + CheckOrderBy(this IQueryable source, + Expression> keySelector, + IComparer comparer, + bool isDescending) + { + return source.ApplyPerPartition(x_1 => HpcLinqHelper.CheckSort(x_1, keySelector, comparer, isDescending)); + } + } +} diff --git a/LinqToDryad/DryadLinqFaultCodes.cs b/LinqToDryad/DryadLinqFaultCodes.cs new file mode 100644 index 0000000..9f41197 --- /dev/null +++ b/LinqToDryad/DryadLinqFaultCodes.cs @@ -0,0 +1,335 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.Text; +using System.Linq.Expressions; +using System.Runtime.Serialization; + using System.Reflection; + +namespace Microsoft.Research.DryadLinq +{ + + // Make HpcLinq error categories internal for SP2. Tracked via bug 13313. + internal enum HpcLinqErrorCodeCategory : int + { + QueryAPI = 0x01000000, + CodeGen = 0x02000000, + JobSubmission = 0x03000000, + Serialization = 0x04000000, + DscClient= 0x05000000, + VertexRuntime = 0x06000000, + LocalDebug = 0x07000000, + Unknown = 0x0f000000 + } + + //@@TODO: when possible, remove the sr.txt entries for all the items marked "DEL" + + /// + /// Lists all error code in HpcLinq + /// + /// + /// NOTE: New error codes must be appended to a category + /// NOTE: Error codes cannot be deleted + /// + public static class HpcLinqErrorCode + { + internal const int codesPerCategory = 0x01000000; + + #region CodeGen + public const int TypeRequiredToBePublic = (int) HpcLinqErrorCodeCategory.CodeGen + 0; + public const int CustomSerializerMustSupportDefaultCtor = (int) HpcLinqErrorCodeCategory.CodeGen + 1; + public const int CustomSerializerMustBeClassOrStruct = (int) HpcLinqErrorCodeCategory.CodeGen + 2; + public const int TypeNotSerializable = (int) HpcLinqErrorCodeCategory.CodeGen + 3; + public const int CannotHandleSubtypes = (int)HpcLinqErrorCodeCategory.CodeGen + 4; + public const int UDTMustBeConcreteType = (int)HpcLinqErrorCodeCategory.CodeGen + 5; + public const int UDTHasFieldOfNonPublicType = (int)HpcLinqErrorCodeCategory.CodeGen + 6; + public const int UDTIsDelegateType = (int)HpcLinqErrorCodeCategory.CodeGen + 7; + public const int FailedToBuild = (int) HpcLinqErrorCodeCategory.CodeGen + 8; + public const int OutputTypeCannotBeAnonymous = (int) HpcLinqErrorCodeCategory.CodeGen + 9; + public const int InputTypeCannotBeAnonymous = (int) HpcLinqErrorCodeCategory.CodeGen + 10; + public const int BranchOfForkNotUsed = (int) HpcLinqErrorCodeCategory.CodeGen + 11; + public const int ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable= (int) HpcLinqErrorCodeCategory.CodeGen + 12; + public const int ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable = (int)HpcLinqErrorCodeCategory.CodeGen + 13; + public const int ComparerExpressionMustBeSpecifiedOrElementTypeMustBeIEquatable = (int)HpcLinqErrorCodeCategory.CodeGen + 14; + public const int TooManyHomomorphicAttributes = (int) HpcLinqErrorCodeCategory.CodeGen + 15; + public const int HomomorphicApplyNeedsSamePartitionCount = (int)HpcLinqErrorCodeCategory.CodeGen + 16; + public const int UnrecognizedDataSource = (int) HpcLinqErrorCodeCategory.CodeGen + 17; + // [deleted] #18 ReducerDoesntExist + // [deleted] #19 CombinerDoesntExist + // [deleted] #20 CombinerReturnTypeMismatch + public const int CannotConcatDatasetsWithDifferentCompression = (int) HpcLinqErrorCodeCategory.CodeGen + 21; + // [deleted] #22 CannotCreateTablesWithDifferentCompression + public const int AggregateOperatorNotSupported = (int) HpcLinqErrorCodeCategory.CodeGen + 23; + public const int FinalizerReturnTypeMismatch = (int)HpcLinqErrorCodeCategory.CodeGen + 24; + // [deleted] #25 FinalizerDoesNotExist + public const int CannotHandleCircularTypes = (int)HpcLinqErrorCodeCategory.CodeGen + 26; + public const int OperatorNotSupported = (int)HpcLinqErrorCodeCategory.CodeGen + 27; + public const int AggregationOperatorRequiresIComparable = (int)HpcLinqErrorCodeCategory.CodeGen + 28; + public const int DecomposerTypeDoesNotImplementInterface = (int)HpcLinqErrorCodeCategory.CodeGen + 29; + public const int DecomposerTypeImplementsTooManyInterfaces = (int)HpcLinqErrorCodeCategory.CodeGen + 30; + public const int DecomposerTypesDoNotMatch = (int)HpcLinqErrorCodeCategory.CodeGen + 31; + public const int DecomposerTypeMustBePublic = (int)HpcLinqErrorCodeCategory.CodeGen + 32; + public const int DecomposerTypeDoesNotHavePublicDefaultCtor = (int)HpcLinqErrorCodeCategory.CodeGen + 33; + public const int AssociativeMethodHasWrongForm = (int)HpcLinqErrorCodeCategory.QueryAPI + 34; + public const int AssociativeTypeDoesNotImplementInterface = (int)HpcLinqErrorCodeCategory.CodeGen + 35; + public const int AssociativeTypeImplementsTooManyInterfaces = (int)HpcLinqErrorCodeCategory.CodeGen + 36; + public const int AssociativeTypesDoNotMatch = (int)HpcLinqErrorCodeCategory.CodeGen + 37; + public const int AssociativeTypeMustBePublic = (int)HpcLinqErrorCodeCategory.CodeGen + 38; + public const int AssociativeTypeDoesNotHavePublicDefaultCtor = (int)HpcLinqErrorCodeCategory.CodeGen + 39; + // [deleted] #40 ClientNETVersion; + //->Internal public const int Internal_CannotBeUsedForValueType = (int)HpcLinqErrorCodeCategory.CodeGen + 41; + //->Internal public const int Internal_TypeDoesNotContainRequestedField = (int)HpcLinqErrorCodeCategory.CodeGen + 42; + public const int CannotCreatePartitionNodeRandom = (int)HpcLinqErrorCodeCategory.CodeGen + 43; + public const int PartitionKeysNotProvided = (int)HpcLinqErrorCodeCategory.CodeGen + 44; + public const int PartitionKeysAreNotConsistentlyOrdered = (int)HpcLinqErrorCodeCategory.CodeGen + 45; + public const int IsDescendingIsInconsistent = (int)HpcLinqErrorCodeCategory.CodeGen + 46; + //DEL public const int Internal_FailedToRemoveMergeNode = (int)HpcLinqErrorCodeCategory.CodeGen + 47; + //->Internal public const int Internal_CannotAttach = (int)HpcLinqErrorCodeCategory.CodeGen + 48; + //->Internal public const int Internal_CannotAddTeeToNode = (int)HpcLinqErrorCodeCategory.CodeGen + 49; + //DEL public const int Internal_ShouldNotCreateCodeForInput = (int)HpcLinqErrorCodeCategory.CodeGen + 50; + //DEL public const int Internal_ShouldNotCreateCodeForOutput = (int)HpcLinqErrorCodeCategory.CodeGen + 51; + //DEL public const int Internal_ShouldNotCreateCodeForConcat = (int)HpcLinqErrorCodeCategory.CodeGen + 52; + //->Internal public const int Internal_DynamicManagerType = (int)HpcLinqErrorCodeCategory.CodeGen + 53; + //DEL public const int Internal_ShouldNotCreateCodeForTee = (int)HpcLinqErrorCodeCategory.CodeGen + 54; + //->Internal public const int Internal_IllegalDynamicManagerType = (int)HpcLinqErrorCodeCategory.CodeGen + 55; + //->Internal public const int Internal_OptimizationPhaseError = (int)HpcLinqErrorCodeCategory.CodeGen + 56; + //->Internal public const int Internal_DistinctOnlyTakesTwoArgs = (int)HpcLinqErrorCodeCategory.CodeGen + 57; + //DEL public const int Internal_InputArityMustEqualChildren = (int)HpcLinqErrorCodeCategory.CodeGen + 58; + //->Internal public const int Internal_AddVertexNotHandled = (int)HpcLinqErrorCodeCategory.CodeGen + 59; + //->Internal public const int Internal_CannotBeEmpty = (int)HpcLinqErrorCodeCategory.CodeGen + 60; + //DEL public const int Internal_MustSpecifyOutputAssemblyFileName = (int)HpcLinqErrorCodeCategory.CodeGen + 61; + //->Internal public const int Internal_AutogeneratedAssemblyMissing = (int)HpcLinqErrorCodeCategory.CodeGen + 62; + //DEL public const int Internal_ShouldNotCreateCodeForDummyNode = (int)HpcLinqErrorCodeCategory.CodeGen + 63; + //->Internal public const int Internal_CannotBeUsedForReferenceType = (int)HpcLinqErrorCodeCategory.CodeGen + 64; + public const int BadSeparatorCount = (int)HpcLinqErrorCodeCategory.CodeGen + 65; + public const int TypeMustHaveDataMembers = (int)HpcLinqErrorCodeCategory.CodeGen + 66; + public const int CannotHandleObjectFields = (int)HpcLinqErrorCodeCategory.CodeGen + 67; + public const int CannotHandleDerivedtypes = (int)HpcLinqErrorCodeCategory.CodeGen + 68; + public const int MultipleOutputsWithSameDscUri = (int)HpcLinqErrorCodeCategory.CodeGen + 69; + public const int OutputUriAlsoQueryInput = (int)HpcLinqErrorCodeCategory.CodeGen + 70; + + //The "internal" code is used for internal errors that should not be hit by users. + //The messages may be informative, but the error code doesn't need to be and it avoids users + //seeing all the error codes in intellisense and/or wondering if they should catch & deal with them etc. + public const int Internal = (int)HpcLinqErrorCodeCategory.CodeGen + 71; + + #endregion + + #region DscClient + public const int DSCStreamError = (int) HpcLinqErrorCodeCategory.DscClient + 0; + public const int StreamDoesNotExist = (int) HpcLinqErrorCodeCategory.DscClient + 1; + public const int StreamAlreadyExists = (int) HpcLinqErrorCodeCategory.DscClient + 2; + public const int AttemptToReadFromAWriteStream = (int) HpcLinqErrorCodeCategory.DscClient + 3; + public const int FailedToCreateStream = (int) HpcLinqErrorCodeCategory.DscClient + 4; + public const int JobToCreateTableWasCanceled = (int) HpcLinqErrorCodeCategory.DscClient + 5; + public const int FailedToGetReadPathsForStream = (int) HpcLinqErrorCodeCategory.DscClient + 6; + public const int CannotAccesFilePath = (int)HpcLinqErrorCodeCategory.DscClient + 7; + public const int PositionNotSupported = (int)HpcLinqErrorCodeCategory.DscClient + 8; + public const int GetFileSizeError = (int)HpcLinqErrorCodeCategory.DscClient + 9; + public const int ReadFileError = (int)HpcLinqErrorCodeCategory.DscClient + 10; + public const int UnknownCompressionScheme = (int)HpcLinqErrorCodeCategory.DscClient + 11; + public const int WriteFileError = (int)HpcLinqErrorCodeCategory.DscClient + 12; + public const int MultiBlockEmptyPartitionList = (int)HpcLinqErrorCodeCategory.DscClient + 13; + public const int GetURINotSupported = (int)HpcLinqErrorCodeCategory.DscClient + 14; + public const int SetCalcFPNotSupported = (int)HpcLinqErrorCodeCategory.DscClient + 15; + public const int GetFPNotSupported = (int)HpcLinqErrorCodeCategory.DscClient + 16; + public const int FailedToAllocateNewNativeBuffer = (int)HpcLinqErrorCodeCategory.DscClient + 17; + public const int FailedToReadFromInputChannel = (int)HpcLinqErrorCodeCategory.DscClient + 18; + public const int FailedToWriteToOutputChannel = (int)HpcLinqErrorCodeCategory.DscClient + 19; + //DEL public const int Internal_PrefixAlreadyUsedForOtherProvider = (int)HpcLinqErrorCodeCategory.DscClient + 20; + //->Internal public const int Internal_UnknownProvider = (int)HpcLinqErrorCodeCategory.DscClient + 21; + //DEL public const int Internal_CannotCallPartitionInfoOnType = (int)HpcLinqErrorCodeCategory.DscClient + 22; + //DEL public const int Internal_IllFormedUriArguments = (int)HpcLinqErrorCodeCategory.DscClient + 23; + //DEL public const int Internal_OpenForWriteError = (int)HpcLinqErrorCodeCategory.DscClient + 24; + public const int MultiBlockCannotAccesFilePath = (int)HpcLinqErrorCodeCategory.DscClient + 25; + #endregion + + #region JobSubmission + public const int DryadHomeMustBeSpecified = (int) HpcLinqErrorCodeCategory.JobSubmission + 0; + public const int ClusterNameMustBeSpecified = (int) HpcLinqErrorCodeCategory.JobSubmission + 1; + public const int UnexpectedJobStatus = (int) HpcLinqErrorCodeCategory.JobSubmission + 2; + public const int JobStatusQueryError = (int) HpcLinqErrorCodeCategory.JobSubmission + 3; + public const int JobOptionNotImplemented = (int) HpcLinqErrorCodeCategory.JobSubmission + 4; + public const int HpcLinqJobMinMustBe2OrMore = (int) HpcLinqErrorCodeCategory.JobSubmission + 5; + public const int SubmissionFailure = (int)HpcLinqErrorCodeCategory.JobSubmission + 6; + public const int UnsupportedSchedulerType = (int)HpcLinqErrorCodeCategory.JobSubmission + 7; + public const int UnsupportedExecutionKind = (int)HpcLinqErrorCodeCategory.JobSubmission + 8; + public const int DidNotCompleteSuccessfully = (int)HpcLinqErrorCodeCategory.JobSubmission + 9; + public const int Binaries32BitNotSupported = (int)HpcLinqErrorCodeCategory.JobSubmission + 10; + #endregion + + #region QueryAPI + public const int DistinctAttributeComparerNotDefined = (int) HpcLinqErrorCodeCategory.QueryAPI + 0; + public const int SerializerTypeMustBeNonNull = (int) HpcLinqErrorCodeCategory.QueryAPI + 1; + public const int SerializerTypeMustSupportIHpcSerializer = (int) HpcLinqErrorCodeCategory.QueryAPI + 2; + public const int UnrecognizedOperatorName = (int) HpcLinqErrorCodeCategory.QueryAPI + 3; + //[deleted] #4 CouldNotInferCombiner + //[deleted] #5 CombinerHasWrongType + //[deleted] #6 UnknownCombiner + public const int UnsupportedExpressionsType = (int) HpcLinqErrorCodeCategory.QueryAPI + 7; + public const int UnsupportedExpressionType = (int) HpcLinqErrorCodeCategory.QueryAPI + 8; + //[deleted/duplicate] #9 IndexOutOfRange + public const int IndexTooSmall = (int)HpcLinqErrorCodeCategory.QueryAPI + 10; + public const int MultiQueryableKeyOutOfRange = (int) HpcLinqErrorCodeCategory.QueryAPI + 11; + public const int IndexOutOfRange = (int) HpcLinqErrorCodeCategory.QueryAPI + 12; + //->ArgumentException public const int NotAHpcLinqQuery = (int) HpcLinqErrorCodeCategory.QueryAPI + 13; + public const int ToDscUsedIncorrectly = (int) HpcLinqErrorCodeCategory.QueryAPI + 14; + public const int ExpressionTypeNotHandled = (int) HpcLinqErrorCodeCategory.QueryAPI + 15; + public const int FailedToGetStreamProps = (int) HpcLinqErrorCodeCategory.QueryAPI + 16; + public const int MetadataRecordType = (int) HpcLinqErrorCodeCategory.QueryAPI + 17; + //[deleted] #18 MetadataCompressionScheme + //[deleted] #19 CannotHaveZeroPartitions + public const int JobToCreateTableFailed = (int) HpcLinqErrorCodeCategory.QueryAPI + 20; + //[deleted/duplicate] #21 UnrecognizedDataSource + public const int OnlyAvailableForPhysicalData = (int) HpcLinqErrorCodeCategory.QueryAPI + 22; + public const int FileSetMustBeSealed = (int)HpcLinqErrorCodeCategory.QueryAPI + 23; + public const int FileSetCouldNotBeOpened = (int)HpcLinqErrorCodeCategory.QueryAPI + 24; + public const int FileSetMustHaveAtLeastOneFile = (int)HpcLinqErrorCodeCategory.QueryAPI + 25; + //->ArgumentException public const int AtLeastOneOperatorRequired = (int)HpcLinqErrorCodeCategory.QueryAPI + 26; + public const int CouldNotGetClientVersion = (int)HpcLinqErrorCodeCategory.QueryAPI + 27; + public const int CouldNotGetServerVersion = (int)HpcLinqErrorCodeCategory.QueryAPI + 28; + public const int ContextDisposed = (int)HpcLinqErrorCodeCategory.QueryAPI + 29; + public const int UnhandledQuery = (int)HpcLinqErrorCodeCategory.QueryAPI + 30; //@@TODO: when possible, reword the sr.txt entry. + public const int ExpressionMustBeMethodCall= (int)HpcLinqErrorCodeCategory.QueryAPI + 31; + public const int UntypedProviderMethodsNotSupported = (int)HpcLinqErrorCodeCategory.QueryAPI + 32; + public const int ErrorReadingMetadata = (int)HpcLinqErrorCodeCategory.QueryAPI + 33; + public const int MustStartFromContext = (int)HpcLinqErrorCodeCategory.QueryAPI + 34; + public const int ToHdfsUsedIncorrectly = (int)HpcLinqErrorCodeCategory.QueryAPI + 35; + //->Internal public const int Internal_TypeDoesNotContainMember = (int)HpcLinqErrorCodeCategory.QueryAPI + 35; + //->Internal public const int Internal_BugInHandlingAnonymousClass = (int)HpcLinqErrorCodeCategory.QueryAPI + 36; + //->Internal public const int Internal_FieldAnnotationIncorrectParameter = (int)HpcLinqErrorCodeCategory.QueryAPI + 37; + //->Internal public const int Internal_UnnamedParameterExpression = (int)HpcLinqErrorCodeCategory.QueryAPI + 38; + //DEL public const int Internal_UriIsNotDsc = (int)HpcLinqErrorCodeCategory.QueryAPI + 39; + //->ArgumentException public const int AlreadySubmitted = (int)HpcLinqErrorCodeCategory.QueryAPI + 40; + #endregion + + #region Serialization + public const int FailedToReadFrom = (int) HpcLinqErrorCodeCategory.Serialization + 0; + public const int EndOfStreamEncountered = (int) HpcLinqErrorCodeCategory.Serialization + 1; + public const int SettingPositionNotSupported = (int) HpcLinqErrorCodeCategory.Serialization + 2; + public const int FingerprintDisabled = (int) HpcLinqErrorCodeCategory.Serialization + 3; + public const int RecordSizeMax2GB = (int) HpcLinqErrorCodeCategory.Serialization + 4; + //[deleted/duplicate] #5 SettingPositionNotSupported + public const int ReadByteNotAllowed = (int) HpcLinqErrorCodeCategory.Serialization + 6; + public const int ReadNotAllowed = (int) HpcLinqErrorCodeCategory.Serialization + 7; + public const int SeekNotSupported = (int) HpcLinqErrorCodeCategory.Serialization + 8; + public const int SetLengthNotSupported = (int) HpcLinqErrorCodeCategory.Serialization + 9; + public const int FailedToDeserialize = (int) HpcLinqErrorCodeCategory.Serialization + 10; + public const int ChannelCannotBeReadMoreThanOnce = (int) HpcLinqErrorCodeCategory.Serialization + 11; + //[deleted/duplicate] #12 IndexOutOfRange + public const int WriteNotSupported = (int)HpcLinqErrorCodeCategory.Serialization + 13; + public const int WriteByteNotSupported = (int)HpcLinqErrorCodeCategory.Serialization + 14; + public const int CannotSerializeHpcLinqQuery = (int)HpcLinqErrorCodeCategory.Serialization + 15; + public const int CannotSerializeObject = (int)HpcLinqErrorCodeCategory.Serialization + 16; + public const int GeneralSerializeFailure = (int)HpcLinqErrorCodeCategory.Serialization + 17; + //DEL public const int Internal_ShouldNotCallReset = (int)HpcLinqErrorCodeCategory.Serialization + 18; + #endregion + + #region VertexRuntime + public const int SourceOfMergesortMustBeMultiEnumerable = (int) HpcLinqErrorCodeCategory.VertexRuntime + 1; + public const int ThenByNotSupported = (int) HpcLinqErrorCodeCategory.VertexRuntime + 2; + public const int AggregateNoElements = (int) HpcLinqErrorCodeCategory.VertexRuntime + 3; + public const int FirstNoElementsFirst = (int) HpcLinqErrorCodeCategory.VertexRuntime + 4; + public const int SingleMoreThanOneElement = (int) HpcLinqErrorCodeCategory.VertexRuntime + 5; + public const int SingleNoElements = (int) HpcLinqErrorCodeCategory.VertexRuntime + 6; + public const int LastNoElements = (int) HpcLinqErrorCodeCategory.VertexRuntime + 7; + public const int MinNoElements = (int) HpcLinqErrorCodeCategory.VertexRuntime + 8; + public const int MaxNoElements = (int) HpcLinqErrorCodeCategory.VertexRuntime + 9; + public const int AverageNoElements = (int) HpcLinqErrorCodeCategory.VertexRuntime + 10; + public const int RangePartitionKeysMissing = (int) HpcLinqErrorCodeCategory.VertexRuntime + 11; + public const int PartitionFuncReturnValueExceedsNumPorts = (int) HpcLinqErrorCodeCategory.VertexRuntime + 12; + public const int FailureInExcept = (int) HpcLinqErrorCodeCategory.VertexRuntime + 13; + public const int FailureInIntersect = (int) HpcLinqErrorCodeCategory.VertexRuntime + 14; + public const int FailureInSort = (int) HpcLinqErrorCodeCategory.VertexRuntime + 15; + public const int RangePartitionInputOutputMismatch = (int) HpcLinqErrorCodeCategory.VertexRuntime + 16; + //DEL public const int Internal_CannotHaveMoreThanOneOutput = (int)HpcLinqErrorCodeCategory.VertexRuntime + 17; + public const int KeyNotFound = (int)HpcLinqErrorCodeCategory.VertexRuntime + 18; + public const int TooManyItems = (int)HpcLinqErrorCodeCategory.VertexRuntime + 19; + public const int FailureInHashGroupBy = (int)HpcLinqErrorCodeCategory.VertexRuntime + 20; + public const int FailureInSortGroupBy = (int)HpcLinqErrorCodeCategory.VertexRuntime + 21; + public const int FailureInHashJoin = (int)HpcLinqErrorCodeCategory.VertexRuntime + 22; + public const int FailureInHashGroupJoin = (int)HpcLinqErrorCodeCategory.VertexRuntime + 23; + public const int FailureInDistinct = (int)HpcLinqErrorCodeCategory.VertexRuntime + 24; + public const int FailureInOperator = (int)HpcLinqErrorCodeCategory.VertexRuntime + 25; + public const int FailureInUserApplyFunction = (int)HpcLinqErrorCodeCategory.VertexRuntime + 26; + public const int FailureInOrderedGroupBy = (int)HpcLinqErrorCodeCategory.VertexRuntime + 27; + //DEL public const int Internal_NullSelector = (int)HpcLinqErrorCodeCategory.VertexRuntime + 28; + //DEL public const int Internal_CannotResetIEnumerator = (int)HpcLinqErrorCodeCategory.VertexRuntime + 29; + //DEL public const int Internal_WrongFlagCombination = (int)HpcLinqErrorCodeCategory.VertexRuntime + 30; + //DEL public const int Internal_SourceMustBeDryadVertexReader = (int)HpcLinqErrorCodeCategory.VertexRuntime + 31; + //->Internal public const int Internal_SortedChunkCannotBeEmpty = (int)HpcLinqErrorCodeCategory.VertexRuntime + 32; + public const int TooManyElementsBeforeReduction = (int)HpcLinqErrorCodeCategory.VertexRuntime + 33; //@@TODO: when possible, reword the sr.txt entry. + #endregion + + #region LocalDebug + public const int CreatingDscDataFromLocalDebugFailed = (int)HpcLinqErrorCodeCategory.LocalDebug + 0; + #endregion + + #region Unknown + public const int UnknownError = (int) HpcLinqErrorCodeCategory.Unknown + 0; + + #endregion + + /// + /// Returns the category of the specified error code + /// + internal static HpcLinqErrorCodeCategory Category(int code) + { + if ((code >= (int) HpcLinqErrorCodeCategory.QueryAPI) && (code < (int) HpcLinqErrorCodeCategory.QueryAPI+ codesPerCategory)) + { + return HpcLinqErrorCodeCategory.QueryAPI; + } + else if ((code >= (int) HpcLinqErrorCodeCategory.CodeGen) && (code < (int) HpcLinqErrorCodeCategory.CodeGen+ codesPerCategory)) + { + return HpcLinqErrorCodeCategory.CodeGen; + } + else if ((code >= (int) HpcLinqErrorCodeCategory.JobSubmission) && (code < (int) HpcLinqErrorCodeCategory.JobSubmission+ codesPerCategory)) + { + return HpcLinqErrorCodeCategory.JobSubmission; + } + else if ((code >= (int) HpcLinqErrorCodeCategory.Serialization) && (code < (int) HpcLinqErrorCodeCategory.Serialization+ codesPerCategory)) + { + return HpcLinqErrorCodeCategory.Serialization; + } + else if ((code >= (int)HpcLinqErrorCodeCategory.DscClient) && (code < (int)HpcLinqErrorCodeCategory.DscClient+ codesPerCategory)) + { + return HpcLinqErrorCodeCategory.DscClient; + } + else if ((code >= (int)HpcLinqErrorCodeCategory.VertexRuntime) && (code < (int)HpcLinqErrorCodeCategory.VertexRuntime+ codesPerCategory)) + { + return HpcLinqErrorCodeCategory.VertexRuntime; + } + else if ((code >= (int)HpcLinqErrorCodeCategory.LocalDebug) && (code < (int)HpcLinqErrorCodeCategory.LocalDebug+ codesPerCategory)) + { + return HpcLinqErrorCodeCategory.LocalDebug; + } + else + { + return HpcLinqErrorCodeCategory.Unknown; + } + } + + } +} diff --git a/LinqToDryad/DryadLinqFileStream.cs b/LinqToDryad/DryadLinqFileStream.cs new file mode 100644 index 0000000..d9c5afa --- /dev/null +++ b/LinqToDryad/DryadLinqFileStream.cs @@ -0,0 +1,249 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.IO.Compression; +using System.Text; +using System.Reflection; +using System.Diagnostics; +using Microsoft.Win32.SafeHandles; +using System.Runtime.InteropServices; +using Microsoft.Research.DryadLinq; + + +namespace Microsoft.Research.DryadLinq.Internal +{ + // This class directly talks to NTFS files. + internal unsafe class HpcLinqFileStream : NativeBlockStream + { + private const int DefaultBuffSize = 8192*32; + + private FileStream m_fstream; + private SafeFileHandle m_fhandle; + private DscCompressionScheme m_compressionScheme; + private bool m_isClosed; + private Stream m_compressStream; + + internal HpcLinqFileStream(FileStream fstream, DscCompressionScheme scheme) + { + this.m_fstream = fstream; + this.m_fhandle = fstream.SafeFileHandle; + this.m_compressionScheme = scheme; + this.m_isClosed = false; + this.m_compressStream = null; + } + + private void Initialize(string filePath, FileMode mode, FileAccess access, DscCompressionScheme scheme) + { + try + { + this.m_fstream = new FileStream(filePath, mode, access); + } + catch(Exception e) + { + throw new DryadLinqException(HpcLinqErrorCode.CannotAccesFilePath, + String.Format(SR.CannotAccesFilePath , filePath),e); + } + this.m_fhandle = m_fstream.SafeFileHandle; + this.m_isClosed = false; + this.m_compressionScheme = scheme; + this.m_compressStream = null; + } + + internal HpcLinqFileStream(string filePath, FileAccess access, DscCompressionScheme scheme) + { + FileMode mode = (access == FileAccess.Read) ? FileMode.Open : FileMode.OpenOrCreate; + Initialize(filePath, mode, access, scheme); + } + + internal HpcLinqFileStream(string filePath, FileAccess access) + : this(filePath, access, DscCompressionScheme.None) + { + } + + internal HpcLinqFileStream(string filePath, FileMode mode, FileAccess access, DscCompressionScheme scheme) + { + Initialize(filePath, mode, access, scheme); + } + + internal HpcLinqFileStream(string filePath, FileMode mode, FileAccess access) + : this(filePath, mode, access, DscCompressionScheme.None) + { + } + + internal override unsafe Int64 GetTotalLength() + { + Int64 totalLen; + bool success = HpcLinqNative.GetFileSizeEx(this.m_fhandle, out totalLen); + if (!success) + { + throw new DryadLinqException(HpcLinqErrorCode.GetFileSizeError, + String.Format(SR.GetFileSizeError, + Marshal.GetLastWin32Error())); + } + return totalLen; + } + + internal override unsafe DataBlockInfo ReadDataBlock() + { + DataBlockInfo blockInfo; + blockInfo.dataBlock = (byte*)Marshal.AllocHGlobal(DefaultBuffSize); + blockInfo.itemHandle = (IntPtr)blockInfo.dataBlock; + if (this.m_compressionScheme == DscCompressionScheme.None) + { + Int32* pBlockSize = &blockInfo.blockSize; + bool success = HpcLinqNative.ReadFile(this.m_fhandle, + blockInfo.dataBlock, + DefaultBuffSize, + (IntPtr)pBlockSize, + null); + if (!success) + { + throw new DryadLinqException(HpcLinqErrorCode.ReadFileError, + String.Format(SR.ReadFileError, + Marshal.GetLastWin32Error())); + } + } + else + { + if (this.m_compressStream == null) + { + if (this.m_compressionScheme == DscCompressionScheme.Gzip) + { + this.m_compressStream = new GZipStream(this.m_fstream, + CompressionMode.Decompress); + } + else + { + throw new DryadLinqException(HpcLinqErrorCode.UnknownCompressionScheme, + SR.UnknownCompressionScheme); + } + } + // YY: Made an extra copy here. Could do better. + byte[] buffer = new byte[DefaultBuffSize]; + blockInfo.blockSize = this.m_compressStream.Read(buffer, 0, DefaultBuffSize); + fixed (byte* pBuffer = buffer) + { + HpcLinqUtil.memcpy(pBuffer, blockInfo.dataBlock, blockInfo.blockSize); + } + } + + return blockInfo; + } + + internal override unsafe bool WriteDataBlock(IntPtr itemHandle, Int32 numBytesToWrite) + { + byte* dataBlock = (byte*)itemHandle; + if (this.m_compressionScheme == DscCompressionScheme.None) + { + Int32 numBytesWritten = 0; + Int32 remainingBytes = numBytesToWrite; + + while (remainingBytes > 0) + { + Int32* pNumBytesWritten = &numBytesWritten; + bool success = HpcLinqNative.WriteFile(this.m_fhandle, + dataBlock, + (UInt32)remainingBytes, + (IntPtr)pNumBytesWritten, + null); + if (!success) + { + throw new DryadLinqException(HpcLinqErrorCode.WriteFileError, + String.Format(SR.WriteFileError, + Marshal.GetLastWin32Error())); + } + + dataBlock += numBytesWritten; + remainingBytes -= numBytesWritten; + } + } + else + { + if (this.m_compressStream == null) + { + if (this.m_compressionScheme == DscCompressionScheme.Gzip) + { + this.m_compressStream = new GZipStream(this.m_fstream, + CompressionMode.Compress); + } + else + { + throw new DryadLinqException(HpcLinqErrorCode.UnknownCompressionScheme, + SR.UnknownCompressionScheme); + } + } + // YY: Made an extra copy here. Could do better. + byte[] buffer = new byte[numBytesToWrite]; + fixed (byte* pBuffer = buffer) + { + HpcLinqUtil.memcpy(dataBlock, pBuffer, numBytesToWrite); + } + this.m_compressStream.Write(buffer, 0, numBytesToWrite); + } + return true; + } + + internal override void Flush() + { + if (this.m_compressStream != null) + { + this.m_compressStream.Flush(); + } + this.m_fstream.Flush(); + } + + internal override void Close() + { + if (!this.m_isClosed) + { + this.m_isClosed = true; + if (this.m_compressStream != null) + { + this.m_compressStream.Close(); + } + this.m_fstream.Close(); + } + } + + internal override unsafe DataBlockInfo AllocateDataBlock(Int32 size) + { + DataBlockInfo blockInfo; + blockInfo.itemHandle = Marshal.AllocHGlobal((IntPtr)size); + blockInfo.dataBlock = (byte*)blockInfo.itemHandle; + blockInfo.blockSize = size; + return blockInfo; + } + + internal override unsafe void ReleaseDataBlock(IntPtr itemHandle) + { + if (itemHandle != IntPtr.Zero) + { + Marshal.FreeHGlobal(itemHandle); + } + } + } +} diff --git a/LinqToDryad/DryadLinqGlobals.cs b/LinqToDryad/DryadLinqGlobals.cs new file mode 100644 index 0000000..c63c9c2 --- /dev/null +++ b/LinqToDryad/DryadLinqGlobals.cs @@ -0,0 +1,205 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Text; +using System.Reflection; +using System.Linq; + +namespace Microsoft.Research.DryadLinq +{ + /// + /// This class consists of static properties and fields that control the + /// details of the HpcLinq configuration. + /// + internal class StaticConfig + { + internal static bool UseLargeBuffer = true; + + internal const int NoDynamicOpt = 0; + internal const int DynamicBroadcastLevel = 0x1; + internal const int DynamicAggregateLevel = 0x2; + internal const int DynamicRangePartitionLevel = 0x4; + internal const int DynamicHashPartitionLevel = 0x8; + internal const int DynamicCoRangePartitionLevel = 0x10; + + internal static string HpcLinqBinRoot + { + get + { + //TODO -reconcile this with config variable + string dryadHome = Environment.GetEnvironmentVariable("DRYAD_HOME"); + if (string.IsNullOrEmpty(dryadHome)) + { + throw new DryadLinqException(HpcLinqErrorCode.DryadHomeMustBeSpecified, SR.DryadHomeMustBeSpecified); + } + return dryadHome; + } + } + + /// + /// Extra HpcLinq-related command line arguments for XmlExecHost.exe. + /// + //public static string XmlExecHostArgsAdditional = "--break"; + internal static string XmlExecHostArgsAdditional = String.Empty; + + /// + /// Runtime optimization: mostly affects dynamic resource management. + /// + internal static int DynamicOptLevel = DynamicBroadcastLevel; + + /// + /// Number of partitions to use when doing static resource management. + /// + //@@TODO: dynamic re-partitioning of multiple inputs should use CoHashPartition or CoRangePartition + // both are in MSR private repositories, and Yuan recommends the use of CoRangePartition + // as it produces more balanced partitioning. + // If/when this work is done, DefaultPartitionCount should be retired. + // + // Short-term: (assuming dynamic-managers return), using nClusterWidth might be better than "8" + internal static int DefaultPartitionCount = 8; + + /// + /// Used by concat to determine whether to repartition the data down to fewer partitions. + /// + //@@TODO: review and probably choose a more sensible value for this. + // Choosing a smaller value should break anything.. just forces repartitioning rather + // than a plain-old concatenation of the source partitions. + internal static int MaxPartitionCount = 20000; + + /// + /// Use in memory FIFOs where appropriate. + /// + internal static bool UseMemoryFIFO = false; + + /// + /// Stop execution of vertices in a debugger. + /// + internal static bool LaunchDebugger = false; + + // Specifies whether object fields can have null values. + // Setting AllowNullFields to true allows all object fields to have + // null values. A field can be specified as non-nullable by [Nullable(false)]. + // If AllowNullFields is false, all object fields are treated as non-nullable. + // A field can be specified as nullable by [Nullable(true)]. + internal static bool AllowNullFields = false; + internal static bool AllowNullArrayElements = false; + + /// + /// Specifies whether records can have null values. + /// + /// + /// Setting AllowNullRecords to true allows all records to have null values. + /// A class can be specified as non-nullable by [Nullable(false)]. + /// If AllowNullRecords is false, all records are treated as non-nullable. + /// A class can be specified as nullable by [Nullable(true)]. + /// + internal static bool AllowNullRecords = false; + + /// + /// Allows records to be serialized and retain their concrete type identity + /// + /// + /// If false, records in an IQueryable{T} will be treated as records of type T regardless of their concrete type. + /// + internal static bool AllowAutoTypeInference = false; + + /// + /// Specifies whether to use aggregation tree to work around some SMB limitation. + /// + internal static bool UseSMBAggregation = true; + + /// + /// The maximum number of seconds to wait between polling the cluster to see if a job has completed. + /// + internal static int JobCompletionMaxPollInterval = 20; + + // The local reduction strategy used by GroupBy. + internal static CombinerKind GroupByReduceStrategy = CombinerKind.PartialHash; + internal static bool GroupByDynamicReduce = false; + internal static bool GroupByLocalAggregationIsPartial = true; + + /// + /// Actual arguments to pass to the XmlExecHost.exe job manager. + /// + internal static string XmlExecHostArgs + { + get { + return XmlExecHostArgsAdditional; + } + } + + /// + /// Path to the HpcLinq main job manager executable (a file called HpcQueryGraphManager.exe) + /// + internal static string XmlHostPath + { + get { + return StaticConfig.HpcLinqBinRoot + @"\HpcQueryGraphManager.exe"; + } + } + + internal static TimeSpan LeaseDurationForTempFiles = new TimeSpan(1, 0, 0, 0); // 1 day. + } + + /// + /// Specifies a reduction strategy for GroupBy. + /// + internal enum CombinerKind + { + /// + /// Partial sort GroupBy. + /// + PartialSort, + /// + /// Partial hash GroupBy. + /// + PartialHash, + /// + /// Full hash GroupBy. + /// + FullHash + } + + /// + /// Contains references to class&method names that are referenced via reflection. + /// This is intended to assist with refactoring that may break reflection. + /// + /// NOTE: this list will probably never be complete. + /// - A method mentioned here is definitely accessed via reflection + /// - Do not assume that methods not listed here are not accessed via Reflection. + /// + internal class ReflectedNames + { + internal const string DataProvider_GetPartitionedTable = "GetPartitionedTable"; + internal const string DryadLinqIQueryable_ToDscWorker = "ToDscWorker"; + internal const string DryadLinqIQueryable_ToHdfsWorker = "ToHdfsWorker"; + internal const string DryadLinqIQueryable_AnonymousDscPlaceholder = "AnonymousDscTarget__Placeholder"; + internal const string HpcLinqQueryable_LocalDebug_ProcessToDscExpression = "LocalDebug_ProcessToDscExpression"; + internal const string HpcLinqQueryable_ExecuteLocalExpressionAndIngressToDsc = "ExecuteLocalExpressionAndIngressToDsc"; + } +} diff --git a/LinqToDryad/DryadLinqHelper.cs b/LinqToDryad/DryadLinqHelper.cs new file mode 100644 index 0000000..bcf6e69 --- /dev/null +++ b/LinqToDryad/DryadLinqHelper.cs @@ -0,0 +1,416 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Globalization; +using System.Reflection; +using System.Linq.Expressions; +using System.Linq; + +namespace Microsoft.Research.DryadLinq.Internal +{ + public static class HpcLinqHelper + { + [Resource(IsStateful = false)] + public static IEnumerable + CheckSort(IEnumerable source, + Expression> keySelector, + IComparer comparer, + bool isDescending) + { + Func keySel = keySelector.Compile(); + comparer = TypeSystem.GetComparer(comparer); + + IEnumerator elems = source.GetEnumerator(); + if (elems.MoveNext()) + { + TSource curElem = elems.Current; + yield return curElem; + + TKey curKey = keySel(curElem); + while (elems.MoveNext()) + { + TSource nextElem = elems.Current; + yield return nextElem; + + TKey nextKey = keySel(nextElem); + int cmp = comparer.Compare(curKey, nextKey); + int cmpRes = (isDescending) ? -cmp : cmp; + if (cmpRes > 0) + { + throw new DryadLinqException(SR.SourceNotOrdered); + } + curKey = nextKey; + } + } + } + + public static IEnumerable Cross(IEnumerable s1, + IEnumerable s2, + Expression> procFunc) + { + Func proc = procFunc.Compile(); + bool useRight = true; + if ((s1 is HpcVertexReader) && (s2 is HpcVertexReader)) + { + Int64 leftLen = ((HpcVertexReader)s1).GetTotalLength(); + Int64 rightLen = ((HpcVertexReader)s2).GetTotalLength(); + if (leftLen >= 0 && rightLen >= 0) + { + useRight = rightLen <= leftLen; + } + } + if (useRight) + { + List elems2 = s2.ToList(); + foreach (var elem1 in s1) + { + foreach (var elem2 in elems2) + { + yield return proc(elem1, elem2); + } + } + } + else + { + List elems1 = s1.ToList(); + foreach (var elem2 in s2) + { + foreach (var elem1 in elems1) + { + yield return proc(elem1, elem2); + } + } + } + } + + public static IEnumerable SelectSecond(IEnumerable s1, IEnumerable s2) + { + return s2; + } + + // Used in SequenceEqual() + public static IEnumerable SequenceEqual(IEnumerable s1, + IEnumerable s2, + IEqualityComparer comparer) + { + return HpcLinqVertex.AsEnumerable(System.Linq.Enumerable.SequenceEqual(s1, s2, comparer)); + + } + + // Used in SlidingWindow() + [Resource(IsStateful = false)] + public static IEnumerable Last(IEnumerable source, + int windowSize) + { + int count = windowSize - 1; + T[] buffer = new T[count]; + long total = 0; + foreach (var x in source) + { + buffer[total % count] = x; + total++; + } + + if (total < count) + { + throw new DryadLinqException(String.Format(SR.PartitionTooSmallForSlidingWindow, count)); + } + + T[] last = new T[count]; + int startIdx = (int)total % count; + Array.Copy(buffer, startIdx, last, 0, count - startIdx); + Array.Copy(buffer, 0, last, count - startIdx, startIdx); + yield return last; + } + + public static IEnumerable> Slide(IEnumerable source) + { + using (IEnumerator sourceEnum = source.GetEnumerator()) + { + if (sourceEnum.MoveNext()) + { + yield return new IndexedValue(0, new T[0]); + + int index = 1; + T[] lastVal = sourceEnum.Current; + while (sourceEnum.MoveNext()) + { + yield return new IndexedValue(index, lastVal); + index++; + lastVal = sourceEnum.Current; + } + } + } + } + + [Resource(IsStateful = false)] + public static IEnumerable + ProcessWindows(IEnumerable> source1, + IEnumerable source2, + Func, T2> procFunc, + Int32 windowSize) + { + Window window = new Window(windowSize); + T1[] slided = source1.Single().Value; + for (int i = 0; i < slided.Length; i++) + { + window.Add(slided[i]); + } + + using (IEnumerator sourceEnum = source2.GetEnumerator()) + { + while (window.Count() < windowSize) + { + if (!sourceEnum.MoveNext()) break; + window.Add(sourceEnum.Current); + } + if (window.Count() == windowSize) + { + yield return procFunc(window); + while (sourceEnum.MoveNext()) + { + window.Add(sourceEnum.Current); + yield return procFunc(window); + } + } + } + } + + // Calculate the sizes of the partitions. Used for example to implement Concat. + public static IEnumerable> IndexedCount(IEnumerable source) + { + T[] elems = source.ToArray(); + for (int i = 0; i < elems.Length; i++) + { + yield return new IndexedValue(i, elems); + } + } + + [Resource(IsStateful = false)] + public static IEnumerable> + AddPartitionIndex(IEnumerable> source1, IEnumerable source2, Int32 pcount) + { + IndexedValue s1 = source1.Single(); + long averageCount = s1.Value.Sum() / pcount; + long partialCount = 0; + for (int i = 0; i < s1.Index; i++) + { + partialCount += s1.Value[i]; + } + int partIndex = (int)(partialCount / averageCount); + long indexInPart = partialCount % averageCount; + foreach (T elem in source2) + { + if (indexInPart >= averageCount && partIndex != pcount-1) + { + partIndex++; + indexInPart = 0; + } + yield return new IndexedValue(partIndex, elem); + indexInPart++; + } + } + + // Produces one dummy item per partition. Used for example to implement Reverse(). + [Resource(IsStateful = false)] + public static IEnumerable ValueZero(IEnumerable source) + { + yield return 0; + } + + //Used for Reverse() + //input: a sequence of n dummy items. eg {0,0,0... } x n + //output: { {(0,n), (1,n), (2,n), .., (n-1, n)} } + // item.Index = index + // item.Value = nPartitions + public static IEnumerable> MakeIndexCountPairs(IEnumerable source) + { + int count = source.Count(); + for (int i = 0; i < count; i++) + { + yield return new IndexedValue(i, count); + } + } + + // Used for Reverse() + // receives a pair (myIndex, nPartitions) as source1, and a normal sequence as source2. + // targetIdx = nPartition-myIndex-1 + // produces {(targetIdx, item), (targetIdx, item), ...} + public static IEnumerable> + AddIndexForReverse(IEnumerable> source1, IEnumerable source2) + { + IndexedValue item = source1.Single(); + int myIndex = item.Index; + int pcount = item.Value; + int targetIndex = pcount - myIndex - 1; + foreach (T elem in source2) + { + yield return new IndexedValue(targetIndex, elem); + } + } + + // Used in Zip() + public static IEnumerable>> + ZipCount(IEnumerable source1, IEnumerable source2) + { + long[] elems1 = source1.ToArray(); + long[] elems2 = source2.ToArray(); + Pair pair = new Pair(elems1, elems2); + for (int i = 0; i < elems2.Length; i++) + { + yield return new IndexedValue>(i, pair); + } + } + + public static IEnumerable> + AssignPartitionIndex(IEnumerable>> source1, + IEnumerable source2) + { + IndexedValue> s1 = source1.Single(); + long[] elems1 = s1.Value.Key; + long[] elems2 = s1.Value.Value; + + long partialCount = 0; + for (int i = 0; i < s1.Index; i++) + { + partialCount += elems2[i]; + } + int partIndex = 0; + for (partIndex = 0; partIndex < elems1.Length; partIndex++) + { + partialCount -= elems1[partIndex]; + if (partialCount < 0) break; + } + if (partialCount < 0) + { + foreach (T elem in source2) + { + yield return new IndexedValue(partIndex, elem); + partialCount++; + if (partialCount == 0) + { + for (partIndex = partIndex + 1; partIndex < elems1.Length; partIndex++) + { + partialCount = -elems1[partIndex]; + if (partialCount < 0) break; + } + if (partialCount == 0) break; + } + } + } + } + + // Used in SelectWithPartitionIndex() + public static IEnumerable AssignIndex(IEnumerable source) + { + int index = 0; + foreach (int elem in source) + { + yield return index; + index++; + } + } + + public static IEnumerable + ProcessWithIndex(IEnumerable source1, + IEnumerable source2, + Func, int, IEnumerable> procFunc) + { + int index = source2.Single(); + return procFunc(source1, index); + } + + public static IEnumerable + ProcessWithIndex(IEnumerable source1, + IEnumerable source2, + Func procFunc) + { + int index = source2.Single(); + return HpcLinqVertex.Select(source1, x => procFunc(x, index), true); + } + } + + internal class Window : IEnumerable + { + private T[] m_elems; + private int m_startIdx; + private int m_count; + + public Window(int len) + { + this.m_elems = new T[len]; + this.m_startIdx = 0; + this.m_count = 0; + } + + public void Add(T elem) + { + int nextIdx = this.m_startIdx + this.m_count; + if (nextIdx >= this.m_elems.Length) + { + nextIdx -= this.m_elems.Length; + } + this.m_elems[nextIdx] = elem; + if (this.m_count < this.m_elems.Length) + { + this.m_count++; + } + else + { + this.m_startIdx++; + if (this.m_startIdx == this.m_elems.Length) + { + this.m_startIdx = 0; + } + } + } + + public int Count() + { + return this.m_count; + } + + #region IEnumerable and IEnumerable members + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + int idx = this.m_startIdx; + for (int i = 0; i < this.m_count; i++) + { + yield return this.m_elems[idx]; + idx++; + if (idx == this.m_elems.Length) idx = 0; + } + } + #endregion + } +} + diff --git a/LinqToDryad/DryadLinqIEnumerable.cs b/LinqToDryad/DryadLinqIEnumerable.cs new file mode 100644 index 0000000..a6a7f4e --- /dev/null +++ b/LinqToDryad/DryadLinqIEnumerable.cs @@ -0,0 +1,1346 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Globalization; +using System.Reflection; +using System.Linq; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq.Internal +{ + // This class provides the IEnumerable implementation of the operators + // we introduced in DryadLINQ. This is needed when LocalDebug is set. + // + // Due to the way linq-to-objects locates methods to call, this class must be public visibility. + public static class HpcLinqEnumerable + { + // Operator: HashPartition + public static IEnumerable + HashPartition(this IEnumerable source, + Func keySelector, + IEqualityComparer comparer, + int count) + { + return source; + } + + public static IEnumerable + HashPartition(this IEnumerable source, + Func keySelector, + int count) + { + return source; + } + + public static IEnumerable + HashPartition(this IEnumerable source, + Func keySelector) + { + return source; + } + + public static IEnumerable + HashPartition(this IEnumerable source, + Func keySelector, + IEqualityComparer comparer) + { + return source; + } + + // Operator: RangePartition + public static IEnumerable + RangePartition(this IEnumerable source, + Func keySelector, + bool isDescending) + { + return source; + } + + public static IEnumerable + RangePartition(this IEnumerable source, + Func keySelector, + IComparer comparer, + bool isDescending) + { + return source; + } + + public static IEnumerable + RangePartition(this IEnumerable source, + Func keySelector, + TKey[] rangeSeparators) + { + return source; + } + + public static IEnumerable + RangePartition(this IEnumerable source, + Func keySelector, + TKey[] rangeSeparators, + IComparer comparer) + { + return source; + } + + // Operator: Apply + public static IEnumerable + Apply(this IEnumerable source, + Func, IEnumerable> procFunc) + { + return procFunc(source); + } + + public static IEnumerable + Apply(this IEnumerable source1, + IEnumerable source2, + Func, IEnumerable, IEnumerable> procFunc) + { + return procFunc(source1, source2); + } + + public static IEnumerable + Apply(this IEnumerable source, + IEnumerable[] otherSources, + Func[], IEnumerable> procFunc) + { + IEnumerable[] allSources = new IEnumerable[otherSources.Length + 1]; + allSources[0] = source; + for (int i = 0; i < otherSources.Length; i++) + { + allSources[i+1] = otherSources[i]; + } + return procFunc(allSources); + } + + // Operator: SlidingWindow + public static IEnumerable + SlidingWindow(this IEnumerable source, + Func, T2> procFunc, + Int32 windowSize) + { + Window window = new Window(windowSize); + using (IEnumerator sourceEnum = source.GetEnumerator()) + { + while (window.Count() < windowSize) + { + if (!sourceEnum.MoveNext()) break; + window.Add(sourceEnum.Current); + } + if (window.Count() == windowSize) + { + yield return procFunc(window); + while (sourceEnum.MoveNext()) + { + window.Add(sourceEnum.Current); + yield return procFunc(window); + } + } + } + } + + // Operator: SelectWithPartitionIndex + public static IEnumerable + SelectWithPartitionIndex(this IEnumerable source, + Func procFunc) + { + foreach (T1 x in source) + { + yield return procFunc(x, 0); + } + } + + // Operator: ApplyWithPartitionIndex + public static IEnumerable + ApplyWithPartitionIndex(this IEnumerable source, + Func, int, IEnumerable> procFunc) + { + return procFunc(source, 0); + } + + public static IEnumerable AnyAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Any()); + } + + public static IEnumerable + AnyAsQuery(this IEnumerable source, Func predicate) + { + return HpcLinqVertex.AsEnumerable(source.Any(predicate)); + } + + public static IEnumerable + AllAsQuery(this IEnumerable source, Func predicate) + { + return HpcLinqVertex.AsEnumerable(source.All(predicate)); + } + + public static IEnumerable + CountAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Count()); + } + + public static IEnumerable + CountAsQuery(this IEnumerable source, Func predicate) + { + return HpcLinqVertex.AsEnumerable(source.Count(predicate)); + } + + public static IEnumerable + LongCountAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.LongCount()); + } + + public static IEnumerable + LongCountAsQuery(this IEnumerable source, + Func predicate) + { + return HpcLinqVertex.AsEnumerable(source.LongCount(predicate)); + } + + public static IEnumerable + ContainsAsQuery(this IEnumerable source, TSource value) + { + return HpcLinqVertex.AsEnumerable(source.Contains(value)); + } + + public static IEnumerable + ContainsAsQuery(this IEnumerable source, + TSource value, + IEqualityComparer comparer) + { + return HpcLinqVertex.AsEnumerable(source.Contains(value, comparer)); + } + + public static IEnumerable + SequenceEqualAsQuery(this IEnumerable first, + IEnumerable second) + { + return HpcLinqVertex.AsEnumerable(first.SequenceEqual(second)); + } + + public static IEnumerable + SequenceEqualAsQuery(this IEnumerable first, + IEnumerable second, + IEqualityComparer comparer) + { + return HpcLinqVertex.AsEnumerable(first.SequenceEqual(second, comparer)); + } + + public static IEnumerable FirstAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.First()); + } + + public static IEnumerable + FirstAsQuery(this IEnumerable source, + Func predicate) + { + return HpcLinqVertex.AsEnumerable(source.First(predicate)); + } + + public static IEnumerable LastAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Last()); + } + + public static IEnumerable + LastAsQuery(this IEnumerable source, + Func predicate) + { + return HpcLinqVertex.AsEnumerable(source.Last(predicate)); + } + + public static IEnumerable SingleAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Single()); + } + + public static IEnumerable + SingleAsQuery(this IEnumerable source, + Func predicate) + { + return HpcLinqVertex.AsEnumerable(source.Single(predicate)); + } + + public static IEnumerable SumAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Sum()); + } + + public static IEnumerable SumAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Sum()); + } + + public static IEnumerable SumAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Sum()); + } + + public static IEnumerable SumAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Sum()); + } + + public static IEnumerable SumAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Sum()); + } + + public static IEnumerable SumAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Sum()); + } + + public static IEnumerable SumAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Sum()); + } + + public static IEnumerable SumAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Sum()); + } + + public static IEnumerable SumAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Sum()); + } + + public static IEnumerable SumAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Sum()); + } + + public static IEnumerable + SumAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Sum(selector)); + } + + public static IEnumerable + SumAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Sum(selector)); + } + + public static IEnumerable + SumAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Sum(selector)); + } + + public static IEnumerable + SumAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Sum(selector)); + } + + public static IEnumerable + SumAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Sum(selector)); + } + + public static IEnumerable + SumAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Sum(selector)); + } + + public static IEnumerable + SumAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Sum(selector)); + } + + public static IEnumerable + SumAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Sum(selector)); + } + + public static IEnumerable + SumAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Sum(selector)); + } + + public static IEnumerable + SumAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Sum(selector)); + } + + public static IEnumerable MinAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Min()); + } + + public static IEnumerable MinAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Min()); + } + + public static IEnumerable MinAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Min()); + } + + public static IEnumerable MinAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Min()); + } + + public static IEnumerable MinAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Min()); + } + + public static IEnumerable MinAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Min()); + } + + public static IEnumerable MinAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Min()); + } + + public static IEnumerable MinAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Min()); + } + + public static IEnumerable MinAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Min()); + } + + public static IEnumerable MinAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Min()); + } + + public static IEnumerable MinAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Min()); + } + + public static IEnumerable + MinAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Min(selector)); + } + + public static IEnumerable + MinAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Min(selector)); + } + + public static IEnumerable + MinAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Min(selector)); + } + + public static IEnumerable + MinAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Min(selector)); + } + + public static IEnumerable + MinAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Min(selector)); + } + + public static IEnumerable + MinAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Min(selector)); + } + + public static IEnumerable + MinAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Min(selector)); + } + + public static IEnumerable + MinAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Min(selector)); + } + + public static IEnumerable + MinAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Min(selector)); + } + + public static IEnumerable + MinAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Min(selector)); + } + + public static IEnumerable + MinAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Min(selector)); + } + + public static IEnumerable MaxAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Max()); + } + + public static IEnumerable MaxAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Max()); + } + + public static IEnumerable MaxAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Max()); + } + + public static IEnumerable MaxAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Max()); + } + + public static IEnumerable MaxAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Max()); + } + + public static IEnumerable MaxAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Max()); + } + + public static IEnumerable MaxAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Max()); + } + + public static IEnumerable MaxAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Max()); + } + + public static IEnumerable MaxAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Max()); + } + + public static IEnumerable MaxAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Max()); + } + + public static IEnumerable MaxAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Max()); + } + + public static IEnumerable + MaxAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Max(selector)); + } + + public static IEnumerable + MaxAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Max(selector)); + } + + public static IEnumerable + MaxAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Max(selector)); + } + + public static IEnumerable + MaxAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Max(selector)); + } + + public static IEnumerable + MaxAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Max(selector)); + } + + public static IEnumerable + MaxAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Max(selector)); + } + + public static IEnumerable + MaxAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Max(selector)); + } + + public static IEnumerable + MaxAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Max(selector)); + } + + public static IEnumerable + MaxAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Max(selector)); + } + + public static IEnumerable + MaxAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Max(selector)); + } + + public static IEnumerable + MaxAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Max(selector)); + } + + public static IEnumerable AverageAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Average()); + } + + public static IEnumerable AverageAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Average()); + } + + public static IEnumerable AverageAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Average()); + } + + public static IEnumerable AverageAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Average()); + } + + public static IEnumerable AverageAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Average()); + } + + public static IEnumerable AverageAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Average()); + } + + public static IEnumerable AverageAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Average()); + } + + public static IEnumerable AverageAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Average()); + } + + public static IEnumerable AverageAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Average()); + } + + public static IEnumerable AverageAsQuery(this IEnumerable source) + { + return HpcLinqVertex.AsEnumerable(source.Average()); + } + + public static IEnumerable + AverageAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Average(selector)); + } + + public static IEnumerable + AverageAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Average(selector)); + } + + public static IEnumerable + AverageAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Average(selector)); + } + + public static IEnumerable + AverageAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Average(selector)); + } + + public static IEnumerable + AverageAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Average(selector)); + } + + public static IEnumerable + AverageAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Average(selector)); + } + + public static IEnumerable + AverageAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Average(selector)); + } + + public static IEnumerable + AverageAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Average(selector)); + } + + public static IEnumerable + AverageAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Average(selector)); + } + + public static IEnumerable + AverageAsQuery(this IEnumerable source, + Func selector) + { + return HpcLinqVertex.AsEnumerable(source.Average(selector)); + } + + public static IEnumerable + AggregateAsQuery(this IEnumerable source, + Func func) + { + return HpcLinqVertex.AsEnumerable(source.Aggregate(func)); + } + + public static IEnumerable + AggregateAsQuery(this IEnumerable source, + TAccumulate seed, + Func func) + { + return HpcLinqVertex.AsEnumerable(source.Aggregate(seed, func)); + } + + public static IEnumerable + AggregateAsQuery(this IEnumerable source, + TAccumulate seed, + Func func, + Func resultSelector) + { + return HpcLinqVertex.AsEnumerable(source.Aggregate(seed, func, resultSelector)); + } + + public static IEnumerable + AssumeHashPartition(this IEnumerable source, + Func keySelector) + { + return source; + } + + // Operator: AssumeHashPartition + public static IEnumerable + AssumeHashPartition(this IEnumerable source, + Func keySelector, + IEqualityComparer comparer) + { + return source; + } + + public static IEnumerable + AssumeRangePartition(this IEnumerable source, + Func keySelector, + bool isDescending) + { + return source; + } + + // Operator: AssumeRangePartition + public static IEnumerable + AssumeRangePartition(this IEnumerable source, + Func keySelector, + IComparer comparer, + bool isDescending) + { + return source; + } + + public static IEnumerable + AssumeRangePartition(this IEnumerable source, + Func keySelector, + TKey[] rangeSeparators) + { + return source; + } + + public static IEnumerable + AssumeRangePartition(this IEnumerable source, + Func keySelector, + TKey[] rangeSeparators, + IComparer comparer) + { + return source; + } + + // Operator: AssumeOrderBy + public static IEnumerable + AssumeOrderBy(this IEnumerable source, + Func keySelector, + bool isDescending) + { + return source; + } + + public static IEnumerable + AssumeOrderBy(this IEnumerable source, + Func keySelector, + IComparer comparer, + bool isDescending) + { + return source; + } + + public static IMultiEnumerable + Fork(this IEnumerable source, + Func, IEnumerable>> mapper) + { + List resX = new List(); + List resY = new List(); + + IEnumerable> result = mapper(source); + foreach (ForkTuple item in result) + { + if (item.HasFirst) resX.Add(item.First); + if (item.HasSecond) resY.Add(item.Second); + } + return new MultiEnumerable(resX, resY); + } + + public static IMultiEnumerable + Fork(this IEnumerable source, + Func, IEnumerable>> mapper) + { + List resX = new List(); + List resY = new List(); + List resZ = new List(); + + IEnumerable> result = mapper(source); + foreach (ForkTuple item in result) + { + if (item.HasFirst) resX.Add(item.First); + if (item.HasSecond) resY.Add(item.Second); + if (item.HasThird) resZ.Add(item.Third); + } + return new MultiEnumerable(resX, resY, resZ); + } + + public static IMultiEnumerable + Fork(this IEnumerable source, + Func> mapper) + { + List resX = new List(); + List resY = new List(); + + foreach (TSource elem in source) + { + ForkTuple item = mapper(elem); + if (item.HasFirst) resX.Add(item.First); + if (item.HasSecond) resY.Add(item.Second); + } + return new MultiEnumerable(resX, resY); + } + + public static IMultiEnumerable + Fork(this IEnumerable source, + Func> mapper) + { + List resX = new List(); + List resY = new List(); + List resZ = new List(); + + foreach (TSource elem in source) + { + ForkTuple item = mapper(elem); + if (item.HasFirst) resX.Add(item.First); + if (item.HasSecond) resY.Add(item.Second); + if (item.HasThird) resZ.Add(item.Third); + } + return new MultiEnumerable(resX, resY, resZ); + } + + public static IMultiEnumerable + Fork(this IEnumerable source, + Func keySelector, + TKey[] keys) + { + List[] enumList = new List[keys.Length]; + Dictionary keyMap = new Dictionary(keys.Length); + for (int i = 0; i < keys.Length; i++) + { + enumList[i] = new List(); + keyMap.Add(keys[i], i); + } + foreach (TSource item in source) + { + int index; + if (keyMap.TryGetValue(keySelector(item), out index)) + { + enumList[index].Add(item); + } + } + return new MultiEnumerable(enumList); + } + + public static IEnumerable + LongTakeWhile(this IEnumerable source, + Func predicate) + { + long index = 0; + foreach (TSource element in source) + { + if (!predicate(element, index)) + { + yield break; + } + yield return element; + checked { index++; } + } + } + + public static IEnumerable + LongSkipWhile(this IEnumerable source, + Func predicate) + { + long index = -1; + bool yielding = false; + using (IEnumerator sourceEnum = source.GetEnumerator()) + { + while (sourceEnum.MoveNext()) + { + checked { index++; } + if (!predicate(sourceEnum.Current, index)) + { + yielding = true; + break; + } + } + + if (yielding) + { + do + { + yield return sourceEnum.Current; + } + while (sourceEnum.MoveNext()); + } + } + } + + public static IEnumerable FlattenGroups(IEnumerable> groups) + { + foreach (var g in groups) + { + foreach (var x in g) + { + yield return x; + } + } + } + + [FieldMapping("key", "Key")] + public static IGrouping + MakeHpcLinqGroup(K key, IEnumerable> groups) + { + return new HpcLinqGrouping(key, FlattenGroups(groups)); + } + + public static IEnumerable Flatten(IEnumerable> groups) + { + foreach (var g in groups) + { + foreach (var x in g) + { + yield return x; + } + } + } + + public static IEnumerable FlattenDistinct(IEnumerable> groups, + IEqualityComparer comparer) + { + return Flatten(groups).Distinct(comparer).ToArray(); + } + + public static IEnumerable> Offsets(IEnumerable counts, bool isLong) + { + int index = 0; + long offset = 0; + foreach (long count in counts) + { + yield return new IndexedValue(index, offset); + index++; + checked { offset += count; } + } + if (!isLong && (offset > Int32.MaxValue)) + { + throw new OverflowException(SR.IndexTooSmall); + } + } + + public static IEnumerable + WhereWithStartIndex(IEnumerable source, + IEnumerable> startIndex, + Func predicate) + { + int index = (int)startIndex.Single().Value; + foreach (TSource element in source) + { + if (predicate(element, index)) + { + yield return element; + } + checked { index++; } + } + } + + public static IEnumerable + LongWhere(this IEnumerable source, + Func predicate) + { + long index = 0; + foreach (TSource element in source) + { + if (predicate(element, index)) + { + yield return element; + } + checked { index++; } + } + } + + public static IEnumerable + LongWhereWithStartIndex(IEnumerable source, + IEnumerable> startIndex, + Func predicate) + { + long index = startIndex.Single().Value; + foreach (TSource element in source) + { + if (predicate(element, index)) + { + yield return element; + } + checked { index++; } + } + } + + public static IEnumerable + SelectWithStartIndex(IEnumerable source, + IEnumerable> startIndex, + Func selector) + { + int index = (int)startIndex.Single().Value; + foreach (TSource element in source) + { + yield return selector(element, index); + checked { index++; } + } + } + + public static IEnumerable + LongSelect(this IEnumerable source, + Func selector) + { + long index = 0; + foreach (TSource element in source) + { + yield return selector(element, index); + checked { index++; } + } + } + + public static IEnumerable + LongSelectWithStartIndex(IEnumerable source, + IEnumerable> startIndex, + Func selector) + { + long index = startIndex.Single().Value; + foreach (TSource element in source) + { + yield return selector(element, index); + checked { index++; } + } + } + + public static IEnumerable + SelectManyWithStartIndex(IEnumerable source, + IEnumerable> startIndex, + Func> selector) + { + int index = (int)startIndex.Single().Value; + foreach (TSource element in source) + { + foreach (TResult result in selector(element, index)) + { + yield return result; + } + checked { index++; } + } + } + + public static IEnumerable + LongSelectManyWithStartIndex( + IEnumerable source, + IEnumerable> startIndex, + Func> selector) + { + long index = startIndex.Single().Value; + foreach (TSource element in source) + { + foreach (TResult result in selector(element, index)) + { + yield return result; + } + checked { index++; } + } + } + + public static IEnumerable + SelectManyResultWithStartIndex( + IEnumerable source, + IEnumerable> startIndex, + Func> collectionSelector, + Func resultSelector) + { + int index = (int)startIndex.Single().Value; + foreach (TSource element in source) + { + foreach (TCollection result in collectionSelector(element, index)) + { + yield return resultSelector(element, result); + } + checked { index++; } + } + } + + public static IEnumerable LongSelectMany( + this IEnumerable source, + Func> selector) + { + long index = 0; + foreach (TSource element in source) + { + foreach (TResult result in selector(element, index)) + { + yield return result; + } + checked { index++; } + } + } + + public static IEnumerable + LongSelectMany( + this IEnumerable source, + Func> selector, + Func resultSelector) + { + long index = 0; + foreach (TSource element in source) + { + foreach (TCollection result in selector(element, index)) + { + yield return resultSelector(element, result); + } + checked { index++; } + } + } + + + public static IEnumerable + LongSelectManyResultWithStartIndex( + IEnumerable source, + IEnumerable> startIndex, + Func> collectionSelector, + Func resultSelector) + { + long index = startIndex.Single().Value; + foreach (TSource element in source) + { + foreach (TCollection result in collectionSelector(element, index)) + { + yield return resultSelector(element, result); + } + checked { index++; } + } + } + + private const int GroupSize = 1000; + public static IEnumerable, bool>> + GroupTakeWhile(IEnumerable source, Func pred) + { + List group = new List(GroupSize); + foreach (T elem in source) + { + if (pred(elem)) + { + if (group.Count == GroupSize) + { + yield return new Pair, bool>(group, true); + group = new List(GroupSize); + } + group.Add(elem); + } + else + { + yield return new Pair, bool>(group, false); + yield break; + } + } + + yield return new Pair, bool>(group, true); + } + + public static IEnumerable, bool>> + GroupIndexedTakeWhile(IEnumerable source, + IEnumerable> startIndex, + Func pred) + { + int currIdx = (int)startIndex.Single().Value; + List group = new List(GroupSize); + foreach (T elem in source) + { + if (pred(elem, currIdx)) + { + if (group.Count == GroupSize) + { + yield return new Pair, bool>(group, true); + group = new List(GroupSize); + } + group.Add(elem); + checked { currIdx++; } + } + else + { + yield return new Pair, bool>(group, false); + yield break; + } + } + + yield return new Pair, bool>(group, true); + } + + public static IEnumerable, bool>> + GroupIndexedLongTakeWhile(IEnumerable source, + IEnumerable> startIndex, + Func pred) + { + long currIdx = startIndex.Single().Value; + List group = new List(GroupSize); + foreach (T elem in source) + { + if (pred(elem, currIdx)) + { + if (group.Count == GroupSize) + { + yield return new Pair, bool>(group, true); + group = new List(GroupSize); + } + group.Add(elem); + currIdx++; + } + else + { + yield return new Pair, bool>(group, false); + yield break; + } + } + + yield return new Pair, bool>(group, true); + } + } +} diff --git a/LinqToDryad/DryadLinqIQueryable.cs b/LinqToDryad/DryadLinqIQueryable.cs new file mode 100644 index 0000000..0d7de3f --- /dev/null +++ b/LinqToDryad/DryadLinqIQueryable.cs @@ -0,0 +1,3528 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Globalization; +using System.Reflection; +using System.Linq.Expressions; +using System.Linq; +using System.Diagnostics; +using Microsoft.Research.DryadLinq.Internal; +using Microsoft.Research.Dryad.Hdfs; + + +namespace Microsoft.Research.DryadLinq +{ + // This class introduces some new operators into the expression tree. So far, + // there are two classes of new operators: + // 1. HashPartition, RangePartition, Merge + // 2. Apply + public static class HpcLinqQueryable + { + internal static bool IsLocalDebugSource(IQueryable source) + { + return !(source.Provider is DryadLinqProvider); + } + + public static IQueryable + LongWhere(this IQueryable source, + Expression> predicate) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (predicate == null) + { + throw new ArgumentNullException("predicate"); + } + if (IsLocalDebugSource(source)) + { + var result = HpcLinqEnumerable.LongWhere(source, predicate.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, result); + } + + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(predicate) } + )); + } + + public static IQueryable + LongSelect(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var result = HpcLinqEnumerable.LongSelect(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, result); + } + + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TResult)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + LongSelectMany(this IQueryable source, + Expression>> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var result = HpcLinqEnumerable.LongSelectMany(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, result); + } + + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TResult)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + LongSelectMany(this IQueryable source, + Expression>> selector, + Expression> resultSelector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (resultSelector == null) + { + throw new ArgumentNullException("resultSelector"); + } + if (IsLocalDebugSource(source)) + { + var result = HpcLinqEnumerable.LongSelectMany(source, selector.Compile(), resultSelector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, result); + } + + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TCollection), typeof(TResult)), + new Expression[] { source.Expression, Expression.Quote(selector), Expression.Quote(resultSelector) } + )); + } + + public static IQueryable + LongTakeWhile(this IQueryable source, + Expression> predicate) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (predicate == null) + { + throw new ArgumentNullException("predicate"); + } + if (IsLocalDebugSource(source)) + { + var result = HpcLinqEnumerable.LongTakeWhile(source, predicate.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, result); + } + + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(predicate) } + )); + } + + public static IQueryable + LongSkipWhile(this IQueryable source, + Expression> predicate) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (predicate == null) + { + throw new ArgumentNullException("predicate"); + } + if (IsLocalDebugSource(source)) + { + var result = HpcLinqEnumerable.LongSkipWhile(source, predicate.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, result); + } + + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(predicate) } + )); + } + + /// + /// Hash partition a dataset. + /// + /// The type of the records in the dataset + /// The type of the key on which the partition is based + /// The dataset to be partitioned + /// The function to extract the key from a record + /// An EqualityComparer on TKey to compare keys + /// The number of partitions to create + /// An IQueryable partitioned according to a key + public static IQueryable + HashPartition(this IQueryable source, + Expression> keySelector, + IEqualityComparer comparer, + int partitionCount) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (partitionCount <= 0) + { + throw new ArgumentOutOfRangeException("partitionCount"); + } + + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(comparer, typeof(IEqualityComparer)), + Expression.Constant(partitionCount, typeof(int)) } + )); + } + + /// + /// Hash partition a dataset. + /// + /// The type of the records in the dataset + /// The type of the key on which the partition is based + /// the dataset to be partitioned + /// The funtion to extract the key from a record + /// The number of partitioned to create + /// An IQueryable partitioned according to a key + public static IQueryable + HashPartition(this IQueryable source, + Expression> keySelector, + int partitionCount) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (partitionCount <= 0) + { + throw new ArgumentOutOfRangeException("partitionCount"); + } + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(partitionCount, typeof(int)) } + )); + } + + /// + /// Hash partition a dataset. The number of resulting partitions is dynamically determined + /// at the runtime. + /// + /// The type of the records in the dataset + /// The type of the key on which the partition is based + /// the dataset to be partitioned + /// The function to extract the key from a record + /// An IQueryable partitioned according to a key + public static IQueryable + HashPartition(this IQueryable source, + Expression> keySelector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector) } + )); + } + + /// + /// Hash partition a dataset. The number of resulting partitions is dynamically determined + /// at the runtime. + /// + /// The type of the records in the dataset + /// The type of the key on which the partition is based + /// The dataset to be partitioned + /// The function to extract the key from a record + /// An IComparer on TKey to compare keys + /// An IQueryable partitioned according to a key + public static IQueryable + HashPartition(this IQueryable source, + Expression> keySelector, + IEqualityComparer comparer) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(comparer, typeof(IEqualityComparer)) } + )); + } + + /// + /// Range partition a dataset. The list of range keys are determined dynamically at + /// runtime. + /// + /// The type of the records in the dataset + /// The type of the key on which the partition is based + /// The dataset to be partitioned + /// The function to extract the key from a record + /// An IQueryable partitioned according to a key + public static IQueryable + RangePartition(this IQueryable source, + Expression> keySelector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector) } + )); + } + + /// + /// Range partition a dataset. The list of range keys are determined dynamically at + /// runtime. + /// + /// The type of the records in the dataset + /// The type of the key on which the partition is based + /// The dataset to be partitioned + /// The function to extract the key from a record + /// Number of partitions in the output dataset + /// An IQueryable partitioned according to a key + public static IQueryable + RangePartition(this IQueryable source, + Expression> keySelector, + int partitionCount) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (partitionCount <= 0) + { + throw new ArgumentOutOfRangeException("partitionCount"); + } + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(partitionCount, typeof(int)) } + )); + } + + /// + /// Range partition a dataset. The list of range keys are determined dynamically at + /// runtime. + /// + /// The type of the records in the dataset + /// The type of the key on which the partition is based + /// The dataset to be partitioned + /// The funtion to extract the key from a record + /// true if the partition keys are descending + /// An IQueryable partitioned according to a key + public static IQueryable + RangePartition(this IQueryable source, + Expression> keySelector, + bool isDescending) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(isDescending, typeof(bool)) } + )); + } + + public static IQueryable + RangePartition(this IQueryable source, + Expression> keySelector, + TKey[] rangeSeparators) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (rangeSeparators == null) + throw new ArgumentNullException("rangeSeparators"); + + // check that the range-keys are consistent. + bool? dummy; + if (!HpcLinqUtil.ComputeIsDescending(rangeSeparators, Comparer.Default, out dummy)) + { + throw new ArgumentException(SR.PartitionKeysAreNotConsistentlyOrdered, "rangeSeparators"); + } + + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(rangeSeparators, typeof(TKey[])) } + )); + } + + /// + /// Range partition a dataset. The list of range keys are determined dynamically at + /// runtime. + /// + /// The type of the records in the dataset + /// The type of the key on which the partition is based + /// The dataset to be partitioned + /// The funtion to extract the key from a record + /// true if the partition keys are descending + /// Number of partitions in the output dataset + /// An IQueryable partitioned according to a key + public static IQueryable + RangePartition(this IQueryable source, + Expression> keySelector, + bool isDescending, + int partitionCount ) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (partitionCount <= 0) + { + throw new ArgumentOutOfRangeException("partitionCount"); + } + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(isDescending, typeof(bool)), + Expression.Constant(partitionCount, typeof(int)) } + )); + } + + public static IQueryable + RangePartition(this IQueryable source, + Expression> keySelector, + IComparer comparer, + bool isDescending) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(comparer, typeof(IComparer)), + Expression.Constant(isDescending, typeof(bool)) } + )); + } + + public static IQueryable + RangePartition(this IQueryable source, + Expression> keySelector, + TKey[] rangeSeparators, + IComparer comparer) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (rangeSeparators == null) + throw new ArgumentNullException("rangeSeparators"); + + if (comparer == null && !TypeSystem.HasDefaultComparer(typeof(TKey))) + { + throw new DryadLinqException(HpcLinqErrorCode.ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable, + string.Format(SR.ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable, typeof(TKey))); + } + comparer = TypeSystem.GetComparer(comparer); + + // check that the range-keys are consistent. + bool? dummy; + if (!HpcLinqUtil.ComputeIsDescending(rangeSeparators, comparer, out dummy)) + { + throw new ArgumentException(SR.PartitionKeysAreNotConsistentlyOrdered, "rangeSeparators"); + } + + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(rangeSeparators, typeof(TKey[])), + Expression.Constant(comparer, typeof(IComparer)) } + )); + } + + public static IQueryable + RangePartition(this IQueryable source, + Expression> keySelector, + IComparer comparer, + bool isDescending, + int partitionCount) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (partitionCount <= 0) + { + throw new ArgumentOutOfRangeException("partitionCount"); + } + + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(comparer, typeof(IComparer)), + Expression.Constant(isDescending, typeof(bool)), + Expression.Constant(partitionCount, typeof(int))} + )); + } + + public static IQueryable + RangePartition(this IQueryable source, + Expression> keySelector, + TKey[] rangeSeparators, + IComparer comparer, + bool isDescending) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + + if (rangeSeparators == null) + throw new ArgumentNullException("rangeSeparators"); + + if (comparer == null && !TypeSystem.HasDefaultComparer(typeof(TKey))) + { + throw new DryadLinqException(HpcLinqErrorCode.ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable, + string.Format(SR.ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable, typeof(TKey))); + } + comparer = TypeSystem.GetComparer(comparer); + + // check that the range-keys are consistent. + bool? detectedDescending; + bool keysAreConsistent = HpcLinqUtil.ComputeIsDescending(rangeSeparators, comparer, out detectedDescending); + // Note: detectedDescending==null implies that we couldn't precisely tell (single element, repeated elements, etc). + + if (!keysAreConsistent) + { + throw new ArgumentException(SR.PartitionKeysAreNotConsistentlyOrdered, "rangeSeparators"); + } + + // and check that the actual direction of keys matches what the user said they wanted. + if (detectedDescending != null && detectedDescending != isDescending) + { + throw new ArgumentException(SR.IsDescendingIsInconsistent); + } + + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(rangeSeparators, typeof(TKey[])), + Expression.Constant(comparer, typeof(IComparer)), + Expression.Constant(isDescending, typeof(bool)) } + )); + + } + + /// + /// Compute applyFunc (source) + /// + /// The type of the records of the input dataset + /// The type of the records of the output dataset + /// The input dataset + /// The function to be applied to the input dataset + /// The result of computing applyFunc(source) + public static IQueryable + Apply(this IQueryable source, + Expression, IEnumerable>> applyFunc) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = applyFunc.Compile()(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(T1), typeof(T2)), + new Expression[] { source.Expression, + Expression.Quote(applyFunc) } + )); + } + + /// + /// Compute applyFunc(source1, source2) + /// + /// The type of the records of the first input dataset + /// The type of the records of the second input dataset + /// he type of the records of the output dataset + /// The first input dataset + /// The second input dataset + /// The function to be applied to the input datasets + /// The result of computing applyFunc(source1, source2) + public static IQueryable + Apply(this IQueryable source1, + IQueryable source2, + Expression, IEnumerable, IEnumerable>> applyFunc) + { + if (source1 == null) + { + throw new ArgumentNullException("source1"); + } + if (source2 == null) + { + throw new ArgumentNullException("source2"); + } + if (IsLocalDebugSource(source1)) + { + var q = applyFunc.Compile()(source1, source2).AsQueryable(); + return new DryadLinqLocalQuery(source1.Provider, q); + } + return source1.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(T1), typeof(T2), typeof(T3)), + new Expression[] { source1.Expression, + source2.Expression, + Expression.Quote(applyFunc) } + )); + } + + /// + /// Compute applyFunc on multiple sources + /// + /// The type of the records of input + /// The type of the records of output + /// The first input dataset + /// Other input datasets + /// The function to be applied to the input datasets + /// The result of computing applyFunc(source,pieces) + public static IQueryable + Apply(this IQueryable source, + IQueryable[] otherSources, + Expression[], IEnumerable>> applyFunc) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + IQueryable[] allSources = new IQueryable[otherSources.Length + 1]; + allSources[0] = source; + for (int i = 0; i < otherSources.Length; ++i) + { + allSources[i + 1] = otherSources[i]; + } + var q = applyFunc.Compile()(allSources).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + + Expression[] others = new Expression[otherSources.Length]; + for (int i = 0; i < otherSources.Length; i++) + { + others[i] = otherSources[i].Expression; + } + + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(T1), typeof(T2)), + new Expression[] { source.Expression, + Expression.NewArrayInit(typeof(IQueryable), others), + Expression.Quote(applyFunc) } + )); + } + + /// + /// Compute applyFunc (source) + /// + /// The type of the records of the input dataset + /// The type of the records of the output dataset + /// The input dataset + /// The function to be applied to the input dataset + /// The result of computing applyFunc(source) + public static IQueryable + ApplyPerPartition( + this IQueryable source, + Expression, IEnumerable>> applyFunc) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = applyFunc.Compile()(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(T1), typeof(T2)), + new Expression[] { source.Expression, + Expression.Quote(applyFunc) } + )); + } + + /// + /// Compute applyFunc(source1, source2) + /// + /// The type of the records of the first input dataset + /// The type of the records of the second input dataset + /// he type of the records of the output dataset + /// The first input dataset + /// The second input dataset + /// The function to be applied to the input datasets + /// True if only distributive over the first dataset + /// The result of computing applyFunc(source1, source2) + public static IQueryable + ApplyPerPartition( + this IQueryable source1, + IQueryable source2, + Expression, IEnumerable, IEnumerable>> applyFunc, + bool isFirstOnly) + { + if (source1 == null) + { + throw new ArgumentNullException("source1"); + } + if (source2 == null) + { + throw new ArgumentNullException("source2"); + } + if (IsLocalDebugSource(source1)) + { + var q = applyFunc.Compile()(source1, source2).AsQueryable(); + return new DryadLinqLocalQuery(source1.Provider, q); + } + return source1.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(T1), typeof(T2), typeof(T3)), + new Expression[] { source1.Expression, + source2.Expression, + Expression.Quote(applyFunc), + Expression.Constant(isFirstOnly) } + )); + } + + /// + /// Compute applyFunc on multiple sources + /// + /// The type of the records of input + /// The type of the records of output + /// The first input dataset + /// Other input datasets + /// The function to be applied to the input datasets + /// True if only distributive over the first dataset + /// The result of computing applyFunc(source,pieces) + public static IQueryable + ApplyPerPartition( + this IQueryable source, + IQueryable[] otherSources, + Expression[], IEnumerable>> applyFunc, + bool isFirstOnly) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + IQueryable[] allSources = new IQueryable[otherSources.Length + 1]; + allSources[0] = source; + for (int i = 0; i < otherSources.Length; ++i) + { + allSources[i + 1] = otherSources[i]; + } + var q = applyFunc.Compile()(allSources).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + + Expression[] others = new Expression[otherSources.Length]; + for (int i = 0; i < otherSources.Length; i++) + { + others[i] = otherSources[i].Expression; + } + + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(T1), typeof(T2)), + new Expression[] { source.Expression, + Expression.NewArrayInit(typeof(IQueryable), others), + Expression.Quote(applyFunc), + Expression.Constant(isFirstOnly) } + )); + } + + /// + /// Apply a function on every sliding window on the input sequence of records. + /// + /// The type of the input records + /// The type of the output records + /// The input dataset + /// The function to apply to every sliding window + /// The size of the window + /// + public static IQueryable + SlidingWindow(this IQueryable source, + Expression, T2>> procFunc, + Int32 windowSize) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (windowSize < 2) + { + throw new DryadLinqException(SR.WindowSizeMustyBeGTOne); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SlidingWindow(source, procFunc.Compile(), windowSize); + return new DryadLinqLocalQuery(source.Provider, q.AsQueryable()); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(T1), typeof(T2)), + new Expression[] { source.Expression, + Expression.Quote(procFunc), + Expression.Constant(windowSize, typeof(int))} + )); + } + + public static IQueryable + SelectWithPartitionIndex(this IQueryable source, + Expression> procFunc) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SelectWithPartitionIndex(source, procFunc.Compile()); + return new DryadLinqLocalQuery(source.Provider, q.AsQueryable()); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(T1), typeof(T2)), + new Expression[] { source.Expression, + Expression.Quote(procFunc) } + )); + } + + public static IQueryable + ApplyWithPartitionIndex(this IQueryable source, + Expression, int, IEnumerable>> procFunc) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = procFunc.Compile()(source, 0); + return new DryadLinqLocalQuery(source.Provider, q.AsQueryable()); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(T1), typeof(T2)), + new Expression[] { source.Expression, + Expression.Quote(procFunc) } + )); + } + + public static IQueryable AnyAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AnyAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression } + )); + } + + public static IQueryable AnyAsQuery(this IQueryable source, + Expression> predicate) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (predicate == null) + { + throw new ArgumentNullException("predicate"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AnyAsQuery(source, predicate.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(predicate) } + )); + } + + public static IQueryable AllAsQuery(this IQueryable source, + Expression> predicate) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (predicate == null) + { + throw new ArgumentNullException("predicate"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AllAsQuery(source, predicate.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(predicate) } + )); + } + + public static IQueryable CountAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.CountAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression } + )); + } + + public static IQueryable CountAsQuery(this IQueryable source, + Expression> predicate) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (predicate == null) + { + throw new ArgumentNullException("predicate"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.CountAsQuery(source, predicate.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(predicate) } + )); + } + + public static IQueryable LongCountAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.LongCountAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression } + )); + } + + public static IQueryable LongCountAsQuery(this IQueryable source, + Expression> predicate) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (predicate == null) + { + throw new ArgumentNullException("predicate"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.LongCountAsQuery(source, predicate.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(predicate) } + )); + } + + public static IQueryable + ContainsAsQuery(this IQueryable source, TSource item) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.ContainsAsQuery(source, item).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Constant(item, typeof(TSource)) } + )); + } + + public static IQueryable + ContainsAsQuery(this IQueryable source, + TSource item, + IEqualityComparer comparer) { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.ContainsAsQuery(source, item, comparer).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, + Expression.Constant(item, typeof(TSource)), + Expression.Constant(comparer, typeof(IEqualityComparer)) } + )); + } + + private static Expression GetSourceExpression(IEnumerable source) + { + IQueryable q = source as IQueryable; + if (q != null) return q.Expression; + return Expression.Constant(source.ToArray(), typeof(TSource[])); + } + + public static IQueryable + SequenceEqualAsQuery(this IQueryable source1, + IEnumerable source2) + { + if (source1 == null) + { + throw new ArgumentNullException("source1"); + } + if (source2 == null) + { + throw new ArgumentNullException("source2"); + } + if (IsLocalDebugSource(source1)) + { + var q = HpcLinqEnumerable.SequenceEqualAsQuery(source1, source2).AsQueryable(); + return new DryadLinqLocalQuery(source1.Provider, q); + } + return source1.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source1.Expression, GetSourceExpression(source2) } + )); + } + + public static IQueryable + SequenceEqualAsQuery(this IQueryable source1, + IEnumerable source2, + IEqualityComparer comparer) + { + if (source1 == null) + { + throw new ArgumentNullException("source1"); + } + if (source2 == null) + { + throw new ArgumentNullException("source2"); + } + if (IsLocalDebugSource(source1)) + { + var q = HpcLinqEnumerable.SequenceEqualAsQuery(source1, source2, comparer).AsQueryable(); + return new DryadLinqLocalQuery(source1.Provider, q); + } + return source1.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { + source1.Expression, + GetSourceExpression(source2), + Expression.Constant(comparer, typeof(IEqualityComparer)) + } + )); + } + + public static IQueryable FirstAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.FirstAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression } + )); + } + + public static IQueryable + FirstAsQuery(this IQueryable source, + Expression> predicate) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (predicate == null) + { + throw new ArgumentNullException("predicate"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.FirstAsQuery(source, predicate.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(predicate) } + )); + } + + public static IQueryable LastAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.LastAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression } + )); + } + + public static IQueryable + LastAsQuery(this IQueryable source, + Expression> predicate) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (predicate == null) + { + throw new ArgumentNullException("predicate"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.LastAsQuery(source, predicate.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(predicate) } + )); + } + + public static IQueryable SingleAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SingleAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression } + )); + } + + public static IQueryable + SingleAsQuery(this IQueryable source, + Expression> predicate) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (predicate == null) + { + throw new ArgumentNullException("predicate"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SingleAsQuery(source, predicate.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(predicate) } + )); + } + + public static IQueryable MinAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.MinAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression } + )); + } + + public static IQueryable + MinAsQuery(this IQueryable source, Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.MinAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TResult)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable MaxAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.MaxAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression } + )); + } + + public static IQueryable + MaxAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.MaxAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TResult)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable SumAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable SumAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable SumAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable SumAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable SumAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable SumAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable SumAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable SumAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable SumAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable SumAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable + SumAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + SumAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + SumAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + SumAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + SumAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + SumAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + SumAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + SumAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + SumAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + SumAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.SumAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable AverageAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable AverageAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable AverageAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable AverageAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable AverageAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable AverageAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable AverageAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable AverageAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable AverageAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable AverageAsQuery(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()), + new Expression[] { source.Expression } + )); + } + + public static IQueryable + AverageAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + AverageAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + AverageAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + AverageAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + AverageAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + AverageAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + AverageAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + AverageAsQuery(this IQueryable source, + Expression> selector) { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + AverageAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + AverageAsQuery(this IQueryable source, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AverageAsQuery(source, selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(selector) } + )); + } + + public static IQueryable + AggregateAsQuery(this IQueryable source, + Expression> func) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (func == null) + { + throw new ArgumentNullException("func"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AggregateAsQuery(source, func.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, Expression.Quote(func) } + )); + } + + public static IQueryable + AggregateAsQuery(this IQueryable source, + TAccumulate seed, + Expression> func) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (func == null) + { + throw new ArgumentNullException("func"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AggregateAsQuery(source, seed, func.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TAccumulate)), + new Expression[] { source.Expression, Expression.Constant(seed), Expression.Quote(func) } + )); + } + + public static IQueryable + AggregateAsQuery(this IQueryable source, + TAccumulate seed, + Expression> func, + Expression> selector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (func == null) + { + throw new ArgumentNullException("func"); + } + if (selector == null) + { + throw new ArgumentNullException("selector"); + } + if (IsLocalDebugSource(source)) + { + var q = HpcLinqEnumerable.AggregateAsQuery(source, seed, func.Compile(), selector.Compile()).AsQueryable(); + return new DryadLinqLocalQuery(source.Provider, q); + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TAccumulate), typeof(TResult)), + new Expression[] { source.Expression, Expression.Constant(seed), Expression.Quote(func), Expression.Quote(selector) } + )); + } + + /// + /// Instruct DryadLINQ to assume that the dataset is hash partitioned. + /// + /// The type of the records of the dataset + /// The type of the keys on which the partition is based + /// The dataset + /// The function to extract the key from a record + /// The same dataset as input + public static IQueryable + AssumeHashPartition(this IQueryable source, + Expression> keySelector) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector) } + )); + } + + public static IQueryable + AssumeHashPartition(this IQueryable source, + Expression> keySelector, + IEqualityComparer comparer) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(comparer, typeof(IEqualityComparer)) } + )); + } + + /// + /// Instruct DryadLINQ to assume that the dataset is range partitioned. + /// + /// The type of the records of the dataset + /// The type of the key on which partition is based + /// The dataset + /// The function to extract the key from a record + /// true if the partition keys are ordered descendingly + /// The same dataset as input + public static IQueryable + AssumeRangePartition(this IQueryable source, + Expression> keySelector, + bool isDescending) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(isDescending) } + )); + } + + public static IQueryable + AssumeRangePartition(this IQueryable source, + Expression> keySelector, + IComparer comparer, + bool isDescending) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(comparer, typeof(IComparer)), + Expression.Constant(isDescending) } + )); + } + + public static IQueryable + AssumeRangePartition(this IQueryable source, + Expression> keySelector, + TKey[] rangeSeparators) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (rangeSeparators == null) + { + throw new ArgumentNullException("rangeSeparators"); + } + if (IsLocalDebugSource(source)) + { + return source; + } + + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(rangeSeparators) } + )); + } + + public static IQueryable + AssumeRangePartition(this IQueryable source, + Expression> keySelector, + TKey[] rangeSeparators, + IComparer comparer) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (rangeSeparators == null) + { + throw new ArgumentNullException("rangeSeparators"); + } + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(rangeSeparators), + Expression.Constant(comparer, typeof(IComparer)) } + )); + } + + /// + /// Instruct DryadLINQ to assume that each partition of the dataset is ordered. A dataset + /// is ordered if it is range partitioned and each partition of it is ordered on the same + /// key. + /// + /// The type of the recrods of the dataset + /// The type of the key on which partition is based + /// The dataset + /// The function to extract the key from a record + /// true if the order is descending + /// The same dataset as input + public static IQueryable + AssumeOrderBy(this IQueryable source, + Expression> keySelector, + bool isDescending) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(isDescending) } + )); + } + + public static IQueryable + AssumeOrderBy(this IQueryable source, + Expression> keySelector, + IComparer comparer, + bool isDescending) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (IsLocalDebugSource(source)) + { + return source; + } + return source.Provider.CreateQuery( + Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(comparer, typeof(IComparer)), + Expression.Constant(isDescending) } + )); + } + + public static IMultiQueryable + Fork(this IQueryable source, + Expression, IEnumerable>>> mapper) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + IMultiEnumerable enumerables = HpcLinqEnumerable.Fork(source, mapper.Compile()); + return new MultiQueryable(source, enumerables); + } + + Expression expr = Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(T), typeof(R1), typeof(R2)), + new Expression[] { source.Expression, + Expression.Quote(mapper) } + ); + return new MultiQueryable(source, expr); + } + + public static IMultiQueryable + Fork(this IQueryable source, + Expression, IEnumerable>>> mapper) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + IMultiEnumerable enumerables = HpcLinqEnumerable.Fork(source, mapper.Compile()); + return new MultiQueryable(source, enumerables); + } + + Expression expr = Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(T), typeof(R1), typeof(R2), typeof(R3)), + new Expression[] { source.Expression, + Expression.Quote(mapper) } + ); + return new MultiQueryable(source, expr); + } + + /// + /// Compute two output datasets from one input dataset. + /// + /// The type of records of input dataset + /// The type of records of first output dataset + /// The type of records of second output dataset + /// The input dataset + /// The function applied to each record of the input + /// An IMultiQueryable for the two output dataset + public static IMultiQueryable + Fork(this IQueryable source, + Expression>> mapper) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + IMultiEnumerable enumerables = HpcLinqEnumerable.Fork(source, mapper.Compile()); + return new MultiQueryable(source, enumerables); + } + + Expression expr = Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(T), typeof(R1), typeof(R2)), + new Expression[] { source.Expression, + Expression.Quote(mapper) } + ); + return new MultiQueryable(source, expr); + } + + public static IMultiQueryable + Fork(this IQueryable source, + Expression>> mapper) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (IsLocalDebugSource(source)) + { + IMultiEnumerable enumerables = HpcLinqEnumerable.Fork(source, mapper.Compile()); + return new MultiQueryable(source, enumerables); + } + + Expression expr = Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(T), typeof(R1), typeof(R2), typeof(R3)), + new Expression[] { source.Expression, + Expression.Quote(mapper) } + ); + return new MultiQueryable(source, expr); + } + + public static IKeyedMultiQueryable + Fork(this IQueryable source, + Expression> keySelector, + TKey[] keys) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (keySelector == null) + { + throw new ArgumentNullException("keySelector"); + } + if (keys == null) + { + throw new ArgumentNullException("keys"); + } + if (IsLocalDebugSource(source)) + { + IMultiEnumerable enumerables = HpcLinqEnumerable.Fork(source, keySelector.Compile(), keys); + return new MultiQueryable(source, keys, enumerables); + } + + Expression expr = Expression.Call( + null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource), typeof(TKey)), + new Expression[] { source.Expression, + Expression.Quote(keySelector), + Expression.Constant(keys, typeof(TKey[])) } + ); + return new MultiQueryable(source, keys, expr); + } + + internal static IQueryable ForkChoose(this IMultiQueryable source, int index) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + return source.Provider.CreateQuery( + Expression.Call(null, + ((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(T)), + new Expression[] { source.Expression, + Expression.Constant(index) })); + } + + /// + /// Specifies a DSC stream to be populated with data during query execution. + /// + /// The type of the records of the table + /// The data source + /// A DSC service + /// A DSC stream name + /// A query representing the output data. + + // Note: for both cluster&LocalDebug, we add a node to the query-tree + // Submit/Materialize will process the ToDsc call in both cases. + // This is good for consistency of LocalDebug & Cluser modes -- ie ToDsc is lazy in both cases. + public static IQueryable + ToDsc(this IQueryable source, string streamName) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (streamName == null) + { + throw new ArgumentNullException("streamName"); + } + if (!(source.Provider is DryadLinqProviderBase)) + { + //@@TODO[p2]: a "single-input" resource string should be used + throw new ArgumentException(String.Format(SR.NotAHpcLinqQuery, 0), "source"); + } + HpcLinqContext context = HpcLinqContext.GetContext(source.Provider as DryadLinqProviderBase); + + MethodInfo nongenericMethod = + typeof(HpcLinqQueryable) + .GetMethod(ReflectedNames.DryadLinqIQueryable_ToDscWorker, BindingFlags.Static | BindingFlags.NonPublic); + MethodInfo targetMethod = nongenericMethod.MakeGenericMethod(typeof(TSource)); + + return source.Provider.CreateQuery( + Expression.Call( + null, + targetMethod, + //((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, + Expression.Constant(context , typeof(HpcLinqContext)), + Expression.Constant(streamName, typeof(string)) } + )); + } + + /// + /// Specifies a HDFS stream to be populated with data during query execution. + /// + /// The type of the records of the table + /// The data source + /// A HDFS service + /// A HDFS stream name + /// A query representing the output data. + + // Note: for both cluster&LocalDebug, we add a node to the query-tree + // Submit/Materialize will process the ToHdfs call in both cases. + // This is good for consistency of LocalDebug & Cluser modes -- ie ToHdfs is lazy in both cases. + public static IQueryable ToHdfs(this IQueryable source, string streamName) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + if (streamName == null) + { + throw new ArgumentNullException("streamName"); + } + if (!(source.Provider is DryadLinqProviderBase)) + { + //@@TODO[p2]: a "single-input" resource string should be used + throw new ArgumentException(String.Format(SR.NotAHpcLinqQuery, 0), "source"); + } + HpcLinqContext context = HpcLinqContext.GetContext(source.Provider as DryadLinqProviderBase); + + MethodInfo nongenericMethod = + typeof(HpcLinqQueryable) + .GetMethod(ReflectedNames.DryadLinqIQueryable_ToHdfsWorker, BindingFlags.Static | BindingFlags.NonPublic); + MethodInfo targetMethod = nongenericMethod.MakeGenericMethod(typeof(TSource)); + + + return source.Provider.CreateQuery( + Expression.Call( + null, + targetMethod, + //((MethodInfo)MethodBase.GetCurrentMethod()).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, + Expression.Constant(context , typeof(HpcLinqContext)), + Expression.Constant(streamName, typeof(string)) } + )); + } + + /// + /// Specifies a DSC stream to be populated with data during query execution. + /// + /// + /// This method is not intended for direct use. To prepare a plain enumerable for + /// use with DryadLinq, call AsDryadQuery or AsDryadQueryPartitions. + /// + /// The type of the records of the table + /// The data source + /// A DSC service + /// A DSC stream name + /// A query representing the output data. + // @@TODO[P1]: The above comments require that AsDryadQuery/AsDryadQueryPartitions exists.. that is a P1 work item. + // In the meantime, a basic version of those methods is available in DryadLinqUnitTests. + // + // *visited by "LocalDebug-mode GetEnumerator()" + // -- in cluster mode, we never enter this method.. DryadQueryGen just inspects the MethodCallExpression + // -- in localDebug mode, we do enter this method, particularly via Linq-to-objects GetEnumerator. + // + // Behavior: + // (LocalDebug mode) q.ToDsc().GetEnumerator() -> throws, as execution never occurred, so the output was not created. + // (LocalDebug mode) var q1 = q.ToDsc();q1.Submit(); foreach(var x in q1) -> succeeds (as submit creates the data) + // + internal static IQueryable + ToDscWorker(this IEnumerable source, HpcLinqContext context, string streamName) + { + // We want the following query to succeed reliably for both cluster and LocalDebug : + // var q1 = q.ToDsc(".."); + // q1.Submit(runtime); + // foreach(var x in q1){ ... } // calls q1.GetEnumerator() + // + // For cluster execution, q1 becomes data-backed and so the call to GetEnumerator succeeds. + // For local execution, we don't attach a "databacking DLQ" to the source + // - if we did find a "data-backing DLQ" for the DSC node we would return that node. + // - but we don't, so we fake it by looking for the existence of the output and if it exists, we + // assume that we created it and treat it as though we had tracked it as the data-backing. + // + // @@BUG(low-severity): This logic could be spoofed if the dsc output already exists but wasn't created via + // the expected call to Submit()/Materialize().. A possible cause is a previous run of almost exactly the same query. + // Mitigation: User-education to not rely on auto-delete of output streams will avoid the issue and is a good idea anyway. + // + // (fix idea) reimplement all the Linq-to-objects operators ourseleves so that we have more control + // during LocalDebug-mode. + // + // (fix idea) Introduce a new class DryadLinqQueryLocal. It uses AsQueryable.Provider as query provider + // and has a data-backing field. May only be feasible if DryadLinqQuery is internal. + IQueryable fakedDataBackingDLQ = null; + try + { + if (context.DscService.FileSetExists(streamName)) + { + // if so, try to make a DLQ from it. + fakedDataBackingDLQ = context.FromDsc(streamName); + } + } + catch (DryadLinqException) + { + //suppress.. we expect this to occur if the dsc stream does exist, but cannot be loaded as a DLQ + } + + if (fakedDataBackingDLQ != null) + return fakedDataBackingDLQ; + + throw new DryadLinqException(HpcLinqErrorCode.ToDscUsedIncorrectly, String.Format(SR.ToDscUsedIncorrectly)); + } + + + /// + /// Specifies a HDFS stream to be populated with data during query execution. + /// + /// + /// This method is not intended for direct use. To prepare a plain enumerable for + /// use with DryadLinq, call AsDryadQuery or AsDryadQueryPartitions. + /// + /// The type of the records of the table + /// The data source + /// A HDFS service + /// A HDFS stream name + /// A query representing the output data. + // @@TODO[P1]: The above comments require that AsDryadQuery/AsDryadQueryPartitions exists.. that is a P1 work item. + // In the meantime, a basic version of those methods is available in DryadLinqUnitTests. + // + // *visited by "LocalDebug-mode GetEnumerator()" + // -- in cluster mode, we never enter this method.. DryadQueryGen just inspects the MethodCallExpression + // -- in localDebug mode, we do enter this method, particularly via Linq-to-objects GetEnumerator. + // + // Behavior: + // (LocalDebug mode) q.ToHdfs().GetEnumerator() -> throws, as execution never occurred, so the output was not created. + // (LocalDebug mode) var q1 = q.ToHdfs();q1.Submit(); foreach(var x in q1) -> succeeds (as submit creates the data) + // + internal static IQueryable ToHdfsWorker(this IEnumerable source, HpcLinqContext context, string streamName) + { + // We want the following query to succeed reliably for both cluster and LocalDebug : + // var q1 = q.ToHdfs(".."); + // q1.Submit(runtime); + // foreach(var x in q1){ ... } // calls q1.GetEnumerator() + // + // For cluster execution, q1 becomes data-backed and so the call to GetEnumerator succeeds. + // For local execution, we don't attach a "databacking DLQ" to the source + // - if we did find a "data-backing DLQ" for the DSC node we would return that node. + // - but we don't, so we fake it by looking for the existence of the output and if it exists, we + // assume that we created it and treat it as though we had tracked it as the data-backing. + // + // @@BUG(low-severity): This logic could be spoofed if the dsc output already exists but wasn't created via + // the expected call to Submit()/Materialize().. A possible cause is a previous run of almost exactly the same query. + // Mitigation: User-education to not rely on auto-delete of output streams will avoid the issue and is a good idea anyway. + // + // (fix idea) reimplement all the Linq-to-objects operators ourseleves so that we have more control + // during LocalDebug-mode. + // + // (fix idea) Introduce a new class DryadLinqQueryLocal. It uses AsQueryable.Provider as query provider + // and has a data-backing field. May only be feasible if DryadLinqQuery is internal. + IQueryable fakedDataBackingDLQ = null; + try + { + bool exists; + using (HdfsInstance hdfs = new HdfsInstance(context.HdfsService)) + { + exists = hdfs.IsFileExists(streamName); + } + + if (exists) + { + // if so, try to make a DLQ from it. + fakedDataBackingDLQ = context.FromHdfs(streamName); + } + } + catch (DryadLinqException) + { + //suppress.. we expect this to occur if the dsc stream does exist, but cannot be loaded as a DLQ + } + + if (fakedDataBackingDLQ != null) + return fakedDataBackingDLQ; + + throw new DryadLinqException(HpcLinqErrorCode.ToHdfsUsedIncorrectly, String.Format(SR.ToHdfsUsedIncorrectly)); + } + + + + /// + /// Submits the query and then waits for the job to complete + /// + /// If the job completes in error or is cancelled. + /// If repeated errors occur while polling for status. + /// The type of the records of the table + /// The data source + /// Information about the execution job. + public static HpcLinqJobInfo SubmitAndWait(this IQueryable source) + { + HpcLinqJobInfo info = source.Submit(); + info.Wait(); + return info; + } + + /// + /// Submits the query for asynchronous execution. + /// + /// The type of the records of the table + /// The data source + /// The URI of the DSC file set to be created + /// Information about the execution job. + // @@TODO[P1]. Try to unify q.Submit() and the static Submit. They duplicate a fair bit of code. + // Changing this method to work with untyped IQueryable (and not mention TSource) should + // make it easy to forward the call to submit( new [] {source} ) + public static HpcLinqJobInfo Submit(this IQueryable source) + { + if (source == null) + { + throw new ArgumentNullException("source"); + } + + //Extract the context. + if (!(source.Provider is DryadLinqProviderBase)) + { + //@@TODO[p2]: a "single-input" resource string should be used + throw new ArgumentException(String.Format(SR.NotAHpcLinqQuery, 0), "source"); + } + HpcLinqContext context = HpcLinqContext.GetContext(source.Provider as DryadLinqProviderBase); + + if (IsLocalDebugSource(source)) + { + try + { + // if ToDsc is present, extract out the target, otherwise make an anonymous target + // then ingress the data directly to DSC. + + MethodCallExpression mcExpr1 = source.Expression as MethodCallExpression; + if (mcExpr1 != null && mcExpr1.Method.Name == ReflectedNames.DryadLinqIQueryable_ToDscWorker) + { + // visted by LocalDebug:: q2.ToDsc(...).Submit(...), eg Test3. + LocalDebug_ProcessToDscExpression(context, mcExpr1); + } + else + { + // visited by (LocalDebug mode) q-nonToDsc.Submit(); + string fileSetName = DataPath.MakeUniqueTemporaryDscFileSetName(); + + DscCompressionScheme outputScheme = context.Configuration.IntermediateDataCompressionScheme; + DryadLinqMetaData metadata = DryadLinqMetaData.ForLocalDebug(context, typeof(TSource), fileSetName, outputScheme ); + DataProvider.IngressTemporaryDataDirectlyToDsc(context, source, fileSetName, metadata, outputScheme); + } + + return new HpcLinqJobInfo(HpcLinqJobInfo.JOBID_LOCALDEBUG, null, null, null); + } + catch (Exception e) + { + throw new DryadLinqException(HpcLinqErrorCode.CreatingDscDataFromLocalDebugFailed, + String.Format(SR.CreatingDscDataFromLocalDebugFailed), e); + } + } + + if (!(source is DryadLinqQuery)) + { + //@@TODO[p2]: a "single-input" resource string should be used + throw new ArgumentException(String.Format(SR.NotAHpcLinqQuery, 0), "source"); + } + + // handle repeat submissions. + if (((DryadLinqQuery)source).IsDataBacked) + { + // this query has already been submitted. + throw new ArgumentException(string.Format(SR.AlreadySubmitted), "source"); + } + + // sanity check that we have a normal DryadLinqQuery + MethodCallExpression mcExpr = source.Expression as MethodCallExpression; + if (mcExpr == null) + { + throw new ArgumentException(String.Format(SR.AtLeastOneOperatorRequired, 0), "source"); + } + + if (!mcExpr.Method.IsStatic || + !TypeSystem.IsQueryOperatorCall(mcExpr)) + { + //@@TODO[p2]: a "single-input" resource string should be used + throw new ArgumentException(String.Format(SR.NotAHpcLinqQuery, 0), "source"); + } + + if (mcExpr.Method.Name != ReflectedNames.DryadLinqIQueryable_ToDscWorker && mcExpr.Method.Name != ReflectedNames.DryadLinqIQueryable_ToHdfsWorker) + { + // Support for non-ToDscQuery.Submit() + string tableName = DataPath.MakeUniqueTemporaryDscFileSetUri(context); + mcExpr = (MethodCallExpression) source.Provider.CreateQuery( + Expression.Call( + null, + typeof(HpcLinqQueryable).GetMethod(ReflectedNames.DryadLinqIQueryable_AnonymousDscPlaceholder, + BindingFlags.Static | BindingFlags.NonPublic).MakeGenericMethod(typeof(TSource)), + new Expression[] { source.Expression, + Expression.Constant(tableName, typeof(string))} + )).Expression; + } + + // Execute the queries + HpcLinqQueryGen dryadGen = new HpcLinqQueryGen(context, + ((DryadLinqQuery)source).GetVertexCodeGen(), + new Expression[] { mcExpr }); + DryadLinqQuery[] tables = dryadGen.InvokeDryad(); + + + tables[0].IsTemp = false; + ((DryadLinqQuery)source).BackingDataDLQ = tables[0]; + + int jobId = tables[0].QueryExecutor.JobSubmission.GetJobId(); + string[] targetUris = new [] {tables[0].DataSourceUri}; + return new HpcLinqJobInfo(jobId, context.Configuration.HeadNode, tables[0].QueryExecutor, targetUris); + } + + // inspect and process the trailing ToDsc() node in local-debug scenarios. + private static void LocalDebug_ProcessToDscExpression(HpcLinqContext context, MethodCallExpression mcExpr1) + { + // get the dsc target out of the ToDsc node. + DscService dsc = context.DscService; + string fileSetName = (string)((ConstantExpression)mcExpr1.Arguments[2]).Value; + + // evaluate the source query before the ToDsc node. + // WAS MethodCallExpression mce = ((MethodCallExpression)mcExpr1.Arguments[0]); + // DIDN't handle constant expression for plain-data. + Expression e = mcExpr1.Arguments[0]; + + DscCompressionScheme compressionScheme = context.Configuration.OutputDataCompressionScheme; + ExecuteLocalExpressionAndIngressToDsc(context, e, fileSetName, compressionScheme); + } + + // process the main data and push into DSC, used in local-debug scenarios. + private static void ExecuteLocalExpressionAndIngressToDsc(HpcLinqContext context, + Expression mce, + string fileSetName, + DscCompressionScheme compressionScheme) + { + ExpressionSimplifier> simplifier = new ExpressionSimplifier>(); + IEnumerable sourceData = simplifier.Eval(mce); + + // stuff the data into DSC + DryadLinqMetaData metadata = DryadLinqMetaData.ForLocalDebug(context, typeof(TSource), fileSetName, compressionScheme); + DataProvider.IngressDataDirectlyToDsc(context, sourceData, fileSetName, metadata, compressionScheme); + } + + // this method is never directly executed + // it's only use is as a placeholder in an expression tree for the scenario: query-nonDsc.Submit() + // This appears in expression trees in the same place that "ToDscWorker" would otherwise appear. + internal static DryadLinqQuery + AnonymousDscTarget__Placeholder(this IEnumerable source, string dscFileSetUri) + { + throw new InvalidOperationException(); + } + + /// + /// Submits a collection of HPC LINQ queries for execution. + /// + /// Queries to execute. + /// Job information for tracking the execution. + /// + /// Every item in sources must be an HPC LINQ IQueryable object that terminates with ToDsc() + /// Only one job will be executed, but the job will produce the output associated with each item in sources. + /// + public static HpcLinqJobInfo Submit(params IQueryable[] sources) + { + if (sources == null) + { + throw new ArgumentNullException("sources"); + } + + if(sources.Length == 0) + { + //@@TODO[p2]: localize message. + throw new ArgumentException("sources is empty", "sources"); + } + + HpcLinqContext commonContext = CheckSourcesAndGetCommonContext(sources); + return HpcLinqQueryable.Materialize(commonContext, sources); + } + + /// + /// Submits a collection of HPC LINQ queries for execution and waits for the job to complete/ + /// + /// If the job completes in error or is cancelled. + /// If repeated errors occur while polling for status. + /// The type of the records of the table + /// The data source + /// Information about the execution job. + /// + /// Every item in sources must be an HPC LINQ IQueryable object that terminates with ToDsc() + /// Only one job will be executed, but the job will produce the output associated with each item in sources. + /// + public static HpcLinqJobInfo SubmitAndWait(params IQueryable[] sources) + { + if (sources == null) + { + throw new ArgumentNullException("sources"); + } + + if (sources.Length == 0) + { + throw new ArgumentException("sources is empty", "sources"); + } + + HpcLinqContext commonContext = CheckSourcesAndGetCommonContext(sources); + HpcLinqJobInfo info = HpcLinqQueryable.Materialize(commonContext, sources); + info.Wait(); + return info; + } + + private static HpcLinqContext CheckSourcesAndGetCommonContext(IQueryable[] sources) + { + Debug.Assert(sources != null && sources.Length > 0); + for (int i = 0; i < sources.Length; i++) + { + if (sources[i] == null) + { + //@@TODO[p1]: localize + throw new ArgumentException(string.Format("An item in sources[] was null. sources[{0}]", i), "sources"); + } + + // sanity check that we have normal DryadLinqQuery objects + if (!(sources[i].Provider is DryadLinqProviderBase)) + { + throw new ArgumentException(String.Format(SR.NotAHpcLinqQuery, i), "sources"); + } + } + + //check for duplicate query objects + HashSet repeatedQueryDetector = new HashSet(); + for (int i = 0; i < sources.Length; i++) + { + var q = sources[i]; + if (repeatedQueryDetector.Contains(q)) + { + throw new ArgumentException(string.Format(SR.SameQuerySubmittedMultipleTimesInMaterialize), + string.Format("sources[{0}]", i)); + } + repeatedQueryDetector.Add(q); + } + + //Check the queries all use the same context + HpcLinqContext[] contexts = sources.Select(src => HpcLinqContext.GetContext(src.Provider as DryadLinqProviderBase)).ToArray(); + if (contexts.Distinct().Count() != 1) + { + //@@TODO[p1]: localize message + throw new ArgumentException("Each query must be created from the same HpcLinqContext object", "sources"); + } + + HpcLinqContext commonContext = contexts[0]; + return commonContext; + } + + /// + /// Force the execution of a list of queries and create the partitioned tables for the + /// results of the queries. + /// + /// The list of lazy tables to be materialized + /// A list of partitioned tables + internal static HpcLinqJobInfo Materialize(HpcLinqContext context, params IQueryable[] sources) + { + if (context.Configuration.LocalDebug) + { + //@@TODO[P1]: force streams to be be temporary if automatically named. Not doing this might be considered a bug. + foreach (var source in sources) + { + MethodCallExpression mcExpr1 = source.Expression as MethodCallExpression; + if (mcExpr1 == null) + { + //normally we have a method call, but IDryadLinqQueryable.HashPartition() will just return source in + //localDebug mode. hence ctx.Submit(src.HashPartition()) will just present src here. + Debug.Assert(source.Expression is ConstantExpression, + "mcExpr1 should be MethodCallExpression or ConstantExpression"); + } + + if (mcExpr1 != null && mcExpr1.Method.Name == ReflectedNames.DryadLinqIQueryable_ToDscWorker) + { + // visted by LocalDebug:: Materialize( {q2.ToDsc(...)} ), eg Test5. + var method = + typeof(HpcLinqQueryable) + .GetMethod(ReflectedNames.HpcLinqQueryable_LocalDebug_ProcessToDscExpression, + BindingFlags.NonPublic | BindingFlags.Static) + .MakeGenericMethod(source.ElementType); + try + { + method.Invoke(null, new object[] { context, mcExpr1 }); + } + catch (TargetInvocationException tie) + { + if (tie.InnerException != null) + throw tie.InnerException; // unwrap and rethrow original exception + else + throw; // this shouldn't occur.. but just in case. + } + } + else + { + // visted by LocalDebug:: Materialize( {q2.Non-ToDsc(...)} ), eg Test6. + string tableName = DataPath.MakeUniqueTemporaryDscFileSetName(); + DscCompressionScheme compressionScheme = context.Configuration.IntermediateDataCompressionScheme; + var method = + typeof(HpcLinqQueryable) + .GetMethod(ReflectedNames.HpcLinqQueryable_ExecuteLocalExpressionAndIngressToDsc, + BindingFlags.NonPublic | BindingFlags.Static) + .MakeGenericMethod(source.ElementType); + + try + { + method.Invoke(null, new object[] { context, source.Expression, tableName, compressionScheme }); + } + catch (TargetInvocationException tie) + { + if (tie.InnerException != null) + throw tie.InnerException; // unwrap and rethrow original exception + else + throw; // this shouldn't occur.. but just in case. + } + } + } + + return new HpcLinqJobInfo(HpcLinqJobInfo.JOBID_LOCALDEBUG, null, null, null); + } + else + { + // If the sources were not terminated with ToDsc(), do this now and provide a temporary output name + Expression[] qList = new Expression[sources.Length]; + bool[] isTemps = new bool[sources.Length]; + for (int i = 0; i < sources.Length; i++) + { + isTemps[i] = false; + MethodCallExpression mcExpr = sources[i].Expression as MethodCallExpression; + + // sanity check that we have a normal DryadLinqQuery + if (!(sources[i] is DryadLinqQuery)) + { + throw new ArgumentException(String.Format(SR.NotAHpcLinqQuery, i), + string.Format("sources[{0}]", i)); + } + + // check that the query has not already been submitted. + if (((DryadLinqQuery)sources[i]).IsDataBacked) + { + throw new ArgumentException(string.Format(SR.AlreadySubmittedInMaterialize), + string.Format("sources[{0}]", i)); + } + + // more checks that we have a normal, unsubmitted, non-trivial DryadLinqQuery + if (!(sources[i] is DryadLinqQuery) || + mcExpr == null || + !mcExpr.Method.IsStatic || + !TypeSystem.IsQueryOperatorCall(mcExpr)) + { + throw new ArgumentException(String.Format(SR.NotAHpcLinqQuery, i), + string.Format("sources[{0}]", i)); + } + + if (mcExpr.Method.Name != ReflectedNames.DryadLinqIQueryable_ToDscWorker) + { + isTemps[i] = true; + string tableName = DataPath.MakeUniqueTemporaryDscFileSetUri(context); + MethodInfo minfo = + typeof(HpcLinqQueryable) + .GetMethod(ReflectedNames.DryadLinqIQueryable_AnonymousDscPlaceholder, + BindingFlags.Static | BindingFlags.NonPublic); + Type elemType = mcExpr.Type.GetGenericArguments()[0]; + minfo = minfo.MakeGenericMethod(elemType); + mcExpr = Expression.Call(minfo, mcExpr, Expression.Constant(tableName)); + } + qList[i] = mcExpr; + } + + // Normal cluster execution -- prepare and submit the query to Dryad. + // Execute the queries, which are all now terminated by ToDsc() + VertexCodeGen vertexCodeGen = ((DryadLinqQuery)sources[0]).GetVertexCodeGen(); + HpcLinqQueryGen queryGen = new HpcLinqQueryGen(context, vertexCodeGen, qList); + DryadLinqQuery[] tables = queryGen.InvokeDryad(); + + // Store the results in the queries + for (int j = 0; j < sources.Length; j++) + { + tables[j].IsTemp = isTemps[j]; + ((DryadLinqQuery)sources[j]).BackingDataDLQ = tables[j]; + } + + // return a runtimeInfo. + int jobId = tables[0].QueryExecutor.JobSubmission.GetJobId(); + string[] targetUris = new string[tables.Length]; + for (int i = 0; i < tables.Length; i++) + { + targetUris[i] = tables[i].DataSourceUri; + } + return new HpcLinqJobInfo(jobId, context.Configuration.HeadNode, tables[0].QueryExecutor, targetUris); + } + } + } +} diff --git a/LinqToDryad/DryadLinqJobSubmission.cs b/LinqToDryad/DryadLinqJobSubmission.cs new file mode 100644 index 0000000..3817604 --- /dev/null +++ b/LinqToDryad/DryadLinqJobSubmission.cs @@ -0,0 +1,505 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.IO.Compression; +using System.Linq; +using System.Net; +using System.Security.Principal; +using System.Text; +using System.Xml; +using System.Text.RegularExpressions; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + /// + /// Where is the JM process executed? + /// + internal enum ExecutionKind + { + /// + /// The JM has been submitted to the job scheduler. + /// + JobScheduler, + + /// + /// The JM is run on the local machine. + /// +// LocalJM, + + /// + /// The query is run using local debug. + /// + LocalDebug + } + + /// + /// This class encapsulates the means to execute in the background a collection of queries. + /// + internal class JobExecutor + { + /// + /// The jobSubmission is only used if we use a job scheduler. + /// + private IHpcLinqJobSubmission jobSubmission; + /// + /// The Process is used only if we don't use a job scheduler. + /// +// private Process dryadProc; + /// + /// Where is the query being executed? + /// + private ExecutionKind executionKind; + /// + /// Where is the query being executed? + /// + internal ExecutionKind ExecutionKind + { + get { return this.executionKind; } + } + /// + /// Error message when job fails. + /// +// private string errorMsg; + /// + /// Keep status here when it no longer changes. + /// + private JobStatus currentStatus; + private HpcLinqContext m_context; + + /// + /// Create a new job executor object. + /// + public JobExecutor(HpcLinqContext context) + { + // use a new job submission object for each query +// this.errorMsg = ""; + this.m_context = context; + this.currentStatus = JobStatus.NotSubmitted; + if (context.Runtime is HpcQueryRuntime) + { + YarnJobSubmission job = new YarnJobSubmission(context); +// job.LocalJM = false; + job.Initialize(); + this.executionKind = ExecutionKind.JobScheduler; + this.jobSubmission = job; + } + else + { + throw new DryadLinqException(HpcLinqErrorCode.UnsupportedSchedulerType, + String.Format(SR.UnsupportedSchedulerType, context.Runtime)); + } + +#if REMOVE + case SchedulerKind.LocalJM: + { + HpcJobSubmission job = new HpcJobSubmission(runtime); + job.LocalJM = true; + job.Initialize(); + this.executionKind = ExecutionKind.JobScheduler; + this.jobSubmission = job; + DryadLinq.SchedulerType = SchedulerKind.Hpc; + break; + } +#endif + } + + /// + /// Add a specified resource to the dryad computation. + /// + /// Resource to add. + /// The pathname to the resource to use in the xml plan. + public string AddResource(string resource) + { + switch (this.executionKind) + { + case ExecutionKind.JobScheduler: + { + this.jobSubmission.AddLocalFile(resource); + FileInfo resourceInfo = new FileInfo(resource); + return resourceInfo.Name; + } +#if REMOVE + case ExecutionKind.LocalJM: + { + return resource; + } +#endif + default: + { + throw new DryadLinqException(HpcLinqErrorCode.UnsupportedExecutionKind, + SR.UnsupportedExecutionKind); + } + } + } + + /// + /// Add a resource to a job. Check whether there are clashes in resource names. + /// + /// JobSubmission object which will hold the resources. + /// Pathname to file to add as a resource. + private void AddResource(IHpcLinqJobSubmission jobSubmission, string file) + { + // extract basename + string basename = Path.GetFileName(file); + this.jobSubmission.AddLocalFile(file); + } + + /// + /// Start executing the dryad job in the background. + /// + /// Full path to query plan XML file. + public void ExecuteAsync(string queryPlanPath) + { + switch (this.executionKind) + { + case ExecutionKind.JobScheduler: + { + lock (this) + { + FileInfo xmlHostInfo = new FileInfo(StaticConfig.XmlHostPath); + + // Consturct the Graph Manager cmd line. + //string queryPlanFileName = Path.GetFileName(queryPlanPath); + this.jobSubmission.AddJobOption("cmdline", xmlHostInfo.Name + " " + queryPlanPath); + + AddResource(this.jobSubmission, StaticConfig.XmlHostPath); + AddResource(this.jobSubmission, queryPlanPath); + + // Add black and white list file, additional resources if passed as a command line + string[] args = StaticConfig.XmlExecHostArgs.Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries); + for (int i = 0; i < args.Length; ++i) + { + string arg = args[i]; + if (arg.Equals("-bw") || arg.Equals("-r")) + { + AddResource(this.jobSubmission, args[++i]); + } + } + + this.jobSubmission.SubmitJob(); + this.currentStatus = JobStatus.Waiting; + } + break; + } +#if REMOVE + case ExecutionKind.LocalJM: + { + lock (this) + { + // Invoking Dryad as a separate process: + if (DryadLinq.APEnvironmentPath != null) + { + Environment.SetEnvironmentVariable("APENVIRONMENTPATH", DryadLinq.APEnvironmentPath); + } + if (DryadLinq.APConfigPath != null) + { + Environment.SetEnvironmentVariable("APCONFIGPATH", DryadLinq.APConfigPath); + } + ProcessStartInfo procStartInfo = new ProcessStartInfo(); + procStartInfo.FileName = DryadLinq.XmlHostPath; + procStartInfo.Arguments = dryadProgram; + procStartInfo.RedirectStandardOutput = true; + procStartInfo.RedirectStandardError = false; + procStartInfo.UseShellExecute = false; + + this.dryadProc = Process.Start(procStartInfo); + this.currentStatus = JobStatus.Running; + } + break; + } +#endif + default: + { + throw new NotImplementedException(); + } + } + } + + /// + /// Wait for job completion. + /// + /// The status of the job. + public JobStatus WaitForCompletion() + { + if (this.Terminated()) + { + return this.currentStatus; + } + + JobStatus status; + switch (this.executionKind) + { + case ExecutionKind.JobScheduler: + { + int sleep = 3000; + int maxSleep = 20000; + int retries = 0; + while (true) + { + try + { + string msg = null; + switch (status = this.GetStatus()) + { + case JobStatus.Success: + case JobStatus.Failure: + case JobStatus.Cancelled: + return status; + case JobStatus.NotSubmitted: + msg = "The job to create this table has not been submitted yet. Waiting ..."; + break; + case JobStatus.Running: + msg = "The job to create this table is still running. Waiting ..."; + break; + case JobStatus.Waiting: + msg = "The job to create this table is still queued. Waiting ..."; + break; + default: + throw new DryadLinqException(HpcLinqErrorCode.UnexpectedJobStatus, + String.Format(SR.UnexpectedJobStatus, status.ToString())); + } + + retries = 0; + HpcClientSideLog.Add(msg); + } + catch (System.Net.WebException) + { + retries++; + sleep = maxSleep; + HpcClientSideLog.Add("Error contacting web server while querying job status. Waiting ..."); + } + + if (retries > 5) + { + throw new DryadLinqException(HpcLinqErrorCode.JobStatusQueryError, + SR.JobStatusQueryError); + } + if (sleep < maxSleep) + { + sleep = Math.Min(maxSleep, (int)(sleep * 1.1)); + } + System.Threading.Thread.Sleep(sleep); + } + } +#if REMOVE + case ExecutionKind.LocalJM: + { + StreamReader dryadProcOut = this.dryadProc.StandardOutput; + StreamWriter stdout = null; + try + { + if (DryadLinq.LocalJMStdout != null) + { + stdout = new StreamWriter(DryadLinq.LocalJMStdout); + } + string outLine; + while ((outLine = dryadProcOut.ReadLine()) != null) + { + if (DryadLinq.Verbose) + { + Console.WriteLine(outLine); + } + if (stdout != null) + { + stdout.WriteLine(outLine); + } + } + } + finally + { + if (stdout != null) stdout.Close(); + } + + this.dryadProc.WaitForExit(); + status = this.GetStatus(); + this.dryadProc.Close(); + return status; + } +#endif + default: + { + throw new NotImplementedException(); + } + } + } + + /// + /// True if the background execution has terminated. + /// + /// + public bool Terminated() + { + // First check whether the status is finalized + switch (this.currentStatus) + { + case JobStatus.Cancelled: + case JobStatus.Failure: + case JobStatus.Success: + case JobStatus.NotSubmitted: + // this status can't change any more + return true; + case JobStatus.Running: + case JobStatus.Waiting: + // re-evaluate status + return false; + default: + throw new DryadLinqException(HpcLinqErrorCode.UnexpectedJobStatus, + String.Format(SR.UnexpectedJobStatus, this.currentStatus.ToString())); + } + } + + /// + /// Find out the status of the job. + /// + /// The job status. + internal JobStatus GetStatus() + { + if (this.Terminated()) + { + return this.currentStatus; + } + + lock (this) + { + if (this.executionKind == ExecutionKind.JobScheduler) + { + this.currentStatus = this.jobSubmission.GetStatus(); + } +#if REMOVE + else if (this.executionKind == ExecutionKind.LocalJM) + { + if (!dryadProc.HasExited) + { + this.currentStatus = JobStatus.Running; + } + else + { + if (this.dryadProc.ExitCode != 0) + { + this.currentStatus = JobStatus.Failure; + this.errorMsg = "HpcLinq graph manager process failed with exit code " + dryadProc.ExitCode.ToString(); + } + else + this.currentStatus = JobStatus.Success; + } + } +#endif + else + { + throw new DryadLinqException(HpcLinqErrorCode.UnsupportedExecutionKind, + SR.UnsupportedExecutionKind); + } + return currentStatus; + } + } + + internal void SetStatus(JobStatus js) + { + this.currentStatus = js; + } + + /// + /// Cancel the job computation. + /// + /// The actual status of the job. This may be 'Cancelled', but if the job has completed it may be 'Success' as well. + internal JobStatus Cancel() + { + if (this.Terminated()) + { + return currentStatus; + } + + lock (this) + { + switch (this.executionKind) + { + case ExecutionKind.JobScheduler: + { + this.currentStatus = this.jobSubmission.TerminateJob(); + return this.currentStatus; + } +#if REMOVE + case ExecutionKind.LocalJM: + { + if (this.dryadProc == null) + { + return JobStatus.NotSubmitted; + } + + if (!this.dryadProc.HasExited) + { + this.dryadProc.Kill(); + this.errorMsg = "Job Manager was cancelled"; + this.currentStatus = JobStatus.Cancelled; + } + return this.GetStatus(); + } +#endif + default: + { + throw new DryadLinqException(HpcLinqErrorCode.UnsupportedExecutionKind, + SR.UnsupportedExecutionKind); + } + } + } + } + + /// + /// For failed jobs this contains an error message. + /// + internal string ErrorMsg + { + get + { + switch (this.executionKind) + { + case ExecutionKind.JobScheduler: + { + return this.jobSubmission.ErrorMsg; + } +#if REMOVE + case ExecutionKind.LocalJM: + { + return this.errorMsg; + } +#endif + default: + { + throw new DryadLinqException(HpcLinqErrorCode.UnsupportedExecutionKind, + SR.UnsupportedExecutionKind); + } + } + } + } + + internal IHpcLinqJobSubmission JobSubmission + { + get { return this.jobSubmission; } + } + } +} diff --git a/LinqToDryad/DryadLinqLog.cs b/LinqToDryad/DryadLinqLog.cs new file mode 100644 index 0000000..bcecfbc --- /dev/null +++ b/LinqToDryad/DryadLinqLog.cs @@ -0,0 +1,120 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.Text; +using System.IO; + +namespace Microsoft.Research.DryadLinq.Internal +{ + // The client-log is a file produced in the current working directly (next to generated .cs/.dll and queryPlan.xml) + // + // Note: concurrent queries in one AppDomain will get interleaved logs + // concurrent queries from different AppDomains will not all be able to write logs. + // + // Once the log file is opened, it generally will remain open for the duration of the AppDomain. + internal static class HpcClientSideLog + { + private static bool _enabled = true; // whether to log or not. + + internal const string CLIENT_LOG_FILENAME = "HpcLinq.log"; + private static bool s_IOErrorOccurred = false; + private static StreamWriter s_writer; + + public static void Add(string msg) + { + if (!_enabled) + return; + + Add(msg, null); + } + + public static void Add(string msg, params object[] args) + { + if (!_enabled) + return; + + if(s_IOErrorOccurred) + return; + + try{ + if (s_writer == null) + { + string path = HpcLinqCodeGen.GetPathForGeneratedFile(CLIENT_LOG_FILENAME, null); + s_writer = new StreamWriter(path); + } + + if (args == null) + { + s_writer.WriteLine(msg); + } + else + { + s_writer.WriteLine(msg, args); + } + s_writer.Flush(); + } + catch (IOException) + { + s_IOErrorOccurred = true; + try + { + s_writer.Close(); + s_writer.Dispose(); + s_writer = null; + } + catch + { + // supress exceptions that occur during cleanup. + } + return; + } + } + } + + public static class DryadLinqLog + { + public static bool IsOn { get; set; } + private static StreamWriter sw = new StreamWriter("LinqLog.txt", true); + + static DryadLinqLog() + { + sw.AutoFlush = true; + } + + public static void Add(bool isOn, string msg, params object[] args) + { + if (isOn || IsOn) + { + Console.WriteLine("DryadLinq: " + msg, args); + sw.WriteLine("DryadLinq: " + msg, args); + } + } + + public static void Add(string msg, params object[] args) + { + Add(true, msg, args); //Debug - was false + } + } +} diff --git a/LinqToDryad/DryadLinqMetaData.cs b/LinqToDryad/DryadLinqMetaData.cs new file mode 100644 index 0000000..331a52e --- /dev/null +++ b/LinqToDryad/DryadLinqMetaData.cs @@ -0,0 +1,210 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using System.Runtime.Serialization.Formatters.Binary; +using System.Runtime.Serialization; +using Microsoft.Research.DryadLinq.Internal; +using System.Xml; +using System.Text; + +namespace Microsoft.Research.DryadLinq +{ + //@@TODO[P1]: read/write meta-data to the DSC stream attribute. + // use the query-plan to pass the meta-data to the JM which will set the attribute on each output stream. + internal class DryadLinqMetaData + { + const int FLAG_ALLOW_NULL_RECORDS = 0x1; + const int FLAG_ALLOW_NULL_FIELDS = 0x2; + const int FLAG_ALLOW_NULL_ARRAY_ELEMENTS = 0x4; + const int FLAG_ALLOW_AUTO_TYPE_INFERENCE = 0x8; + + private HpcLinqContext m_context; + private string m_dscStreamName; + private Type m_elemType; + private DscCompressionScheme m_compressionScheme; + //private Version m_version; + //private int m_serializationFlags; + //private UInt64 m_fp; + //private DataSetInfo m_dataSetInfo; + + internal const string RECORD_TYPE_NAME = "__LinqToHPC__ 364C7B59-08C0-44ED-B8A2-7E224ED5B7ED"; + //internal const string COMPRESSION_SCHEME_NAME = "__LinqToHpc__8D22DD19-FA86-45DB-A0D7-C3C3A1440C90"; + //NOTE: although Compression is stored as a DSC attribute, it can only be retrieved via dscFileSet.CompressionScheme + // fs.GetMetaData("name") will only return the byte[] payload, for which there is none. + // Also, fs.CompressionScheme can only be set via dsc.CreateFileSet(.., scheme) + private DryadLinqMetaData() + { + } + + internal static DryadLinqMetaData ForLocalDebug(HpcLinqContext context, + Type recordType, + string dscStreamName, + DscCompressionScheme compressionScheme) + { + DryadLinqMetaData metaData = new DryadLinqMetaData(); + + metaData.m_context = context; + metaData.m_dscStreamName = dscStreamName; + metaData.m_elemType = recordType; + metaData.m_compressionScheme = compressionScheme; + //metaData.m_version = context.ClientVersion; + //metaData.InitializeFlags(); + + //metaData.m_fp = 0UL; + //metaData.m_dataSetInfo = node.OutputDataSetInfo; + + return metaData; + } + + // create DryadLinqMetaData from a query OutputNode + internal static DryadLinqMetaData FromOutputNode(HpcLinqContext context, DryadOutputNode node) + { + DryadLinqMetaData metaData = new DryadLinqMetaData(); + + if (! (DataPath.IsDsc(node.MetaDataUri) || DataPath.IsHdfs(node.MetaDataUri)) ) + { + throw new InvalidOperationException(); + } + + metaData.m_context = context; + metaData.m_dscStreamName = node.MetaDataUri; + metaData.m_elemType = node.OutputTypes[0]; + metaData.m_compressionScheme = node.OutputCompressionScheme; + //metaData.m_version = context.ClientVersion; + //metaData.InitializeFlags(); + + //metaData.m_fp = 0UL; + //metaData.m_dataSetInfo = node.OutputDataSetInfo; + + return metaData; + } + + // Load a DryadLinqMetaData from an existing dsc stream. + internal static DryadLinqMetaData FromDscStream(HpcLinqContext context, string dscStreamName) + { + DryadLinqMetaData metaData; + try + { + DscFileSet fs = context.DscService.GetFileSet(dscStreamName); + metaData = new DryadLinqMetaData(); + metaData.m_context = context; + metaData.m_dscStreamName = dscStreamName; + //metaData.m_fp = 0L; + //metaData.m_dataSetInfo = null; + + byte[] metaDataBytes; + + //record-type + metaDataBytes = fs.GetMetadata(DryadLinqMetaData.RECORD_TYPE_NAME); + if (metaDataBytes != null) + { + string recordTypeString = Encoding.UTF8.GetString(metaDataBytes); + metaData.m_elemType = Type.GetType(recordTypeString); + } + + //Compression-scheme + metaData.m_compressionScheme = fs.CompressionScheme; + } + catch (Exception e) + { + throw new DryadLinqException(HpcLinqErrorCode.ErrorReadingMetadata, + String.Format(SR.ErrorReadingMetadata), e); + } + + return metaData; + } + + //private void InitializeFlags() + //{ + // this.m_serializationFlags = 0; + + // if (StaticConfig.AllowNullRecords) + // { + // this.m_serializationFlags |= FLAG_ALLOW_NULL_RECORDS; + // } + // if (StaticConfig.AllowNullFields) + // { + // this.m_serializationFlags |= FLAG_ALLOW_NULL_FIELDS; + // } + // if (StaticConfig.AllowNullArrayElements) + // { + // this.m_serializationFlags |= FLAG_ALLOW_NULL_ARRAY_ELEMENTS; + // } + // if (StaticConfig.AllowAutoTypeInference) + // { + // this.m_serializationFlags |= FLAG_ALLOW_AUTO_TYPE_INFERENCE; + // } + //} + + internal string MetaDataUri + { + get { return this.m_dscStreamName; } + } + + internal Type ElemType + { + get { return this.m_elemType; } + } + + //internal Version Version + //{ + // get { return this.m_version; } + //} + + internal DscCompressionScheme CompressionScheme + { + get { return this.m_compressionScheme; } + } + + //internal bool AllowNullRecords + //{ + // get { return false; } //(this.m_serializationFlags & FLAG_ALLOW_NULL_RECORDS) != 0; + //} + + //internal bool AllowNullFields + //{ + // get { return false; } // (this.m_serializationFlags & FLAG_ALLOW_NULL_FIELDS) != 0; + //} + + //internal bool AllowNullArrayElements + //{ + // get { return false; } // (this.m_serializationFlags & FLAG_ALLOW_NULL_ARRAY_ELEMENTS) != 0; + //} + + //internal bool AllowAutoTypeInference + //{ + // get { return false; }// (this.m_serializationFlags & FLAG_ALLOW_AUTO_TYPE_INFERENCE) != 0; + //} + + //internal DataSetInfo DataSetInfo + //{ + // get { return this.m_dataSetInfo; } + //} + } +} diff --git a/LinqToDryad/DryadLinqNative.cs b/LinqToDryad/DryadLinqNative.cs new file mode 100644 index 0000000..5b540ab --- /dev/null +++ b/LinqToDryad/DryadLinqNative.cs @@ -0,0 +1,148 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Security; +using System.Threading; +using System.Runtime.InteropServices; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.Research.DryadLinq.Internal +{ + internal struct MEMORYSTATUSEX + { + public UInt32 dwLength; + public UInt32 dwMemoryLoad; + public UInt64 ullTotalPhys; + public UInt64 ullAvailPhys; + public UInt64 ullTotalPageFile; + public UInt64 ullAvailPageFile; + public UInt64 ullTotalVirtual; + public UInt64 ullAvailVirtual; + public UInt64 ullAvailExtendedVirtual; + } + + // This class contains the Win32 and Dryad native API. + // Security Warning: We suppressed unmanaged code secuirty check, + // which saves a stack walk for each call into unmanaged code. + [SuppressUnmanagedCodeSecurity] + internal static class HpcLinqNative + { + /* Win32 native API */ + [DllImport("kernel32.dll", SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + internal unsafe static extern bool ReadFile(SafeFileHandle handle, + byte* pBuffer, + UInt32 numBytesToRead, + IntPtr pNumBytesRead, + NativeOverlapped* overlapped); + + [DllImport("kernel32.dll", SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + internal unsafe static extern bool WriteFile(SafeFileHandle handle, + byte* pBuffer, + UInt32 numBytesToWrite, + IntPtr pNumBytesWritten, + NativeOverlapped* overlapped); + + [DllImport("kernel32.dll", SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + internal static extern bool GetFileSizeEx(SafeFileHandle handle, out Int64 fsize); + + [DllImport("kernel32.dll", SetLastError=true)] + [return: MarshalAs(UnmanagedType.Bool)] + internal static extern bool GlobalMemoryStatusEx(ref MEMORYSTATUSEX lpBuffer); + + /* Dryad native API */ + [DllImport("DryadLINQNativeChannels.dll", SetLastError=true)] + internal static extern UInt32 GetNumOfInputs(IntPtr vertexInfo); + + [DllImport("DryadLINQNativeChannels.dll", SetLastError=true)] + internal static extern UInt32 GetNumOfOutputs(IntPtr vertexInfo); + + [DllImport("DryadLINQNativeChannels.dll", SetLastError=true)] + internal static extern void Flush(IntPtr vertexInfo, UInt32 portNum); + + [DllImport("DryadLINQNativeChannels.dll", SetLastError=true)] + internal static extern void Close(IntPtr vertexInfo, UInt32 portNum); + + // Get the expected size in bytes of the input channel of the given port. + // It returns -1 if the size is unknown. + [DllImport("DryadLINQNativeChannels.dll", SetLastError=true)] + internal static extern Int64 GetExpectedLength(IntPtr vertexInfo, UInt32 portNum); + + // Get the global vertex id which is unique. + [DllImport("DryadLINQNativeChannels.dll", SetLastError = true)] + internal static extern Int64 GetVertexId(IntPtr vertexInfo); + + // Set the hint size for the output channel of the given port. + [DllImport("DryadLINQNativeChannels.dll", SetLastError = true)] + internal static extern void SetInitialSizeHint(IntPtr vertexInfo, UInt32 portNum, UInt64 hint); + + // Get the URI of the input channel of the given port. + [DllImport("DryadLINQNativeChannels.dll", SetLastError=true)] + internal static extern IntPtr GetInputChannelURI(IntPtr vertexInfo, UInt32 portNum); + + // Get the URI of the output channel of the given port. + [DllImport("DryadLINQNativeChannels.dll", SetLastError=true)] + internal static extern IntPtr GetOutputChannelURI(IntPtr vertexInfo, UInt32 portNum); + + // Read the data block from the channel of the specified port number. + // *pDataBlockSize is the number of bytes read, and return 0 if the + // channel reaches the end. In this case, *pDataBlock should be null. + // + // The caller is considered to be the exclusive owner of this data + // block. This data block will not be reclaimed until the caller + // explicitly releases it. + [DllImport("DryadLINQNativeChannels.dll", SetLastError=true)] + internal unsafe static extern IntPtr ReadDataBlock(IntPtr vertexInfo, + UInt32 portNum, + byte** pDataBlock, + Int32* pDataBlockSize, + Int32* pErrorCode); + + // Write the data block on the channel with the specified port number. + // + // The data block should be considered read-only after WriteDataBlock + // has been called. This data block will not be reclaimed until the + // client explicitly releases it. + [DllImport("DryadLINQNativeChannels.dll", SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + internal unsafe static extern bool WriteDataBlock(IntPtr vertexInfo, + UInt32 portNum, + IntPtr itemHandle, + Int32 numBytesToWrite); + + // Allocate a native Dryad data block with specified size. This data + // block will not be reclaimed until the client explicitly releases it. + [DllImport("DryadLINQNativeChannels.dll", SetLastError=true)] + internal unsafe static extern IntPtr AllocateDataBlock(IntPtr vertexInfo, Int32 size, + byte** pDataBlock); + + // Release the data block. The client should not access it again after releasing. + [DllImport("DryadLINQNativeChannels.dll", SetLastError=true)] + internal unsafe static extern void ReleaseDataBlock(IntPtr vertexInfo, IntPtr itemHandle); + } +} diff --git a/LinqToDryad/DryadLinqObjectStore.cs b/LinqToDryad/DryadLinqObjectStore.cs new file mode 100644 index 0000000..dcca548 --- /dev/null +++ b/LinqToDryad/DryadLinqObjectStore.cs @@ -0,0 +1,171 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; +using System.IO; +using System.Reflection; +using System.Runtime.Serialization; +using System.Runtime.Serialization.Formatters.Binary; +using System.Diagnostics; +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq.Internal +{ + // This class implements an object store that is used to store objects + // needed for remote execution of managed vertex code. All objects put + // in the store must have the .NET Serializable attribute. + // Note: this class is not thread safe + public sealed class HpcLinqObjectStore + { + private const string ObjectStoreFileName = "HpcLinqObjectStore.bin"; + + internal static string GetClientSideObjectStorePath() + { + return HpcLinqCodeGen.GetPathForGeneratedFile(ObjectStoreFileName, null); + } + + private static ArrayList s_objectList = null; + + public static bool IsEmpty + { + get { + return (s_objectList == null || s_objectList.Count == 0); + } + } + + public static void Clear() + { + s_objectList = null; + } + + // this method is only used by the generated vertex code, and always + // assumes "HpcLinqObjectStore.bin" to be in the current directory + public static object Get(int idx) + { + if (s_objectList == null) + { + // Try to open the object store. First look in the parent directory + // (job directory when running normally), then try opening it from the path. + FileStream fs; + try + { + fs = new FileStream(Path.Combine(Directory.GetParent(Directory.GetCurrentDirectory()).FullName, ObjectStoreFileName), FileMode.Open, FileAccess.Read, FileShare.Read); + } + catch (FileNotFoundException) + { + fs = new FileStream(ObjectStoreFileName, FileMode.Open, FileAccess.Read, FileShare.Read); + } + + BinaryFormatter bfm = new BinaryFormatter(); + try + { + s_objectList = (ArrayList)bfm.Deserialize(fs); + } + catch (SerializationException e) + { + throw new DryadLinqException(HpcLinqErrorCode.FailedToDeserialize, + SR.FailedToDeserialize, e); + } + finally + { + if (fs != null) fs.Close(); + } + } + + if (idx >= s_objectList.Count) + { + throw new DryadLinqException(HpcLinqErrorCode.IndexOutOfRange, + SR.IndexOutOfRange); + } + return s_objectList[idx]; + } + + public static int Put(object obj) + { + if (s_objectList != null) + { + for (int idx = 0; idx < s_objectList.Count; idx++) + { + if (Object.ReferenceEquals(obj, s_objectList[idx])) + { + return idx; + } + } + } + + if (s_objectList == null) + { + s_objectList = new ArrayList(4); + } + s_objectList.Add(obj); + return (s_objectList.Count - 1); + } + + // This method is only used by the client process to save the object store before submitting a job + // Like other generated files we need to save this file in the temp directory + public static void Save() + { + if (IsEmpty) return; + string objectStorePath = GetClientSideObjectStorePath(); + + FileStream fs = new FileStream(objectStorePath, FileMode.Create); + BinaryFormatter bfm = new BinaryFormatter(); + try + { + bfm.Serialize(fs, s_objectList); + } + catch (SerializationException e) + { + foreach (object obj in s_objectList) + { + Type badType = TypeSystem.GetNonserializable(obj); + if (badType != null) + { + if (badType.IsGenericType && + badType.GetGenericTypeDefinition() == typeof(DryadLinqQuery<>)) + { + throw new DryadLinqException(HpcLinqErrorCode.CannotSerializeHpcLinqQuery, + SR.CannotSerializeHpcLinqQuery); + } + else + { + throw new DryadLinqException(HpcLinqErrorCode.CannotSerializeObject, + string.Format(SR.CannotSerializeObject, obj)); + } + } + } + + throw new DryadLinqException(HpcLinqErrorCode.GeneralSerializeFailure, + SR.GeneralSerializeFailure, e); + } + finally + { + if (fs != null) fs.Close(); + } + } + } + +} diff --git a/LinqToDryad/DryadLinqQuery.cs b/LinqToDryad/DryadLinqQuery.cs new file mode 100644 index 0000000..ad08f0b --- /dev/null +++ b/LinqToDryad/DryadLinqQuery.cs @@ -0,0 +1,889 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Linq; +using System.Diagnostics; +using System.Linq.Expressions; +using System.Reflection; +using Microsoft.Research.Dryad.Hdfs; +using Microsoft.Research.DryadLinq.Internal; +using System.Globalization; + +namespace Microsoft.Research.DryadLinq +{ + // The base provider for all DryadLinq queries. + // Any IQueryable that we are handling should satisfy ((queryable.Provider is DryadLinqProviderBase) == true) + // + // For example: + // - all IQueryable extension methods ask for (queryable.Provider) and then call provider.CreateQuery(expr) + internal abstract class DryadLinqProviderBase : IQueryProvider + { + private HpcLinqContext m_context; + internal HpcLinqContext Context { get { return m_context; } } + + internal DryadLinqProviderBase(HpcLinqContext context) + { + m_context = context; + } + + public abstract IQueryable CreateQuery(Expression expression); + public abstract IQueryable CreateQuery(Expression expression); + public abstract TResult Execute(Expression expression); + public abstract object Execute(Expression expression); + } + + // The provider for DryadLinq queries that will be executed by the LocalDebug infrastructure. + internal sealed class DryadLinqLocalProvider : DryadLinqProviderBase + { + private IQueryProvider m_linqToObjectsProvider; + + public DryadLinqLocalProvider(IQueryProvider linqToObjectsProvider, HpcLinqContext context) + : base(context) + { + this.m_linqToObjectsProvider = linqToObjectsProvider; + } + + //Always throw for untyped call. + public override IQueryable CreateQuery(Expression expression) + { + MethodCallExpression callExpr = expression as MethodCallExpression; + if (callExpr == null) + { + throw new DryadLinqException(HpcLinqErrorCode.ExpressionMustBeMethodCall, SR.ExpressionMustBeMethodCall); + } + string methodName = callExpr.Method.Name; + throw new DryadLinqException(HpcLinqErrorCode.UntypedProviderMethodsNotSupported, + String.Format(SR.UntypedProviderMethodsNotSupported, methodName)); + } + + //Always throw for untyped call. + public override object Execute(Expression expression) + { + return this.CreateQuery(expression); + } + + public override IQueryable CreateQuery(Expression expression) + { + ThrowIfUnsupported(expression); + var localQuery = this.m_linqToObjectsProvider.CreateQuery(expression); + return new DryadLinqLocalQuery(this, localQuery); + } + + public override TResult Execute(Expression expression) + { + ThrowIfUnsupported(expression); + return this.m_linqToObjectsProvider.Execute(expression); + } + + internal void ThrowIfUnsupported(Expression expression) + { + var mcexpr = expression as MethodCallExpression; + if (mcexpr != null) + { + // if (mcexpr.Method.Name == "SequenceEqual") + // { + // throw new NotSupportedException(SR.SequenceEqualNotSupported); + // } + } + } + } + + // The IQueryable that is used for LocalDebug queries. + // This is much simpler than DryadLinqQuery as it only has to support fallback to LINQ-to-objects. + internal sealed class DryadLinqLocalQuery : IOrderedQueryable, IEnumerable, IOrderedQueryable + { + private IQueryProvider m_queryProvider; + private IQueryable m_localQuery; + + public DryadLinqLocalQuery(IQueryProvider queryProvider, IQueryable localQuery) + { + this.m_queryProvider = queryProvider; + this.m_localQuery = localQuery; + } + + public Expression Expression + { + get { return this.m_localQuery.Expression; } + } + + Type IQueryable.ElementType + { + get { return typeof(T); } + } + + IQueryProvider IQueryable.Provider + { + get { return this.m_queryProvider; } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + return this.m_localQuery.GetEnumerator(); + } + } + + + // The provider for DryadLinq queries that will be executed by the cluster infrastructure. + internal sealed class DryadLinqProvider : DryadLinqProviderBase + { + internal DryadLinqProvider(HpcLinqContext context) + : base(context) + { + } + + //It is exercised by unit test "Bug11782_LowLevelQueryableManipulation" + // which ccn now expect to receive an exception + // Always throw for the untyped calls. + public override IQueryable CreateQuery(Expression expression) + { + MethodCallExpression callExpr = expression as MethodCallExpression; + if (callExpr == null) + { + throw new DryadLinqException(HpcLinqErrorCode.ExpressionMustBeMethodCall, + SR.ExpressionMustBeMethodCall); + } + string methodName = callExpr.Method.Name; + throw new DryadLinqException(HpcLinqErrorCode.UntypedProviderMethodsNotSupported, + String.Format(SR.UntypedProviderMethodsNotSupported, methodName)); + } + + public override IQueryable CreateQuery(Expression expression) + { + return new DryadLinqQuery(this, expression); + } + + //This is the IQueryProvider.Execute() method used for execution when a single value is produced (rather than an enumerable) + //This non-generic method simply delegates to the generic method + //Always throw for the untyped calls + public override object Execute(Expression expression) + { + return this.CreateQuery(expression); // which will throw. + } + + //This is the IQueryProvider.Execute() method used for execution when a single value is produced (rather than an enumerable) + public override TResult Execute(Expression expression) + { + MethodCallExpression callExpr = expression as MethodCallExpression; + if (callExpr == null) + { + throw new ArgumentException(String.Format(SR.ExpressionMustBeMethodCall, + HpcLinqExpression.Summarize(expression)), "expression"); + } + string methodName = callExpr.Method.Name; + if (methodName == "FirstOrDefault" || + methodName == "SingleOrDefault" || + methodName == "LastOrDefault") + { + Type qType = typeof(DryadLinqQuery<>).MakeGenericType(typeof(AggregateValue<>).MakeGenericType(expression.Type)); + AggregateValue res = ((IEnumerable>) + Activator.CreateInstance( + qType, + BindingFlags.NonPublic | BindingFlags.Instance, + null, + new object[] { this, expression}, + CultureInfo.CurrentCulture + )).Single(); + if (res.Count == 0) return default(TResult); + return res.Value; + } + else + { + Type qType = typeof(DryadLinqQuery<>).MakeGenericType(expression.Type); + return ((IEnumerable) Activator.CreateInstance( + qType, + BindingFlags.NonPublic | BindingFlags.Instance, + null, + new object[] { this, expression }, + CultureInfo.CurrentCulture + )).Single(); + } + } + } + + + //Note: cannot be sub-classed by users as they cannot provide overrides for the internal abstract properties. + internal abstract class DryadLinqQuery + { + protected DryadLinqProviderBase m_queryProvider; + private DataProvider m_dataProvider; + private bool m_isTemp; + private JobExecutor m_queryExecutor; + + internal DryadLinqQuery(DryadLinqProviderBase queryProvider, DataProvider dataProvider) + { + this.m_queryProvider = queryProvider; + this.m_dataProvider = dataProvider; + this.m_queryExecutor = null; + } + + //if non-null, this provided a data-backed DLQ that should be used in place of (this). + //query-execution causes a _backingData field to be set for the DLQ nodes that were specifically "executed". + //(used to be called _table/Table) + internal abstract DryadLinqQuery BackingDataDLQ { set; } + internal abstract bool IsDataBacked { get; } + + internal abstract Type Type { get; } + internal abstract string DataSourceUri { get; } + internal abstract bool IsDynamic { get; } + internal abstract int PartitionCount { get; } + internal abstract DataSetInfo DataSetInfo { get; } + + internal DryadLinqProviderBase QueryProvider + { + get { return this.m_queryProvider; } + } + + internal DataProvider DataProvider + { + get { return this.m_dataProvider; } + } + + internal bool IsTemp + { + set { this.m_isTemp = value; } + } + + internal JobExecutor QueryExecutor + { + get { return this.m_queryExecutor; } + set { this.m_queryExecutor = value; } + } + + protected void CopyToInternal(DryadLinqQuery otherQuery) + { + otherQuery.m_queryProvider = this.m_queryProvider; + otherQuery.m_dataProvider = this.m_dataProvider; + otherQuery.m_isTemp = this.m_isTemp; + otherQuery.m_queryExecutor = this.m_queryExecutor; + } + + internal virtual VertexCodeGen GetVertexCodeGen() + { + return new VertexCodeGen(); + } + } + + // The IQueryable that is used for cluster-execution queries. + internal class DryadLinqQuery : DryadLinqQuery, IOrderedQueryable, IEnumerable, IOrderedQueryable + { + // If _backingDataDLQ is set, this is a normal query node that was executed and now has a + // "PlainData" DLQ available with the results. + private DryadLinqQuery m_backingDataDLQ; + private Expression m_queryExpression; + private string m_dataSourceUri; + private DataSetInfo m_dataSetInfo; + private bool m_isDynamic; + private DryadLinqQueryEnumerable m_tableEnumerable; + + //ctor: + // 1. used by IQueryProvider. eg IQueryable<>.Select() + // 2. used by lazy queries which come in through here via + // DryadLinqIQueryable.ToPartitionedTableLazy() + // -> DryadLinqProvider.CreateQuery() + // + internal DryadLinqQuery(DryadLinqProviderBase provider, Expression expression) + : base(provider, null) + { + this.m_queryExpression = expression; + this.m_isDynamic = false; + this.m_tableEnumerable = null; + } + + //ctor: + //[ML]: combined from MSR-DL ctors for PartitionedTable<> and DryadLinqQuery<> + // This ctor is used by DryadLinqQuery.Get(uri) + internal DryadLinqQuery(Expression queryExpression, + DryadLinqProvider queryProvider, + DataProvider dataProvider, + string dataUri) + : base(queryProvider, dataProvider) + { + if(!DataPath.IsDsc(dataUri) && !DataPath.IsHdfs(dataUri)) + { + throw new DryadLinqException(HpcLinqErrorCode.UnrecognizedDataSource, + String.Format(SR.UnrecognizedDataSource, dataUri)); + } + + this.m_queryExpression = queryExpression; + this.m_dataSourceUri = DataPath.GetDataPath(dataUri); + this.m_isDynamic = false; + this.m_tableEnumerable = null; + } + + internal void CopyTo(DryadLinqQuery otherQuery) + { + this.CopyToInternal(otherQuery); + otherQuery.m_backingDataDLQ = this.m_backingDataDLQ; + otherQuery.m_queryExpression = this.m_queryExpression; + otherQuery.m_dataSourceUri = this.m_dataSourceUri; + otherQuery.m_dataSetInfo = this.m_dataSetInfo; + otherQuery.m_isDynamic = this.m_isDynamic; + otherQuery.m_tableEnumerable = this.m_tableEnumerable; + } + + // returns true for DLQ that are pointing directly at plain data. + // Note: plain-data DLQ might also have an executor associated with it.. the data wont be + // available unless the executor completes sucessfully. + internal bool IsPlainData + { + get { return (this.m_dataSourceUri != null); } + } + + // returns true for DLQ that are not themselves pointing directly at plain data, eg query-operators. + internal bool IsNormalQuery + { + get { return (this.m_dataSourceUri == null); } + } + + internal override DryadLinqQuery BackingDataDLQ + { + set { m_backingDataDLQ = (DryadLinqQuery)value; } + } + + // returns true for a normal query that was executed and now has a backing data DLQ available. + internal override bool IsDataBacked + { + get { return (this.m_backingDataDLQ != null); } + } + + // returns true if an executor is associated with the DLQ. + internal bool HasExecutor + { + get { + bool hasExec = (this.QueryExecutor != null); + if (hasExec && !IsPlainData) + { + throw new DryadLinqException("An executor should only be associated with a DLQ that is plain data"); + } + return hasExec; + } + } + + public HpcLinqContext Context + { + get { return m_queryProvider.Context; } + } + + #region IQueryable members + Type IQueryable.ElementType + { + get { return typeof(T); } + } + + //@@Comment-required: (bit unclear what the intended behavior is for localDebug) + //ML: combined from PT and DLQ.. + IQueryProvider IQueryable.Provider + { + get + { + this.CheckAndInitialize(); + return this.m_queryProvider; + } + } + #endregion + + // Executes a query to a named Dsc URI. The query should _not_ be terminated with ToDsc(). + internal DryadLinqQuery ToTemporaryTable(HpcLinqContext context, string targetUri) + { + if ((!DataPath.IsDsc(targetUri)) && (!DataPath.IsHdfs(targetUri))) + { + throw new ArgumentException(String.Format(SR.UnrecognizedDataSource, targetUri)); + } + + HpcLinqQueryGen dryadGen = null; + string realTableUri = targetUri; +#if REMOVE_FOR_YARN + if (IsPlainData) // was if (this.m_queryExpression is ConstantExpression) + { + //@@TODO: I think this is dead code. See if it can be exercised. + + // the input is a Plain-data DLQ. + // the output-target has been set + // We expect both to be DSC -- so we just use the DSC API to perform a copy rather than invoke dryad. + string inputUri = DataSourceUri; + + Debug.Assert(DataPath.IsDsc(inputUri) && DataPath.IsDsc(targetUri), "both uris should be to Dsc"); + + using (DscInstance inputService = new DscInstance(new Uri(inputUri))) + using (DscInstance outputService = new DscInstance(new Uri(targetUri))) + { + DscStream inputStream = inputService.GetStream(new Uri(inputUri)); + try + { + DscStream outputStream = outputService.GetStream(new Uri(targetUri)); + outputStream.Delete(); + } + catch (DscException) + { } + inputStream.Copy(new Uri(targetUri)); + ////this.m_table = DryadLinqQuery.Get(tableUri); + ////return this.m_table; + //// [ML] part of deleting this.m_table. We just return (this) rather than (this.m_table) + + DryadLinqQuery databackedDLQ = DataProvider.GetPartitionedTable(Context, targetUri); + this.m_backingDataDLQ = databackedDLQ; //ML: we set the new table as backing data for the source. (not sure what this gains) + return databackedDLQ; + } + + } + else if (IsDataBacked) + { + // @@TODO: I think this is dead code. See if it can be exercised. + // if taken, we should be able to just recurse with the backing data (after doing _backingDataDLQ.CheckAndInitialize()) + throw new NotImplementedException(); + } +#endif + // Invoke Dryad + Debug.Assert(IsNormalQuery, "execution should only occur for a normal query"); + if (dryadGen == null) + { + dryadGen = new HpcLinqQueryGen(context, this.GetVertexCodeGen(), this.Expression, realTableUri, true); + } + DryadLinqQuery[] resultTables = dryadGen.InvokeDryad(); + this.m_backingDataDLQ = (DryadLinqQuery) resultTables[0]; + + return this; + } + + // Generate the query plan as an XML file and return the file name. + // provided for test-support. Access via reflection. + // returns the queryPlan xml path. + internal string ToDryadLinqProgram() + { + string tableUri = DataPath.DSC_URI_PREFIX + @"dummy/dummy"; + HpcLinqQueryGen dryadGen = new HpcLinqQueryGen(Context, this.GetVertexCodeGen(), this.m_queryExpression, tableUri, true); + return dryadGen.GenerateDryadProgram(); + } + + // ML: complex ToString is problematic for the debugger -- eg Expression.ToString() leads to infinite recursion + // for PlainData which has Expression=ConstantExpression(this) and ConstantExpression(x).ToString() ==> x.ToString() + // Also, a rich ToString risks leaking internal details. + // @@TODO[P2]: this override should not be necessary.. however the debugger was acting up without it.. + // eg timing out when inspecting DryadLinqQuery objects in watch window etc. + public override string ToString() + { + return base.ToString(); + } + + internal void Initialize() + { + //Detailed initialize behavior is only for plain-old-data. + //This was previously implicit (as only defined on the PartitionedData<> type) + if (this.IsPlainData) + { + // short-circuit if already initialized + if (this.m_tableEnumerable != null) + { + return; + } + + Int32 parCount = 0; + Int64 estSize = -1; + this.m_isDynamic = false; + + try + { + // YY: TODO: This could just be set to -1 if the xmlexechost will create the correct number of partitions + // YY: We need the partition count here: it is used in plan optimization. + if (DataPath.IsHdfs((this.m_dataSourceUri))) + { + //hdfs + /* + using (HdfsInstance hdfs = new HdfsInstance(this.m_dataSourceUri)) + { + string path = hdfs.FromInternalUri(this.m_dataSourceUri); + HdfsFileInfo dataStream = hdfs.GetFileInfo(path, true); + estSize = (long)dataStream.totalSize; + parCount = (Int32)dataStream.blockArray.Length; + } + */ + WebHdfsClient.GetContentSummary(this.m_dataSourceUri, ref estSize, ref parCount); + + } + else + { + //dsc + using (DscInstance dataService = new DscInstance(new Uri(this.m_dataSourceUri))) + { + DscStream dataStream = dataService.GetStream(new Uri(this.m_dataSourceUri)); + estSize = (long)dataStream.Length; + parCount = (Int32)dataStream.PartitionCount; + + } + } + } + catch (Exception e) + { + throw new DryadLinqException(HpcLinqErrorCode.FailedToGetStreamProps, + String.Format(SR.FailedToGetStreamProps, this.m_dataSourceUri), e); + } + + // --- start metadata processing --- // + // Finally load any stored metadata to check settings, extract compression-setting and initialize the DataInfo for this Query. + string streamName = DataPath.GetFilesetNameFromUri(this.m_dataSourceUri); // we converted to uri.. now must go back to stream-name. + DryadLinqMetaData meta = null; + if (DataPath.IsDsc(this.m_dataSourceUri)) + { + meta = DryadLinqMetaData.FromDscStream(Context, streamName); + } + if (meta != null) + { + //check the record-type matches meta-data. (disabled until final API is determined) + //if (meta.ElemType != typeof(T)) + //{ + // throw new HpcLinqException(HpcLinqErrorCode.MetadataRecordType, + // String.Format(SR.MetadataRecordType, typeof(T), meta.ElemType)); + //} + + //check the serialization flags match meta-data. (disabled as serialization flags are fixed. re-consider when flags become user-settable again.) + //if (StaticConfig.AllowNullFields != meta.AllowNullFields || + // StaticConfig.AllowNullArrayElements != meta.AllowNullArrayElements || + // StaticConfig.AllowNullRecords != meta.AllowNullRecords) + //{ + // HpcClientSideLog.Add("Warning: Table was generated with AllowNullFields=" + meta.AllowNullFields + + // ", AllowNullRecords=" + meta.AllowNullRecords + + // ", and AllowNullArrayElements=" + meta.AllowNullArrayElements); + //} + + } + + // Initialize the DataInfo -- currently we always initialize to the "nothing" datainfo. + PartitionInfo pinfo = new RandomPartition(parCount); + OrderByInfo oinfo = DataSetInfo.NoOrderBy; + DistinctInfo dinfo = DataSetInfo.NoDistinct; + this.m_dataSetInfo = new DataSetInfo(pinfo, oinfo, dinfo); + + // --- end metadata processing --- // + + string fileSetName = DataPath.GetFilesetNameFromUri(this.m_dataSourceUri); + this.m_tableEnumerable = new DryadLinqQueryEnumerable(this.Context, fileSetName); + + //YY: query expression and provider are at least set consistently + if (Context.Configuration.LocalDebug) + { + this.m_queryExpression = Expression.Constant(this.m_tableEnumerable.AsQueryable()); + IQueryProvider linqToObjectProvider = this.m_tableEnumerable.AsQueryable().Provider; // this should be an instance of "EnumerableQuery" + this.m_queryProvider = new DryadLinqLocalProvider(linqToObjectProvider, Context); + } + else + { + this.m_queryExpression = Expression.Constant(this); + this.m_queryProvider = new DryadLinqProvider(Context); + } + } + } + + internal override Type Type + { + get { return typeof(T); } + } + + // only legal/valid for plainData and data-backed DLQ. + internal override string DataSourceUri + { + get { + + if (this.IsPlainData) + { + // no need to CheckAndInitialize() as m_dataSourceUri should already be set. + // also, performing checkAndInitialize causes infinite recursion due to it accessing DataSourceUri + return this.m_dataSourceUri; + } + else if (this.IsDataBacked) + { + // as above, regarding CheckAndInitialize() + return (this.m_backingDataDLQ).m_dataSourceUri; + } + + throw new DryadLinqException(HpcLinqErrorCode.OnlyAvailableForPhysicalData, + SR.OnlyAvailableForPhysicalData); + } + } + + // combination of old approaches. return either the full expression, or just an expression for the table, if available. + // * Fundamental part of IQueryable system. + // Most IQueryable operators will access (source.Expression) and form a new IQueryable + // which is a MethodCall('method',{src.Expression,params}) + // + // Plain data: we create an expression to represent plain-data + // Data-backed query: we behave as if the IQueryable were just the backing data (ie a simple expression to plain-data) + // Normal query: an normal query node will already have a m_queryExpression + public Expression Expression + { + get + { + if (this.IsDataBacked) + { + // if this is a data-backed-query, (recursively) return the expression for the backingDLQ + Debug.Assert(this.m_backingDataDLQ.IsPlainData, "backing data is expected to always be plain data"); + return (this.m_backingDataDLQ).Expression; + } + this.CheckAndInitialize(); + return this.m_queryExpression; + } + } + + internal override int PartitionCount + { + get + { + if (IsPlainData) + { + this.CheckAndInitialize(); + return this.m_dataSetInfo.partitionInfo.Count; + } + + if (IsDataBacked) + { + this.m_backingDataDLQ.CheckAndInitialize(); + return this.m_backingDataDLQ.PartitionCount; + } + + throw new DryadLinqException(HpcLinqErrorCode.OnlyAvailableForPhysicalData, + SR.OnlyAvailableForPhysicalData); + } + } + + + + internal override bool IsDynamic + { + get { + this.CheckAndInitialize(); + return this.m_isDynamic; // possible issue: if(IsDataBacked) then using the value from backing data may be more appropriate. + } + } + + internal override DataSetInfo DataSetInfo + { + get + { + //even if data-backed, the DataSetInfo for the normal-query is the best available. + //hence this._m_dataSetInfo is always best. + + this.CheckAndInitialize(); + return this.m_dataSetInfo; + } + } + + internal void CheckAndInitialize() + { + if (HasExecutor) + { + Debug.Assert(IsPlainData, "We expect a DLQ with an executor to be a plain-data DLQ"); + + JobStatus status = this.QueryExecutor.WaitForCompletion(); + if (status == JobStatus.Failure) + { + throw new DryadLinqException(HpcLinqErrorCode.JobToCreateTableFailed, + String.Format(SR.JobToCreateTableFailed, this.QueryExecutor.ErrorMsg)); + } + if (status == JobStatus.Cancelled) + { + throw new DryadLinqException(HpcLinqErrorCode.JobToCreateTableWasCanceled, + SR.JobToCreateTableWasCanceled); + } + if (status == JobStatus.Success) + { + HpcClientSideLog.Add("Table " + this.DataSourceUri + " was created successfully."); + } + } + this.Initialize(); + } + + + + #region IEnumerable and IEnumerable members + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + // combined GetEnumerator from PT and DLQ... use table if present, else start query to generate anonymous output table. + public IEnumerator GetEnumerator() + { + // Process: + // 1. if this is a data-backed-query, return an enumerator over the backing data + // 2. if this is plain-data, return an enumerator over the data. + // 2. otherwise, start an anonymous query execution (which will produce a data-backed-query), and call GetEnumerator() again to hit the first path. + + if (this.IsPlainData){ + this.CheckAndInitialize(); + return this.m_tableEnumerable.GetEnumerator(); + } + else if (this.IsDataBacked) + { + m_backingDataDLQ.CheckAndInitialize(); + return m_backingDataDLQ.m_tableEnumerable.GetEnumerator(); + } + else + { + Debug.Assert(IsNormalQuery); + + // if terminated in ToDsc, eg query.ToDsc("path").GetEnumerator(); + // currently: treat this as an error. We throw in both cluster and LocalDebug modes. + // @@TODO[P2]: we could execute the query, producing data as specified by ToDsc(). (see DryadQueryGen ctors) + // + // otherwise, we create a temporary stream to hold the data and get an enumerator. + // - cluster mode will run a dryad query. + // - LocalDebug mode will write the data directly into DSC. + + string hdfsPath = DataPath.MakeUniqueTemporaryHdfsFileSetUri(Context); + return this.ToTemporaryTable(Context, hdfsPath).GetEnumerator(); + } + } + #endregion + } + + + //From PartitionedTableEnumerable + internal class DryadLinqQueryEnumerable : IEnumerable, IEnumerable + { + internal string m_fileSetName; + private HpcLinqContext m_context; + + public DryadLinqQueryEnumerable(HpcLinqContext context, string fileSetName) + { + m_context = context; + m_fileSetName = fileSetName; + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + List filePathList; // a list of dsc files, each of which is represented by an array holding the replica paths + DscCompressionScheme compressionScheme; + try + { + DscFileSet fileSet = m_context.DscService.GetFileSet(m_fileSetName); + filePathList = fileSet.GetFiles().Select(file => file.ReadPaths).ToList(); + DryadLinqMetaData metaData = DryadLinqMetaData.FromDscStream(m_context, m_fileSetName); + compressionScheme = metaData.CompressionScheme; + + } + catch (Exception e) + { + throw new DryadLinqException(HpcLinqErrorCode.FailedToGetReadPathsForStream, + String.Format(SR.FailedToGetReadPathsForStream, this.m_fileSetName), e); + } + + return new TableEnumerator(m_context, filePathList, m_fileSetName,compressionScheme); + } + + // Internal enumerator class + private class TableEnumerator : IEnumerator + { + private HpcLinqContext m_context; + private T m_current; + private List m_filePathList; // a list of dsc files, each of which is represented by an array holding the replica paths + private string m_associatedDscStreamName; // stored here only to provide a better exception message in case of IO errors + private DscCompressionScheme m_compressionScheme; + private HpcLinqFactory m_factory; + private HpcRecordReader m_reader; + + internal TableEnumerator(HpcLinqContext context, + List filePathList, + string associatedDscStreamName, + DscCompressionScheme scheme) + { + this.m_context = context; + this.m_current = default(T); + this.m_filePathList = filePathList; + this.m_associatedDscStreamName = associatedDscStreamName; + this.m_compressionScheme = scheme; + this.m_factory = (HpcLinqFactory)HpcLinqCodeGen.GetFactory(context, typeof(T)); + bool appendNewLinesToFiles = (typeof(T) == typeof(LineRecord)); + NativeBlockStream nativeStream = new MultiBlockStream(m_filePathList, m_associatedDscStreamName, + FileAccess.Read, m_compressionScheme, + appendNewLinesToFiles); + this.m_reader = this.m_factory.MakeReader(nativeStream); + + if (context.Configuration.AllowConcurrentUserDelegatesInSingleProcess) + { + this.m_reader.StartWorker(); + } + } + + public bool MoveNext() + { + if (m_context.Configuration.AllowConcurrentUserDelegatesInSingleProcess) + { + return this.m_reader.ReadRecordAsync(ref this.m_current); + } + else + { + return this.m_reader.ReadRecordSync(ref this.m_current); + } + } + + object IEnumerator.Current + { + get { return this.m_current; } + } + + public T Current + { + get { return this.m_current; } + } + + public void Reset() + { + this.m_current = default(T); + bool appendNewLineToFiles = (typeof(T) == typeof(LineRecord)); + NativeBlockStream nativeStream = new MultiBlockStream(this.m_filePathList, m_associatedDscStreamName, + FileAccess.Read, this.m_compressionScheme, + appendNewLineToFiles); + this.m_reader = this.m_factory.MakeReader(nativeStream); + + if (m_context.Configuration.AllowConcurrentUserDelegatesInSingleProcess) + { + this.m_reader.StartWorker(); + } + } + + void IDisposable.Dispose() + { + this.m_reader.Close(); + } + } + } +} diff --git a/LinqToDryad/DryadLinqSampler.cs b/LinqToDryad/DryadLinqSampler.cs new file mode 100644 index 0000000..363d8ac --- /dev/null +++ b/LinqToDryad/DryadLinqSampler.cs @@ -0,0 +1,249 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Globalization; +using System.Reflection; +using System.Linq.Expressions; +using System.Linq; +using Microsoft.Research.DryadLinq; +using System.Diagnostics; + +namespace Microsoft.Research.DryadLinq.Internal +{ + public static class HpcLinqSampler + { + internal const double SAMPLE_RATE = 0.001; + private const int MAX_SECOND_PHASE_SAMPLES = 1024*1024; + + [Resource(IsStateful=false)] + public static IEnumerable Phase1Sampling(IEnumerable source, + Func keySelector, + HpcLinqVertexEnv denv) + { + // note: vertexID is constant for each repetition of a specific vertex (eg in fail-and-retry scenarios) + // this is very good as it ensure the sampling is idempotent w.r.t. retries. + + long vertexID = HpcLinqNative.GetVertexId(denv.NativeHandle); + int seed = unchecked((int)(vertexID)); + long nEmitted = 0; + + Random rdm = new Random(seed); + + List allSoFar = new List(); + List samples = new List(); + + // try to collect 10 samples, but keep all the records just in case + IEnumerator sourceEnumerator = source.GetEnumerator(); + while (sourceEnumerator.MoveNext()) + { + T elem = sourceEnumerator.Current; + K key = keySelector(elem); + allSoFar.Add(key); + if (rdm.NextDouble() < SAMPLE_RATE) + { + samples.Add(key); + if (samples.Count >= 10) + break; + } + } + + if (samples.Count >= 10) + { + // we have lots of samples.. emit them and continue sampling + allSoFar = null; // not needed. + foreach (K key in samples) + { + yield return key; + nEmitted++; + } + while (sourceEnumerator.MoveNext()) + { + T elem = sourceEnumerator.Current; + if (rdm.NextDouble() < SAMPLE_RATE) + { + yield return keySelector(elem); + nEmitted++; + } + } + } + else + { + // sampling didn't produce much, so emit all the records instead. + DryadLinqLog.Add("Sampling produced only {0} records. Emitting all records instead.", samples.Count()); + Debug.Assert(sourceEnumerator.MoveNext() == false, "The source enumerator wasn't finished"); + samples = null; // the samples list is not needed. + foreach (K key in allSoFar) + { + yield return key; + nEmitted++; + } + } + + DryadLinqLog.Add("Stage1 sampling: num keys emitted = {0}", nEmitted); + } + + //------------------------------------ + //Range-sampler + // 1. Secondary sampling + // 2. sort, and select separator values. + + //This method is only used for dynamic inputs. Not required in RTM + //public static IEnumerable RangeSampler_Dynamic(IEnumerable source, + // Func keySelector, + // IComparer comparer, + // bool isDescending, + // HpcLinqVertexEnv denv) + //{ + // if (denv.NumberOfArguments < 2) + // { + // throw new HpcLinqException(SR.Sampler_NotEnoughArgumentsForVertex); + // } + // Int32 pcount = Int32.Parse(denv.GetArgument(denv.NumberOfArguments-1)); + // return RangeSamplerCore(source, keySelector, comparer, isDescending, pcount); + //} + + // used for static plan (ie pcount is determined on client-side and baked into vertex code) + public static IEnumerable + RangeSampler_Static(IEnumerable firstPhaseSamples, + IComparer comparer, + bool isDescending, + int pcount) + { + return RangeSamplerCore(firstPhaseSamples, comparer, isDescending, pcount); + } + + public static IEnumerable + RangeSamplerCore(IEnumerable firstPhaseSamples, + IComparer comparer, + bool isDescending, + int pcount) + { + //Reservoir sampling to produce at most MAX_SECOND_PHASE_SAMPLES records. + K[] samples = new K[MAX_SECOND_PHASE_SAMPLES]; + int inputCount = 0; + int reservoirCount = 0; + + // fixed-seed is ok here as second-phase-sampler is a singleton vertex. Idempotency is important. + Random r = new Random(314159); + + foreach (K key in firstPhaseSamples) // this completely enumerates each source in turn. + { + if (inputCount < MAX_SECOND_PHASE_SAMPLES) + { + samples[reservoirCount] = key; + inputCount++; + reservoirCount++; + } + else + { + int idx = r.Next(inputCount); // ie a number between 0..inputCount-1 inclusive. + if (idx < MAX_SECOND_PHASE_SAMPLES) + { + samples[idx] = key; + } + inputCount++; + } + + } + + // Sort and Emit the keys + Array.Sort(samples, 0, reservoirCount, comparer); + + DryadLinqLog.Add("Range-partition separator keys: "); + DryadLinqLog.Add("samples: {0}", reservoirCount); + DryadLinqLog.Add("pCount: {0}", pcount); + + if (reservoirCount == 0) + { + //DryadLinqLog.Add(" case: cnt==0. No separators produced."); + yield break; + } + + if (reservoirCount < pcount) + { + //DryadLinqLog.Add(" case: cnt < pcount"); + if (isDescending) + { + //DryadLinqLog.Add(" case: isDescending=true"); + for (int i = reservoirCount - 1; i >= 0; i--) + { + //DryadLinqLog.Add(" [{0}]", samples[i]); + yield return samples[i]; + } + K first = samples[0]; + for (int i = reservoirCount; i < pcount - 1; i++) + { + //DryadLinqLog.Add(" [{0}]", first); + yield return first; + } + } + else + { + //DryadLinqLog.Add(" case: isDescending=false"); + for (int i = 0; i < reservoirCount; i++) + { + //DryadLinqLog.Add(" [{0}]", samples[i]); + yield return samples[i]; + } + K last = samples[reservoirCount - 1]; + for (int i = reservoirCount; i < pcount - 1; i++) + { + //DryadLinqLog.Add(" [{0}]", last); + yield return last; + } + } + } + else + { + //DryadLinqLog.Add(" case: cnt >= pcount"); + int intv = reservoirCount / pcount; + if (isDescending) + { + //DryadLinqLog.Add(" case: isDescending=true"); + int idx = reservoirCount - intv; + for (int i = 0; i < pcount-1; i++) + { + //DryadLinqLog.Add(" [{0}]", samples[idx]); + yield return samples[idx]; + idx -= intv; + } + } + else + { + //DryadLinqLog.Add(" case: isDescending=false"); + int idx = intv; + for (int i = 0; i < pcount-1; i++) + { + //DryadLinqLog.Add(" [{0}]", samples[idx]); + yield return samples[idx]; + idx += intv; + } + } + } + } + } +} diff --git a/LinqToDryad/DryadLinqSerialization.cs b/LinqToDryad/DryadLinqSerialization.cs new file mode 100644 index 0000000..6fce179 --- /dev/null +++ b/LinqToDryad/DryadLinqSerialization.cs @@ -0,0 +1,755 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Text; +using System.Reflection; +using System.Linq; +using System.Data.SqlTypes; +using System.Diagnostics; +using Microsoft.Research.DryadLinq.Internal; +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq +{ + // If a class T implements HpcSerializer, HpcLinq will use the + // read and write methods of the class to do serialization. + public interface IHpcSerializer + { + T Read(HpcBinaryReader reader); + void Write(HpcBinaryWriter writer, T x); + } +} + +namespace Microsoft.Research.DryadLinq.Internal +{ + public abstract class HpcSerializer : IHpcSerializer + { + public HpcSerializer() { } + public abstract T Read(HpcBinaryReader reader); + public abstract void Write(HpcBinaryWriter writer, T x); + } + + internal struct HpcLinqSequence : IEnumerable + { + private T[] elements; + + internal HpcLinqSequence(T[] elems) + { + this.elements = elems; + } + + internal int Count() + { + return this.elements.Length; + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + internal IEnumerator GetEnumerator() + { + foreach (T x in this.elements) + { + yield return x; + } + } + } + + // The only use is to handle Nullable. + public static class StructHpcSerialization + where T : struct + where S : HpcSerializer, new() + { + private static S serializer = new S(); + + public static void Read(HpcBinaryReader reader, out Nullable val) + { + bool hasValue = reader.ReadBool(); + if (hasValue) + { + val = serializer.Read(reader); + } + else + { + val = null; + } + } + + public static void Write(HpcBinaryWriter writer, Nullable val) + { + writer.Write(val.HasValue); + if (val.HasValue) + { + serializer.Write(writer, val.Value); + } + } + } + + public static class StructHpcSerialization + where T1 : struct + where T2 : struct + where S1 : HpcSerializer, new() + where S2 : HpcSerializer, new() + { + private static S1 serializer1 = new S1(); + private static S2 serializer2 = new S2(); + } + + // A workaround to deal with some limitation of C# generics + public static class HpcSerialization + where S : HpcSerializer, new() + { + private static S serializer = new S(); + + public static void Read(HpcBinaryReader reader, out List list) + { + int cnt = reader.ReadInt32(); + list = new List(cnt); + for (int i = 0; i < cnt; i++) + { + list.Add(serializer.Read(reader)); + } + } + + public static void Write(HpcBinaryWriter writer, List list) + { + writer.Write(list.Count); + foreach (T elem in list) + { + serializer.Write(writer, elem); + } + } + + public static void Read(HpcBinaryReader reader, out LinkedList list) + { + int cnt = reader.ReadInt32(); + list = new LinkedList(); + for (int i = 0; i < cnt; i++) + { + list.AddLast(serializer.Read(reader)); + } + } + + public static void Write(HpcBinaryWriter writer, LinkedList list) + { + writer.Write(list.Count); + foreach (T elem in list) + { + serializer.Write(writer, elem); + } + } + + public static void Read(HpcBinaryReader reader, out Queue queue) + { + int cnt = reader.ReadInt32(); + queue = new Queue(cnt); + for (int i = 0; i < cnt; i++) + { + queue.Enqueue(serializer.Read(reader)); + } + } + + public static void Write(HpcBinaryWriter writer, Queue queue) + { + writer.Write(queue.Count); + foreach (T elem in queue) + { + serializer.Write(writer, elem); + } + } + + public static void Read(HpcBinaryReader reader, out Stack stack) + { + int cnt = reader.ReadInt32(); + stack = new Stack(cnt); + for (int i = 0; i < cnt; i++) + { + stack.Push(serializer.Read(reader)); + } + } + + public static void Write(HpcBinaryWriter writer, Stack stack) + { + writer.Write(stack.Count); + foreach (T elem in stack) + { + serializer.Write(writer, elem); + } + } + + public static void Read(HpcBinaryReader reader, out HashSet set) + { + int cnt = reader.ReadInt32(); + set = new HashSet(); + for (int i = 0; i < cnt; i++) + { + set.Add(serializer.Read(reader)); + } + } + + public static void Write(HpcBinaryWriter writer, HashSet set) + { + writer.Write(set.Count); + foreach (T elem in set) + { + serializer.Write(writer, elem); + } + } + + public static void Read(HpcBinaryReader reader, out Collection set) + { + int cnt = reader.ReadInt32(); + set = new Collection(); + for (int i = 0; i < cnt; i++) + { + set.Add(serializer.Read(reader)); + } + } + + public static void Write(HpcBinaryWriter writer, Collection set) + { + writer.Write(set.Count); + foreach (T elem in set) + { + serializer.Write(writer, elem); + } + } + + public static void Read(HpcBinaryReader reader, out ReadOnlyCollection set) + { + int cnt = reader.ReadInt32(); + List lst = new List(cnt); + for (int i = 0; i < cnt; i++) + { + lst.Add(serializer.Read(reader)); + } + set = new ReadOnlyCollection(lst); + } + + public static void Write(HpcBinaryWriter writer, ReadOnlyCollection set) + { + writer.Write(set.Count); + foreach (T elem in set) + { + serializer.Write(writer, elem); + } + } + + public static void Read(HpcBinaryReader reader, out IEnumerable seq) + { + int cnt = reader.ReadInt32(); + T[] elems = new T[cnt]; + for (int i = 0; i < cnt; i++) + { + elems[i] = serializer.Read(reader); + } + seq = new HpcLinqSequence(elems); + } + + public static void Write(HpcBinaryWriter writer, IEnumerable seq) + { + writer.Write(seq.Count()); + foreach (T elem in seq) + { + serializer.Write(writer, elem); + } + } + + public static void Read(HpcBinaryReader reader, out IList seq) + { + int cnt = reader.ReadInt32(); + seq = new List(cnt); + for (int i = 0; i < cnt; i++) + { + seq.Add(serializer.Read(reader)); + } + } + + public static void Write(HpcBinaryWriter writer, IList seq) + { + writer.Write(seq.Count); + foreach (T elem in seq) + { + serializer.Write(writer, elem); + } + } + + public static void Read(HpcBinaryReader reader, out ForkValue val) + { + val = new ForkValue(); + if (reader.ReadBool()) + { + val.Value = serializer.Read(reader); + } + } + + public static void Write(HpcBinaryWriter writer, ForkValue val) + { + writer.Write(val.HasValue); + if (val.HasValue) + { + serializer.Write(writer, val.Value); + } + } + + public static void Read(HpcBinaryReader reader, out AggregateValue aggVal) + { + long cnt = reader.ReadInt64(); + T val = default(T); + if (cnt > 0) + { + val = serializer.Read(reader); + } + aggVal = new AggregateValue(val, cnt); + } + + public static void Write(HpcBinaryWriter writer, AggregateValue aggVal) + { + writer.Write(aggVal.Count); + if (aggVal.Count > 0) + { + serializer.Write(writer, aggVal.Value); + } + } + + public static void Read(HpcBinaryReader reader, out IndexedValue indexedVal) + { + int index = reader.ReadInt32(); + T val = serializer.Read(reader); + indexedVal = new IndexedValue(index, val); + } + + public static void Write(HpcBinaryWriter writer, IndexedValue indexedVal) + { + writer.Write(indexedVal.Index); + serializer.Write(writer, indexedVal.Value); + } + } + + public static class HpcSerialization + where S1 : HpcSerializer, new() + where S2 : HpcSerializer, new() + { + private static S1 serializer1 = new S1(); + private static S2 serializer2 = new S2(); + + public static void Read(HpcBinaryReader reader, out Dictionary dict) + { + int cnt = reader.ReadInt32(); + dict = new Dictionary(cnt); + for (int i = 0; i < cnt; i++) + { + T1 key = serializer1.Read(reader); + T2 val = serializer2.Read(reader); + dict.Add(key, val); + } + } + + public static void Write(HpcBinaryWriter writer, Dictionary dict) + { + writer.Write(dict.Count); + foreach (KeyValuePair elem in dict) + { + serializer1.Write(writer, elem.Key); + serializer2.Write(writer, elem.Value); + } + } + + public static void Read(HpcBinaryReader reader, out SortedDictionary dict) + { + int cnt = reader.ReadInt32(); + dict = new SortedDictionary(); + for (int i = 0; i < cnt; i++) + { + T1 key = serializer1.Read(reader); + T2 val = serializer2.Read(reader); + dict.Add(key, val); + } + } + + public static void Write(HpcBinaryWriter writer, SortedDictionary dict) + { + writer.Write(dict.Count); + foreach (KeyValuePair elem in dict) + { + serializer1.Write(writer, elem.Key); + serializer2.Write(writer, elem.Value); + } + } + + public static void Read(HpcBinaryReader reader, out SortedList list) + { + int cnt = reader.ReadInt32(); + list = new SortedList(cnt); + for (int i = 0; i < cnt; i++) + { + T1 key = serializer1.Read(reader); + T2 value = serializer2.Read(reader); + list.Add(key, value); + } + } + + public static void Write(HpcBinaryWriter writer, SortedList list) + { + writer.Write(list.Count); + foreach (KeyValuePair elem in list) + { + serializer1.Write(writer, elem.Key); + serializer2.Write(writer, elem.Value); + } + } + + public static void Read(HpcBinaryReader reader, out IGrouping group) + { + T1 key = serializer1.Read(reader); + int len = reader.ReadInt32(); + Grouping realGroup = new Grouping(key, len); + + for (int i = 0; i < len; i++) + { + realGroup.AddItem(serializer2.Read(reader)); + } + group = realGroup; + } + + public static void Write(HpcBinaryWriter writer, IGrouping group) + { + serializer1.Write(writer, group.Key); + writer.Write(group.Count()); + + foreach (T2 elem in group) + { + serializer2.Write(writer, elem); + } + } + + public static void Read(HpcBinaryReader reader, out KeyValuePair kv) + { + T1 key = serializer1.Read(reader); + T2 val = serializer2.Read(reader); + kv = new KeyValuePair(key, val); + } + + public static void Write(HpcBinaryWriter writer, KeyValuePair kv) + { + serializer1.Write(writer, kv.Key); + serializer2.Write(writer, kv.Value); + } + + public static void Read(HpcBinaryReader reader, out Pair pair) + { + T1 x = serializer1.Read(reader); + T2 y = serializer2.Read(reader); + pair = new Pair(x, y); + } + + public static void Write(HpcBinaryWriter writer, Pair pair) + { + serializer1.Write(writer, pair.Key); + serializer2.Write(writer, pair.Value); + } + + public static void Read(HpcBinaryReader reader, out ForkTuple val) + { + val = new ForkTuple(); + if (reader.ReadBool()) + { + val.First = serializer1.Read(reader); + } + if (reader.ReadBool()) + { + val.Second = serializer2.Read(reader); + } + } + + public static void Write(HpcBinaryWriter writer, ForkTuple val) + { + writer.Write(val.HasFirst); + if (val.HasFirst) + { + serializer1.Write(writer, val.First); + } + + writer.Write(val.HasSecond); + if (val.HasSecond) + { + serializer2.Write(writer, val.Second); + } + } + + public static void Read(HpcBinaryReader reader, out HpcLinqGrouping group) + { + T1 key = serializer1.Read(reader); + int cnt = reader.ReadInt32(); + T2[] elems = new T2[cnt]; + for (int i = 0; i < cnt; i++) + { + elems[i] = serializer2.Read(reader); + } + group = new HpcLinqGrouping(key, elems); + } + + public static void Write(HpcBinaryWriter writer, HpcLinqGrouping group) + { + serializer1.Write(writer, group.Key); + writer.Write(group.Count()); + foreach (T2 elem in group) + { + serializer2.Write(writer, elem); + } + } + } + + public sealed class ByteHpcSerializer : HpcSerializer + { + public override byte Read(HpcBinaryReader reader) + { + return reader.ReadUByte(); + } + + public override void Write(HpcBinaryWriter writer, byte x) + { + writer.Write(x); + } + } + + public sealed class SByteHpcSerializer : HpcSerializer + { + public override sbyte Read(HpcBinaryReader reader) + { + return reader.ReadSByte(); + } + + public override void Write(HpcBinaryWriter writer, sbyte x) + { + writer.Write(x); + } + } + + public sealed class BoolHpcSerializer : HpcSerializer + { + public override bool Read(HpcBinaryReader reader) + { + return reader.ReadBool(); + } + + public override void Write(HpcBinaryWriter writer, bool x) + { + writer.Write(x); + } + } + + public sealed class CharHpcSerializer : HpcSerializer + { + public override char Read(HpcBinaryReader reader) + { + return reader.ReadChar(); + } + + public override void Write(HpcBinaryWriter writer, char x) + { + writer.Write(x); + } + } + + public sealed class Int16HpcSerializer : HpcSerializer + { + public override Int16 Read(HpcBinaryReader reader) + { + return reader.ReadInt16(); + } + + public override void Write(HpcBinaryWriter writer, Int16 x) + { + writer.Write(x); + } + } + + public sealed class UInt16HpcSerializer : HpcSerializer + { + public override UInt16 Read(HpcBinaryReader reader) + { + return reader.ReadUInt16(); + } + + public override void Write(HpcBinaryWriter writer, UInt16 x) + { + writer.Write(x); + } + } + + public sealed class Int32HpcSerializer : HpcSerializer + { + public override Int32 Read(HpcBinaryReader reader) + { + return reader.ReadInt32(); + } + + public override void Write(HpcBinaryWriter writer, Int32 x) + { + writer.Write(x); + } + } + + public sealed class UInt32HpcSerializer : HpcSerializer + { + public override UInt32 Read(HpcBinaryReader reader) + { + return reader.ReadUInt32(); + } + + public override void Write(HpcBinaryWriter writer, UInt32 x) + { + writer.Write(x); + } + } + + public sealed class Int64HpcSerializer : HpcSerializer + { + public override Int64 Read(HpcBinaryReader reader) + { + return reader.ReadInt64(); + } + + public override void Write(HpcBinaryWriter writer, Int64 x) + { + writer.Write(x); + } + } + + public sealed class UInt64HpcSerializer : HpcSerializer + { + public override UInt64 Read(HpcBinaryReader reader) + { + return reader.ReadUInt64(); + } + + public override void Write(HpcBinaryWriter writer, UInt64 x) + { + writer.Write(x); + } + } + + public sealed class SingleHpcSerializer : HpcSerializer + { + public override float Read(HpcBinaryReader reader) + { + return reader.ReadSingle(); + } + + public override void Write(HpcBinaryWriter writer, float x) + { + writer.Write(x); + } + } + + public sealed class DoubleHpcSerializer : HpcSerializer + { + public override double Read(HpcBinaryReader reader) + { + return reader.ReadDouble(); + } + + public override void Write(HpcBinaryWriter writer, double x) + { + writer.Write(x); + } + } + + public sealed class DecimalHpcSerializer : HpcSerializer + { + public override decimal Read(HpcBinaryReader reader) + { + return reader.ReadDecimal(); + } + + public override void Write(HpcBinaryWriter writer, decimal x) + { + writer.Write(x); + } + } + + public sealed class DateTimeHpcSerializer : HpcSerializer + { + public override DateTime Read(HpcBinaryReader reader) + { + return reader.ReadDateTime(); + } + + public override void Write(HpcBinaryWriter writer, DateTime x) + { + writer.Write(x); + } + } + + public sealed class StringHpcSerializer : HpcSerializer + { + public override string Read(HpcBinaryReader reader) + { + return reader.ReadString(); + } + + public override void Write(HpcBinaryWriter writer, string x) + { + writer.Write(x); + } + } + + public sealed class GuidHpcSerializer : HpcSerializer + { + public override Guid Read(HpcBinaryReader reader) + { + return reader.ReadGuid(); + } + + public override void Write(HpcBinaryWriter writer, Guid x) + { + writer.Write(x); + } + } + + + public sealed class SqlDateTimeHpcSerializer : HpcSerializer + { + public override SqlDateTime Read(HpcBinaryReader reader) + { + return reader.ReadSqlDateTime(); + } + + public override void Write(HpcBinaryWriter writer, SqlDateTime x) + { + writer.Write(x); + } + } +} diff --git a/LinqToDryad/DryadLinqStream.cs b/LinqToDryad/DryadLinqStream.cs new file mode 100644 index 0000000..6408c7a --- /dev/null +++ b/LinqToDryad/DryadLinqStream.cs @@ -0,0 +1,151 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Diagnostics; + +namespace Microsoft.Research.DryadLinq +{ + internal class HpcLinqMultiInputStream : Stream + { + private HpcBinaryReader[] m_inputStreamArray; + private HpcBinaryReader m_curStream; + private Int32 m_nextStreamIdx; + + public HpcLinqMultiInputStream(HpcBinaryReader[] streamArray) + { + this.m_inputStreamArray = streamArray; + this.m_curStream = streamArray[0]; + this.m_nextStreamIdx = 1; + } + + ~HpcLinqMultiInputStream() + { + this.Close(); + } + + public override bool CanRead + { + get { return true; } + } + + public override bool CanWrite + { + get { return false; } + } + + public override bool CanSeek + { + get { return false; } + } + + public override long Length + { + get { + long len = 0; + for (int i = 0; i < this.m_inputStreamArray.Length; i++) + { + len += this.m_inputStreamArray[i].Length; + } + return len; + } + } + + public override long Position + { + get { throw new DryadLinqException(HpcLinqErrorCode.PositionNotSupported, + SR.PositionNotSupported); } + set { throw new DryadLinqException(HpcLinqErrorCode.PositionNotSupported, + SR.PositionNotSupported); } + } + + protected override void Dispose(bool disposing) + { + try + { + foreach (HpcBinaryReader s in this.m_inputStreamArray) + { + s.Close(); + } + } + finally + { + base.Dispose(disposing); + } + } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) + { + while (true) + { + int n = this.m_curStream.ReadBytes(buffer, offset, count); + if (n != 0) return n; + if (this.m_nextStreamIdx == this.m_inputStreamArray.Length) return 0; + this.m_curStream = this.m_inputStreamArray[this.m_nextStreamIdx++]; + } + } + + public override int ReadByte() + { + while (true) + { + int b = this.m_curStream.ReadUByte(); + if (b != -1) return b; + if (this.m_nextStreamIdx == this.m_inputStreamArray.Length) return -1; + this.m_curStream = this.m_inputStreamArray[this.m_nextStreamIdx++]; + } + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new DryadLinqException(HpcLinqErrorCode.SeekNotSupported, + SR.SeekNotSupported); + } + + public override void SetLength(long value) + { + throw new DryadLinqException(HpcLinqErrorCode.SetLengthNotSupported, + SR.SetLengthNotSupported); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new DryadLinqException(HpcLinqErrorCode.WriteNotSupported, + SR.WriteNotSupported); + } + + public override void WriteByte(byte value) + { + throw new DryadLinqException(HpcLinqErrorCode.WriteByteNotSupported, + SR.WriteByteNotSupported); + } + } +} diff --git a/LinqToDryad/DryadLinqUtil.cs b/LinqToDryad/DryadLinqUtil.cs new file mode 100644 index 0000000..2eb5c96 --- /dev/null +++ b/LinqToDryad/DryadLinqUtil.cs @@ -0,0 +1,636 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; +using System.IO; +using System.Runtime.Serialization; +using System.Runtime.Serialization.Formatters.Binary; +using System.Security.Cryptography; +using System.Text.RegularExpressions; +using System.Reflection; +using System.Linq; +using System.Diagnostics; +using Microsoft.CSharp; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + // This class contains some useful utility functions. + internal static class HpcLinqUtil + { + private static Regex s_CSharpIdRegex = new Regex(@"[^\p{Ll}\p{Lu}\p{Lt}\p{Lo}\p{Nd}\p{Nl}\p{Mn}\p{Mc}\p{Cf}\p{Pc}\p{Lm}]"); + private static CSharpCodeProvider s_CSharpCodeProvider = new CSharpCodeProvider(); + + // Check if name is a valid identifier. + internal static bool IsValidId(string name) + { + return s_CSharpCodeProvider.IsValidIdentifier(name); + } + + // Make name to be a valid identifier. + internal static string MakeValidId(string name) + { + return s_CSharpIdRegex.Replace(name, "_"); + } + + // Note: this is used in various places, including vertex-code for naming files. + internal static string MakeUniqueName() + { + return System.Guid.NewGuid().ToString(); + } + + internal static int GetTaskIndex(int hcode, int count) + { + int x1 = hcode & 0xFF; + int x2 = (hcode >> 8) & 0xFF; + int x3 = (hcode >> 16) & 0xFFFF; + return (x1 ^ x2 ^ x3) % count; + } + + // Check if the array is ordered. + internal static bool IsOrdered(T[] array, IComparer comparer, bool isDescending) + { + comparer = TypeSystem.GetComparer(comparer); + if (array.Length < 2) return true; + T elem = array[0]; + for (int i = 1; i < array.Length; i++) + { + int cmp = comparer.Compare(elem, array[i]); + int cmpRes = (isDescending) ? -cmp : cmp; + if (cmpRes > 0) return false; + elem = array[i]; + } + return true; + } + + /// + /// Binary search. The array must be ordered. Assume that comparer != null + /// The input is an array {a_i} of n items and a value to search for. + /// The output is a value in the range [0..n]. + /// The output can be directly used as the port-number to emit the item to for range-distribution. + /// + /// The type of the elements in the array. + /// The elements. + /// The value to be found. + /// A comparer to compare the elements of type T. + /// Whether the array is ordered descending (true) or ascending (false). + /// The left-most value that is greaterthan-or-equal to value (lessthan-or-equal for descending). + + // retVal is the idx of the item which satisfies + // + // Ascending: arr[idx-1] < x <= arr[idx] arr[-1] = -inf arr[n] = +inf + // idx is the index of the left-most element which is greater-or-equal than x + // The smallest retVal is 0 + // The largest retVal is n + // + // Descending: arr[idx-1] > x >= arr[idx] arr[-1] = +inf arr[n] = -inf + // idx is the index of the left-most element which is less-or-equal than x. + // The smallest retVal is 0 + // The largest retVal is n + // + // Although this isn't a strictly 'symmetrical rule', the "always left-most" is simple and consistent. + + internal static int BinarySearch(T[] array, T value, IComparer comparer, bool isDescending) + { + int lo = 0; + int hi = array.Length - 1; + while (lo <= hi) + { + int i = lo + ((hi - lo) >> 1); + int cmp = comparer.Compare(array[i], value); + + if (cmp == 0){ + // ensure we are on the left-most matching item. + // Note: linear-search isn't ideal if there are many equal values, but that should not be common. + while (i > 0 && comparer.Compare(array[i-1], value) == 0) + i--; + } + cmp = (isDescending) ? -cmp : cmp; + if (cmp < 0) + { + lo = i + 1; + } + else + { + hi = i - 1; + } + } + + return lo; + } + + /// + /// Mergesort a list of sorted datasets. + /// + /// The type of the records of the input datasets + /// The type of the keys to sort + /// An array of input datasets + /// The key extraction function + /// A Comparer on TKey to compare keys + /// True if the mergesort is descending + /// The mergesort result + internal static IEnumerable + MergeSort(IEnumerable[] sources, + Func keySelector, + IComparer comparer, + bool isDescending) + { + comparer = TypeSystem.GetComparer(comparer); + IEnumerable[] currentLayer = sources; + int currentLayerCount = sources.Length; + IEnumerable[] nextLayer = new IEnumerable[currentLayerCount / 2 + 1]; + while (currentLayerCount != 1) + { + int nextLayerCount = currentLayerCount / 2; + int idx = 0; + for (int i = 0; i < nextLayerCount; i++) + { + nextLayer[i] = BinaryMergeSort(currentLayer[idx], + currentLayer[idx + 1], + keySelector, + comparer, + isDescending); + idx += 2; + } + if (idx < currentLayerCount) + { + nextLayer[nextLayerCount] = currentLayer[idx]; + nextLayerCount++; + } + currentLayer = nextLayer; + currentLayerCount = nextLayerCount; + } + return currentLayer[0]; + } + + /// + /// Mergesort two input sorted datasets. + /// + /// The type of the records of the datasets + /// The type of the keys to sort + /// The first input dataset + /// The second input dataset + /// The key extraction function + /// A Comparer on TKey to compare keys + /// True if the sort is descending + /// The mergesort result + internal static IEnumerable + BinaryMergeSort(IEnumerable source1, + IEnumerable source2, + Func keySelector, + IComparer comparer, + bool isDescending) + { + comparer = TypeSystem.GetComparer(comparer); + + IEnumerator leftElems = source1.GetEnumerator(); + IEnumerator rightElems = source2.GetEnumerator(); + + if (leftElems.MoveNext()) + { + if (rightElems.MoveNext()) + { + TKey leftKey = keySelector(leftElems.Current); + TKey rightKey = keySelector(rightElems.Current); + while (true) + { + int cmp = comparer.Compare(leftKey, rightKey); + int cmpRes = (isDescending) ? -cmp : cmp; + if (cmpRes > 0) + { + yield return rightElems.Current; + if (!rightElems.MoveNext()) + { + yield return leftElems.Current; + break; + } + rightKey = keySelector(rightElems.Current); + } + else + { + yield return leftElems.Current; + if (!leftElems.MoveNext()) + { + yield return rightElems.Current; + leftElems = rightElems; + break; + } + leftKey = keySelector(leftElems.Current); + } + } + } + } + else + { + leftElems = rightElems; + } + + while (leftElems.MoveNext()) + { + yield return leftElems.Current; + } + } + + // Swap the bytes + internal static UInt64 ByteSwap(UInt64 x) + { + return (x << 56) + | (x >> 56) + | ((x & 0x0000ff00UL) << 40) + | ((x >> 40) & 0x0000ff00UL) + | ((x & 0x00ff0000UL) << 24) + | ((x >> 24) & 0x00ff0000UL) + | ((x & 0xff000000UL) << 8) + | ((x >> 8) & 0xff000000UL); + } + + private static byte[] ObjectToByteArray(Object objectToSerialize) + { + MemoryStream fs = new MemoryStream(); + BinaryFormatter formatter = new BinaryFormatter(); + try + { + formatter.Serialize(fs, objectToSerialize); + return fs.ToArray(); + } + catch (SerializationException se) + { + HpcClientSideLog.Add("Error occured during serialization. Message: " + se.Message); + throw; + } + finally + { + fs.Close(); + } + } + + private static Object ByteArrayToObject(byte[] rep) + { + MemoryStream fs = new MemoryStream(rep); + BinaryFormatter bfm = new BinaryFormatter(); + try + { + Object o = bfm.Deserialize(fs); + return o; + } + catch (SerializationException e) + { + throw new DryadLinqException(HpcLinqErrorCode.FailedToDeserialize, SR.FailedToDeserialize, e); + } + finally + { + fs.Close(); + } + } + + internal static string MD5(Object ob) + { + byte[] payload = ObjectToByteArray(ob); + MD5 md5 = new MD5CryptoServiceProvider(); + byte[] result = md5.ComputeHash(payload); + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < result.Length; i++) + { + sb.Append(result[i].ToString("X2")); + } + return sb.ToString(); + } + + internal static string ReplaceWithLast(Match m) + { + return m.Groups[1].Value; + } + + // matches each simple type name in a nested template type. + private static Regex dotted = new Regex(@"[^<>,]*\.([^\.<>,]*)", RegexOptions.Compiled); + internal static string SimpleName(string name) + { + // the typename may contain nested templates; simplify every component to keep the text after the last dot + string result = dotted.Replace(name, ReplaceWithLast); + return result; + } + + internal static string MapToString(Dictionary map) + { + StringBuilder sb = new StringBuilder(); + sb.Append('['); + bool isFirst = true; + foreach (KeyValuePair kv in map) + { + if (isFirst) + { + isFirst = false; + } + else + { + sb.AppendLine(","); + sb.Append(' '); + } + sb.Append(kv.Key); + sb.Append(" -> <"); + bool isFirst1 = true; + foreach (T v in kv.Value) + { + if (isFirst1) + { + isFirst1 = false; + } + sb.Append(v); + } + sb.Append('>'); + } + sb.Append(']'); + return sb.ToString(); + } + + // Unsafe memcpy from System.Buffer + internal unsafe static void memcpy(byte* src, byte* dest, int len) + { + if (len < 0) + { + throw new ArgumentException("len < 0", "len"); + } + +#if FEATURE_PAL + // Portable naive implementation + while (len-- > 0) + *dest++ = *src++; +#else + +#if IA64 + // IA64 implementation + long dstAlign = 8 - (((long)dest) & 7); // number of bytes to copy before dest is 8-byte aligned + + while ((dstAlign > 0) && (len > 0)) + { + *dest++ = *src++; + + len--; + dstAlign--; + } + + long srcAlign = 8 - (((long)src) & 7); + + if (len > 0) + { + if (srcAlign != 8) + { + if (4 == srcAlign) + { + while (len >= 4) + { + ((int*)dest)[0] = ((int*)src)[0]; + dest += 4; + src += 4; + len -= 4; + } + + srcAlign = 2; // fall through to 2-byte copies + } + + if ((2 == srcAlign) || (6 == srcAlign)) + { + while (len >= 2) + { + ((short*)dest)[0] = ((short*)src)[0]; + dest += 2; + src += 2; + len -= 2; + } + } + + while (len-- > 0) + { + *dest++ = *src++; + } + } + else + { + if (len >= 16) + { + do + { + ((long*)dest)[0] = ((long*)src)[0]; + ((long*)dest)[1] = ((long*)src)[1]; + dest += 16; + src += 16; + } while ((len -= 16) >= 16); + } + if (len > 0) // protection against negative len and optimization for len==16*N + { + if ((len & 8) != 0) + { + ((long*)dest)[0] = ((long*)src)[0]; + dest += 8; + src += 8; + } + if ((len & 4) != 0) + { + ((int*)dest)[0] = ((int*)src)[0]; + dest += 4; + src += 4; + } + if ((len & 2) != 0) + { + ((short*)dest)[0] = ((short*)src)[0]; + dest += 2; + src += 2; + } + if ((len & 1) != 0) + { + *dest++ = *src++; + } + } + } + } + +#else + // AMD64 implementation uses longs instead of ints where possible + if (len >= 16) + { + do + { +#if AMD64 + ((long*)dest)[0] = ((long*)src)[0]; + ((long*)dest)[1] = ((long*)src)[1]; +#else + ((int*)dest)[0] = ((int*)src)[0]; + ((int*)dest)[1] = ((int*)src)[1]; + ((int*)dest)[2] = ((int*)src)[2]; + ((int*)dest)[3] = ((int*)src)[3]; +#endif + dest += 16; + src += 16; + } while ((len -= 16) >= 16); + } + if (len > 0) // protection against negative len and optimization for len==16*N + { + if ((len & 8) != 0) + { +#if AMD64 + ((long*)dest)[0] = ((long*)src)[0]; +#else + ((int*)dest)[0] = ((int*)src)[0]; + ((int*)dest)[1] = ((int*)src)[1]; +#endif + dest += 8; + src += 8; + } + if ((len & 4) != 0) + { + ((int*)dest)[0] = ((int*)src)[0]; + dest += 4; + src += 4; + } + if ((len & 2) != 0) + { + ((short*)dest)[0] = ((short*)src)[0]; + dest += 2; + src += 2; + } + if ((len & 1) != 0) + *dest++ = *src++; + } + +#endif // IA64 +#endif // FEATURE_PAL + } + + + //Utility function to determine if a sequence of partition-keys is ascending or descending consistently + // if the sequence is neither ascending or descending, an exception is thrown. + // if the sequence is single element or a series of equal values, the return is null == inconclusive/both. + // if the return value == false, the value of isDescending is meaningless. + internal static bool ComputeIsDescending(TKey[] partitionKeys, IComparer comparer, out bool? isDescending) + { + if (partitionKeys.Length == 0 || partitionKeys.Length == 1) + { + isDescending = null; // neither specifically ascending nor descending. + return true; // everything is OK. + } + + // Determine if the keys are ascending/descending and whether they are consistent + isDescending = null; // initially we don't know (and equal keys may delay identification) + + TKey curr = partitionKeys[0]; + for (int i = 1; i < partitionKeys.Length; i++) + { + int cmp = comparer.Compare(curr, partitionKeys[i]); + + if (cmp == 0) + { + // do nothing + } + if (cmp < 0) + { + if (isDescending == null) + { + isDescending = false; // the sequence appears to be ascending + } + + if (isDescending == true) + { + return false; // there was an inconsistency. + } + } + if (cmp > 0) + { + if (isDescending == null) + { + isDescending = true; // the sequence appears to be descending + } + + if (isDescending == false) + { + return false; // there was an inconsistency. + } + + } + curr = partitionKeys[i]; + } + + return true; + } + } + + internal class FList + { + public T elem; + public FList next; + + internal static FList Empty = new FList(default(T), null); + + internal FList(T elem, FList next) + { + this.elem = elem; + this.next = next; + } + + internal FList Cons(T elem, FList next) + { + return new FList(elem, next); + } + + internal bool Find(T x) + { + FList curr = this; + while (curr != Empty) + { + if (curr.elem.Equals(x)) return true; + curr = curr.next; + } + return false; + } + + public override string ToString() + { + StringBuilder sb = new StringBuilder(); + sb.Append("< "); + if (this != Empty) + { + sb.Append(this.elem); + FList curr = this.next; + while (curr != Empty) + { + sb.Append(", " + curr.elem.ToString()); + curr = curr.next; + } + sb.Append(" >"); + } + return sb.ToString(); + } + } + + internal class Wrapper + { + internal T item; + + internal Wrapper(T item) + { + this.item = item; + } + } +} diff --git a/LinqToDryad/DryadLinqVertex.cs b/LinqToDryad/DryadLinqVertex.cs new file mode 100644 index 0000000..59101c6 --- /dev/null +++ b/LinqToDryad/DryadLinqVertex.cs @@ -0,0 +1,10607 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; +using System.IO; +using System.Reflection; +using System.Linq; +using System.Linq.Expressions; +using System.Data.Linq; +using System.Xml; +using System.Data.Linq.Mapping; +using System.Diagnostics; +using System.Threading; +using System.Data; +using System.Data.SqlClient; +using System.Collections.ObjectModel; +using System.Collections.Concurrent; +using System.Threading.Tasks; + +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq.Internal +{ + // This class contains the generic vertex code for each query operation + // supported by HpcLinq. We hope to support most of the LINQ operators. + // Applications can add their own vertex method using extension methods. + public static class HpcLinqVertex + { + public static bool s_multiThreading = true; //vertex code will set this at runtime. + + internal static IParallelPipeline + ExtendParallelPipeline(this IEnumerable source, + Func, IEnumerable> func, + bool orderPreserving) + { + IParallelPipeline pipe = source as IParallelPipeline; + IParallelPipeline result; + if (pipe == null) + { + result = new ParallelApply(source, func, orderPreserving); + } + else + { + result = pipe.Extend(func, orderPreserving); + } + return result; + } + + // Operator: Where + public static IEnumerable Where(IEnumerable source, + Func predicate, + bool orderPreserving) + { + if(s_multiThreading) + { + return source.ExtendParallelPipeline(s => s.Where(predicate), orderPreserving); + } + else + { + return System.Linq.Enumerable.Where(source, predicate); + } + } + + public static IEnumerable Where(IEnumerable source, + Func predicate, + bool orderPreserving) + { + return System.Linq.Enumerable.Where(source, predicate); + } + + public static IEnumerable LongWhere(IEnumerable source, + Func predicate, + bool orderPreserving) + { + return HpcLinqEnumerable.LongWhere(source, predicate); + } + + // Operator: Select + public static IEnumerable + Select(IEnumerable source, + Func selector, + bool orderPreserving) + { + if (s_multiThreading) + { + return source.ExtendParallelPipeline(s => s.Select(selector), orderPreserving); + } + else + { + return System.Linq.Enumerable.Select(source, selector); + } + } + + public static IEnumerable + Select(IEnumerable source, + Func selector, + bool orderPreserving) + { + return System.Linq.Enumerable.Select(source, selector); + } + + public static IEnumerable + LongSelect(IEnumerable source, + Func selector, + bool orderPreserving) + { + return HpcLinqEnumerable.LongSelect(source, selector); + } + + // Operator: SelectMany + public static IEnumerable + SelectMany(IEnumerable source, + Func> selector, + bool orderPreserving) + { + if (s_multiThreading) + { + return source.ExtendParallelPipeline(s => s.SelectMany(selector), orderPreserving); + } + else + { + return System.Linq.Enumerable.SelectMany(source, selector); + } + } + + public static IEnumerable + SelectMany(IEnumerable source, + Func> selector, + bool orderPreserving) + { + return System.Linq.Enumerable.SelectMany(source, selector); + } + + public static IEnumerable + SelectMany(IEnumerable source, + Func> collectionSelector, + Func resultSelector, + bool orderPreserving) + { + if (s_multiThreading) + { + return source.ExtendParallelPipeline(s => s.SelectMany(collectionSelector, resultSelector), orderPreserving); + } + else + { + return System.Linq.Enumerable.SelectMany(source, collectionSelector, resultSelector); + } + } + + public static IEnumerable + SelectMany(IEnumerable source, + Func> collectionSelector, + Func resultSelector, + bool orderPreserving) + { + return System.Linq.Enumerable.SelectMany(source, collectionSelector, resultSelector); + } + + public static IEnumerable + LongSelectMany(IEnumerable source, + Func> collectionSelector, + bool orderPreserving) + { + return HpcLinqEnumerable.LongSelectMany(source, collectionSelector); + } + + public static IEnumerable + LongSelectMany(IEnumerable source, + Func> collectionSelector, + Func resultSelector, + bool orderPreserving) + { + return HpcLinqEnumerable.LongSelectMany(source, collectionSelector, resultSelector); + } + + // Operator: Zip + private static IEnumerable> ZipToPairs(IEnumerable s1, + IEnumerable s2) + { + IEnumerator elems1 = s1.GetEnumerator(); + IEnumerator elems2 = s2.GetEnumerator(); + while (elems1.MoveNext() && elems2.MoveNext()) + { + yield return new Pair(elems1.Current, elems2.Current); + } + } + + public static IEnumerable Zip(IEnumerable s1, + IEnumerable s2, + Func zipper, + bool orderPreserving) + { + var pairs = ZipToPairs(s1, s2); + return Select(pairs, x => zipper(x.Key, x.Value), orderPreserving); + } + + // Operator: Take + public static IEnumerable + Take(IEnumerable source, int count) + { + return source.Take(count); + } + + // Operator: TakeWhile + public static IEnumerable + TakeWhile(IEnumerable, bool>> source) + { + foreach (Pair, bool> group in source) + { + foreach (TSource elem in group.Key) + { + yield return elem; + } + if (!group.Value) yield break; + } + } + + // Operator: Skip + public static IEnumerable + Skip(IEnumerable source, int count) + { + return source.Skip(count); + } + + // Operator: SkipWhile + public static IEnumerable + SkipWhile(IEnumerable source, Func predicate) + { + return source.SkipWhile(predicate); + } + + public static IEnumerable + SkipWhile(IEnumerable source, Func predicate) + { + return source.SkipWhile(predicate); + } + + public static IEnumerable + LongSkipWhile(IEnumerable source, + Func predicate) + { + long index = -1; + bool yielding = false; + using (IEnumerator sourceEnum = source.GetEnumerator()) + { + while (sourceEnum.MoveNext()) + { + checked { index++; } + if (!predicate(sourceEnum.Current, index)) + { + yielding = true; + break; + } + } + + if (yielding) + { + do + { + yield return sourceEnum.Current; + } + while (sourceEnum.MoveNext()); + } + } + } + + // Operator: OrderBy + public static IEnumerable + Sort(IEnumerable source, + Func keySelector, + IComparer comparer, + bool isDescending, + bool isIdKeySelector, + HpcLinqFactory factory) + { + if (s_multiThreading) + { + return new ParallelSort(source, keySelector, comparer, isDescending, isIdKeySelector, factory); + } + else + { + if (isDescending) + { + return Enumerable.OrderByDescending(source, keySelector, comparer); + } + else + { + return Enumerable.OrderBy(source, keySelector, comparer); + } + } + } + + public static IEnumerable + MergeSort(this IEnumerable source, + Func keySelector, + IComparer comparer, + bool isDescending) + { + IMultiEnumerable msource = source as IMultiEnumerable; + if (msource == null) + { + throw new DryadLinqException(HpcLinqErrorCode.SourceOfMergesortMustBeMultiEnumerable, + SR.SourceOfMergesortMustBeMultiEnumerable); + } + if (msource.NumberOfInputs == 1) + { + return source; + } + + if (s_multiThreading) + { + return new ParallelMergeSort(msource, keySelector, comparer, isDescending); + } + else + { + return SequentialMergeSort(msource, keySelector, comparer, isDescending); + } + } + + private static IEnumerable + SequentialMergeSort(IMultiEnumerable source, + Func keySelector, + IComparer comparer, + bool isDescending) + { + comparer = TypeSystem.GetComparer(comparer); + + // Initialize + IEnumerator[] readers = new IEnumerator[source.NumberOfInputs]; + for (int i = 0; i < readers.Length; i++) + { + readers[i] = source[i].GetEnumerator(); + } + + DryadLinqLog.Add("Sequential MergeSort started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + TSource[] elems = new TSource[readers.Length]; + TKey[] keys = new TKey[readers.Length]; + int lastIdx = readers.Length - 1; + int readerCnt = 0; + while (readerCnt <= lastIdx) + { + elems[readerCnt] = default(TSource); + if (readers[readerCnt].MoveNext()) + { + elems[readerCnt] = readers[readerCnt].Current; + keys[readerCnt] = keySelector(elems[readerCnt]); + readerCnt++; + } + else + { + readers[readerCnt].Dispose(); + if (readerCnt == lastIdx) break; + readers[readerCnt] = readers[lastIdx]; + lastIdx--; + } + } + + // Merge sort + while (readerCnt > 0) + { + TKey key = keys[0]; + int idx = 0; + for (int i = 1; i < readerCnt; i++) + { + int cmp = comparer.Compare(key, keys[i]); + int cmpRes = (isDescending) ? -cmp : cmp; + if (cmpRes > 0) + { + key = keys[i]; + idx = i; + } + } + + yield return elems[idx]; + + if (readers[idx].MoveNext()) + { + elems[idx] = readers[idx].Current; + keys[idx] = keySelector(elems[idx]); + } + else + { + readers[idx].Dispose(); + readerCnt--; + if (idx < readerCnt) + { + readers[idx] = readers[readerCnt]; + elems[idx] = elems[readerCnt]; + keys[idx] = keys[readerCnt]; + } + } + } + + DryadLinqLog.Add("Sequential MergeSort ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + + // Operator: ThenBy + public static IEnumerable + ThenBy(IEnumerable source, + Func keySelector, + IComparer comparer, + bool isDescending) + { + throw new DryadLinqException(HpcLinqErrorCode.ThenByNotSupported, SR.ThenByNotSupported); + } + + // Operator: GroupBy + public static IEnumerable> + GroupBy( + IEnumerable source, + Func keySelector, + Func seed, + Func accumulator, + IEqualityComparer comparer, + bool isPartial) + { + return GroupBy(source, keySelector, x => x, seed, accumulator, comparer, isPartial); + } + + public static IEnumerable> + GroupBy( + IEnumerable source, + Func keySelector, + Func elementSelector, + Func seed, + Func accumulator, + IEqualityComparer comparer, + bool isPartial) + { + if (s_multiThreading) + { + if (isPartial) + { + if (source is IParallelApply) + { + IParallelApply parSource = (IParallelApply)source; + return parSource.ExtendGroupBy(keySelector, elementSelector, seed, accumulator, comparer); + } + else + { + return new ParallelHashGroupByPartialAccumulate( + source, null, keySelector, elementSelector, seed, accumulator, comparer); + } + } + else + { + return new ParallelHashGroupByFullAccumulate>( + source, keySelector, elementSelector, seed, accumulator, comparer, null); + } + } + else + { + return SequentialHashGroupBy(source, keySelector, elementSelector, seed, accumulator, comparer); + } + } + + private static IEnumerable> + SequentialHashGroupBy( + IEnumerable source, + Func keySelector, + Func elementSelector, + Func seed, + Func accumulator, + IEqualityComparer comparer) + { + DryadLinqLog.Add("Sequential HashGroupBy (Acc) started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + AccumulateDictionary + groups = new AccumulateDictionary(comparer, 16411, seed, accumulator); + foreach (TSource item in source) + { + groups.Add(keySelector(item), elementSelector(item)); + } + + DryadLinqLog.Add("Sequential HashGroupBy (Acc) ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + return groups; + } + + public static IEnumerable> + GroupBy(IEnumerable source, + Func keySelector, + IEqualityComparer comparer) + { + if (comparer == null) + { + comparer = EqualityComparer.Default; + } + + GroupingHashSet groupings = new GroupingHashSet(comparer, 16411); + foreach (TSource item in source) + { + groupings.AddItem(keySelector(item), item); + } + return groupings; + } + + public static IEnumerable + GroupBy(IEnumerable source, + Func keySelector, + Func, TResult> resultSelector, + IEqualityComparer comparer) + { + var groupings = GroupBy(source, keySelector, comparer); + if (s_multiThreading) + { + return new ParallelApply, TResult>( + groupings, s => s.Select(g => resultSelector(g.Key, g)), false); + } + else + { + return Apply(groupings, s => s.Select(g => resultSelector(g.Key, g))); + } + } + + public static IEnumerable> + GroupBy(IEnumerable source, + Func keySelector, + Func elementSelector, + IEqualityComparer comparer) + { + if (comparer == null) + { + comparer = EqualityComparer.Default; + } + + GroupingHashSet groupings = new GroupingHashSet(comparer, 16411); + foreach (TSource item in source) + { + groupings.AddItem(keySelector(item), elementSelector(item)); + } + return groupings; + } + + public static IEnumerable + GroupBy(IEnumerable source, + Func keySelector, + Func elementSelector, + Func, TResult> resultSelector, + IEqualityComparer comparer) + { + var groupings = GroupBy(source, keySelector, elementSelector, comparer); + if (s_multiThreading) + { + return new ParallelApply, TResult>( + groupings, s => s.Select(g => resultSelector(g.Key, g)), false); + } + else + { + return Apply(groupings, s => s.Select(g => resultSelector(g.Key, g))); + } + } + + public static IEnumerable> + OrderedGroupBy(IEnumerable source, + Func keySelector, + Func seed, + Func accumulator, + IEqualityComparer comparer, + bool isPartial) + { + return OrderedGroupBy(source, keySelector, x => x, seed, accumulator, comparer, isPartial); + } + + public static IEnumerable> + OrderedGroupBy(IEnumerable source, + Func keySelector, + Func elementSelector, + Func seed, + Func accumulator, + IEqualityComparer comparer, + bool isPartial) + { + if (s_multiThreading) + { + return new ParallelOrderedGroupByAccumulate>( + source, keySelector, elementSelector, seed, accumulator, comparer, null); + } + else + { + return SequentialOrderedGroupBy(source, keySelector, elementSelector, seed, accumulator, comparer); + } + } + + private static IEnumerable> + SequentialOrderedGroupBy( + IEnumerable source, + Func keySelector, + Func elementSelector, + Func seed, + Func accumulator, + IEqualityComparer comparer) + { + DryadLinqLog.Add("Sequential OrderedGroupBy (Acc) started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + using (IEnumerator elems = source.GetEnumerator()) + { + if (elems.MoveNext()) + { + TKey curKey = keySelector(elems.Current); + TResult curValue = seed(elementSelector(elems.Current)); + + while (elems.MoveNext()) + { + if (comparer.Equals(curKey, keySelector(elems.Current))) + { + curValue = accumulator(curValue, elementSelector(elems.Current)); + } + else + { + yield return new Pair(curKey, curValue); + curKey = keySelector(elems.Current); + curValue = seed(elementSelector(elems.Current)); + } + } + + yield return new Pair(curKey, curValue); + } + } + + DryadLinqLog.Add("Sequential OrderedGroupBy (Acc) ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + + public static IEnumerable> + OrderedGroupBy(IEnumerable source, + Func keySelector, + IEqualityComparer comparer) + { + if (comparer == null) + { + comparer = EqualityComparer.Default; + } + + DryadLinqLog.Add("Sequential OrderedGroupBy started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + using (IEnumerator elems = source.GetEnumerator()) + { + Grouping curGroup; + if (elems.MoveNext()) + { + curGroup = new Grouping(keySelector(elems.Current)); + curGroup.AddItem(elems.Current); + + while (elems.MoveNext()) + { + if (!comparer.Equals(curGroup.Key, keySelector(elems.Current))) + { + yield return curGroup; + curGroup = new Grouping(keySelector(elems.Current)); + } + curGroup.AddItem(elems.Current); + } + yield return curGroup; + } + } + + DryadLinqLog.Add("Sequential OrderedGroupBy ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + + public static IEnumerable> + OrderedGroupBy(IEnumerable source, + Func keySelector, + Func elementSelector, + IEqualityComparer comparer) + { + if (comparer == null) + { + comparer = EqualityComparer.Default; + } + + DryadLinqLog.Add("Sequential OrderedGroupBy started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + using (IEnumerator elems = source.GetEnumerator()) + { + Grouping curGroup; + if (elems.MoveNext()) + { + curGroup = new Grouping(keySelector(elems.Current)); + curGroup.AddItem(elementSelector(elems.Current)); + + while (elems.MoveNext()) + { + if (!comparer.Equals(curGroup.Key, keySelector(elems.Current))) + { + yield return curGroup; + curGroup = new Grouping(keySelector(elems.Current)); + } + curGroup.AddItem(elementSelector(elems.Current)); + } + yield return curGroup; + } + } + + DryadLinqLog.Add("Sequential OrderedGroupBy ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + + public static IEnumerable + OrderedGroupBy(IEnumerable source, + Func keySelector, + Func, TResult> resultSelector, + IEqualityComparer comparer) + { + var groupings = OrderedGroupBy(source, keySelector, comparer); + if (s_multiThreading) + { + return new ParallelApply, TResult>( + groupings, s => s.Select(g => resultSelector(g.Key, g)), true); + } + else + { + return Apply(groupings, s => s.Select(g => resultSelector(g.Key, g))); + } + + } + + public static IEnumerable + OrderedGroupBy( + IEnumerable source, + Func keySelector, + Func elementSelector, + Func, TResult> resultSelector, + IEqualityComparer comparer) + { + var groupings = OrderedGroupBy(source, keySelector, elementSelector, comparer); + if (s_multiThreading) + { + return new ParallelApply, TResult>( + groupings, s => s.Select(g => resultSelector(g.Key, g)), true); + } + else + { + return Apply(groupings, s => s.Select(g => resultSelector(g.Key, g))); + } + } + + // Operator: Join + internal static IEnumerable + SequentialHashJoin( + IEnumerable outer, + IEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func resultSelector, + IEqualityComparer comparer) + { + if (comparer == null) + { + comparer = EqualityComparer.Default; + } + + DryadLinqLog.Add("Sequential HashJoin started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + bool hashInner = true; + if ((outer is HpcVertexReader) && (inner is HpcVertexReader)) + { + Int64 outerLen = ((HpcVertexReader)outer).GetTotalLength(); + Int64 innerLen = ((HpcVertexReader)inner).GetTotalLength(); + if (innerLen >= 0 && outerLen >= 0) + { + hashInner = innerLen <= outerLen; + } + } + + if (hashInner) + { + // Create a hash lookup table using inner + GroupingHashSet innerGroupings = new GroupingHashSet(comparer); + foreach (TInner innerItem in inner) + { + innerGroupings.AddItem(innerKeySelector(innerItem), innerItem); + } + foreach (TOuter outerItem in outer) + { + Grouping innerGroup = innerGroupings.GetGroup(outerKeySelector(outerItem)); + if (innerGroup != null) + { + TInner[] items = innerGroup.Elements; + for (int i = 0; i < innerGroup.Count(); i++) + { + yield return resultSelector(outerItem, items[i]); + } + } + } + } + else + { + // Create a hash lookup table using outer + GroupingHashSet outerGroupings = new GroupingHashSet(comparer); + foreach (TOuter outerItem in outer) + { + outerGroupings.AddItem(outerKeySelector(outerItem), outerItem); + } + + foreach (TInner innerItem in inner) + { + Grouping outerGroup = outerGroupings.GetGroup(innerKeySelector(innerItem)); + if (outerGroup != null) + { + TOuter[] items = outerGroup.Elements; + for (int i = 0; i < outerGroup.Count(); i++) + { + yield return resultSelector(items[i], innerItem); + } + } + } + } + + DryadLinqLog.Add("Sequential HashJoin ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + + // Perform a hash join. + public static IEnumerable + HashJoin(IEnumerable outer, + IEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func resultSelector, + IEqualityComparer comparer) + { + if (s_multiThreading) + { + return new ParallelHashJoin( + outer, inner, outerKeySelector, innerKeySelector, + resultSelector, comparer, null); + } + else + { + return SequentialHashJoin(outer, inner, outerKeySelector, innerKeySelector, + resultSelector, comparer); + } + } + + public static IEnumerable + HashJoin(IEnumerable outer, + IEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func resultSelector, + IEqualityComparer comparer, + Func, IEnumerable> applyFunc) + { + if (s_multiThreading) + { + return new ParallelHashJoin( + outer, inner, outerKeySelector, innerKeySelector, + resultSelector, comparer, applyFunc); + } + else + { + var results = SequentialHashJoin(outer, inner, outerKeySelector, innerKeySelector, + resultSelector, comparer); + return applyFunc(results); + } + } + + // Perform a merge join + // Precondition: both outer and inner inputs are sorted based on TKey + public static IEnumerable + MergeJoin(IEnumerable outer, + IEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func resultSelector, + IComparer comparer, + bool isDescending) + { + var joinPairs = MergeJoin(outer, inner, outerKeySelector, innerKeySelector, comparer, isDescending); + + if (s_multiThreading) + { + return new ParallelApply, TResult>( + joinPairs, s => s.Select(x => resultSelector(x.Key, x.Value)), true); + } + else + { + return Apply(joinPairs, s => s.Select(x => resultSelector(x.Key, x.Value))); + } + } + + public static IEnumerable + MergeJoin( + IEnumerable outer, + IEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func resultSelector, + IComparer comparer, + bool isDescending, + Func, IEnumerable> applyFunc) + { + var joinPairs = MergeJoin(outer, inner, outerKeySelector, innerKeySelector, comparer, isDescending); + if (s_multiThreading) + { + return new ParallelApply, TFinal>( + joinPairs, s => applyFunc(s.Select(x => resultSelector(x.Key, x.Value))), true); + } + else + { + return Apply(joinPairs, s => applyFunc(s.Select(x => resultSelector(x.Key, x.Value)))); + } + } + + public static IEnumerable> + MergeJoin(IEnumerable outer, + IEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + IComparer comparer, + bool isDescending) + { + comparer = TypeSystem.GetComparer(comparer); + + DryadLinqLog.Add("Sequential MergeJoin started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + IEnumerator outerEnum = outer.GetEnumerator(); + IEnumerator innerEnum = inner.GetEnumerator(); + TOuter[] outerItemArray = new TOuter[4]; + TInner[] innerItemArray = new TInner[4]; + bool outerHasMoreWork = outerEnum.MoveNext(); + bool innerHasMoreWork = innerEnum.MoveNext(); + while (outerHasMoreWork && innerHasMoreWork) + { + TOuter outerItem = outerEnum.Current; + TInner innerItem = innerEnum.Current; + TKey outerKey = outerKeySelector(outerItem); + TKey innerKey = innerKeySelector(innerItem); + + int cmpResult = comparer.Compare(outerKey, innerKey); + cmpResult = (isDescending) ? -cmpResult : cmpResult; + if (cmpResult < 0) + { + outerHasMoreWork = outerEnum.MoveNext(); + } + else if (cmpResult > 0) + { + innerHasMoreWork = innerEnum.MoveNext(); + } + else + { + // Get all the outer items with the same key: + outerItemArray[0] = outerItem; + int outerCnt = 1; + while (true) + { + outerHasMoreWork = outerEnum.MoveNext(); + if (!outerHasMoreWork || + comparer.Compare(outerKey, outerKeySelector(outerEnum.Current)) != 0) + { + break; + } + if (outerCnt == outerItemArray.Length) + { + TOuter[] newOuterItemArray = new TOuter[outerItemArray.Length * 2]; + Array.Copy(outerItemArray, 0, newOuterItemArray, 0, outerItemArray.Length); + outerItemArray = newOuterItemArray; + } + outerItemArray[outerCnt++] = outerEnum.Current; + } + + // Get all the inner items with the same key: + innerItemArray[0] = innerItem; + int innerCnt = 1; + while (true) + { + innerHasMoreWork = innerEnum.MoveNext(); + if (!innerHasMoreWork || + comparer.Compare(innerKey, innerKeySelector(innerEnum.Current)) != 0) + { + break; + } + if (innerCnt == innerItemArray.Length) + { + TInner[] newInnerItemArray = new TInner[innerItemArray.Length * 2]; + Array.Copy(innerItemArray, 0, newInnerItemArray, 0, innerItemArray.Length); + innerItemArray = newInnerItemArray; + } + innerItemArray[innerCnt++] = innerEnum.Current; + } + + // Yield items: + for (int i = 0; i < outerCnt; i++) + { + for (int j = 0; j < innerCnt; j++) + { + yield return new Pair(outerItemArray[i], innerItemArray[j]); + } + } + } + } + + DryadLinqLog.Add("Sequential MergeJoin ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + + internal static IEnumerable + SequentialHashGroupJoin( + IEnumerable outer, + IEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func, TResult> resultSelector, + IEqualityComparer comparer) + { + if (comparer == null) + { + comparer = EqualityComparer.Default; + } + + // Create a hash lookup table using inner. It is hard to do the same + // optimization as Join, because resultSelector is not symemtric. + GroupingHashSet innerGroupings = new GroupingHashSet(comparer); + foreach (TInner innerItem in inner) + { + innerGroupings.AddItem(innerKeySelector(innerItem), innerItem); + } + + TInner[] emptyGroup = new TInner[0]; + foreach (TOuter outerItem in outer) + { + IEnumerable innerGroup = innerGroupings.GetGroup(outerKeySelector(outerItem)); + if (innerGroup == null) + { + innerGroup = emptyGroup; + } + yield return resultSelector(outerItem, innerGroup); + } + } + + + public static IEnumerable + HashGroupJoin(IEnumerable outer, + IEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func, TResult> resultSelector, + IEqualityComparer comparer) + { + if (s_multiThreading) + { + return new ParallelHashGroupJoin( + outer, inner, outerKeySelector, innerKeySelector, + resultSelector, comparer, null); + } + else + { + return SequentialHashGroupJoin(outer, inner, outerKeySelector, innerKeySelector, + resultSelector, comparer); + } + } + + public static IEnumerable + HashGroupJoin( + IEnumerable outer, + IEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func, TResult> resultSelector, + IEqualityComparer comparer, + Func, IEnumerable> applyFunc) + { + if (s_multiThreading) + { + return new ParallelHashGroupJoin( + outer, inner, outerKeySelector, innerKeySelector, + resultSelector, comparer, applyFunc); + } + else + { + var results = SequentialHashGroupJoin(outer, inner, outerKeySelector, innerKeySelector, + resultSelector, comparer); + return applyFunc(results); + } + } + + // Precondition: both outer and inner inputs are sorted based on TKey + public static IEnumerable + MergeGroupJoin(IEnumerable outer, + IEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func, TResult> resultSelector, + IComparer comparer, + bool isDescending) + { + var joinPairs = MergeGroupJoin(outer, inner, outerKeySelector, innerKeySelector, comparer, isDescending); + + if (s_multiThreading) + { + return new ParallelApply>, TResult>( + joinPairs, s => s.Select(x => resultSelector(x.Key, x.Value)), true); + } + else + { + return Apply(joinPairs, s => s.Select(x => resultSelector(x.Key, x.Value))); + } + } + + public static IEnumerable + MergeGroupJoin(IEnumerable outer, + IEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func, TResult> resultSelector, + IComparer comparer, + bool isDescending, + Func, IEnumerable> applyFunc) + { + var joinPairs = MergeGroupJoin(outer, inner, outerKeySelector, innerKeySelector, comparer, isDescending); + + if (s_multiThreading) + { + return new ParallelApply>, TFinal>( + joinPairs, s => applyFunc(s.Select(x => resultSelector(x.Key, x.Value))), true); + } + else + { + return Apply(joinPairs, s => applyFunc(s.Select(x => resultSelector(x.Key, x.Value)))); + } + } + + public static IEnumerable>> + MergeGroupJoin(IEnumerable outer, + IEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + IComparer comparer, + bool isDescending) + { + comparer = TypeSystem.GetComparer(comparer); + + DryadLinqLog.Add("Sequential MergeGroupJoin started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + IEnumerator outerEnum = outer.GetEnumerator(); + IEnumerator innerEnum = inner.GetEnumerator(); + List innerItemList = new List(8); + bool hasMoreWork = outerEnum.MoveNext() && innerEnum.MoveNext(); + while (hasMoreWork) + { + TOuter outerItem = outerEnum.Current; + TInner innerItem = innerEnum.Current; + TKey outerKey = outerKeySelector(outerItem); + TKey innerKey = innerKeySelector(innerItem); + int cmpResult = comparer.Compare(outerKey, innerKey); + cmpResult = (isDescending) ? -cmpResult : cmpResult; + if (cmpResult < 0) + { + hasMoreWork = outerEnum.MoveNext(); + } + else if (cmpResult > 0) + { + hasMoreWork = innerEnum.MoveNext(); + } + else + { + // Get all the inner items with the same key: + innerItemList.Add(innerItem); + while (true) + { + hasMoreWork = innerEnum.MoveNext(); + if (!hasMoreWork || + comparer.Compare(innerKey, innerKeySelector(innerEnum.Current)) != 0) + { + break; + } + innerItemList.Add(innerEnum.Current); + } + + // Yield items: + while (true) + { + yield return new Pair>(outerItem, innerItemList); + hasMoreWork = outerEnum.MoveNext(); + if (!hasMoreWork || + comparer.Compare(outerKey, outerKeySelector(outerEnum.Current)) != 0) + { + break; + } + outerItem = outerEnum.Current; + } + + innerItemList.Clear(); + } + } + + DryadLinqLog.Add("Sequential MergeGroupJoin ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + + // Operator: Concat + public static IEnumerable + Concat(IEnumerable source1, IEnumerable source2) + { + return System.Linq.Enumerable.Concat(source1, source2); + } + + // Operator: Distinct + public static IEnumerable + Distinct(IEnumerable source, + IEqualityComparer comparer, + Func, IEnumerable> applyFunc, + bool isPartial) + { + if (s_multiThreading) + { + return new ParallelSetOperation( + "Distinct", source, null, comparer, applyFunc, isPartial); + } + else + { + var results = Enumerable.Distinct(source, comparer); + return applyFunc(results); + } + } + + public static IEnumerable + Distinct(IEnumerable source, + IEqualityComparer comparer, + bool isPartial) + { + if (s_multiThreading) + { + return new ParallelSetOperation( + "Distinct", source, null, comparer, null, isPartial); + } + else + { + return Enumerable.Distinct(source, comparer); + } + } + + // Operator: Union + public static IEnumerable + Union(IEnumerable source1, + IEnumerable source2) + { + if (s_multiThreading) + { + return new ParallelSetOperation( + "Union", source1, source2, null, null, false); + } + else + { + return System.Linq.Enumerable.Union(source1, source2); + } + } + + public static IEnumerable + Union(IEnumerable source1, + IEnumerable source2, + IEqualityComparer comparer) + { + if (s_multiThreading) + { + return new ParallelSetOperation( + "Union", source1, source2, comparer, null, false); + } + else + { + return System.Linq.Enumerable.Union(source1, source2, comparer); + } + } + + /// + /// Performs the union of two ordered sources. It is like mergesort, but removes duplicates. + /// + /// Type of elements to union. + /// Left sorted stream to union. + /// Right sorted stream to union. + /// The union of all elements, in sorted order. + public static IEnumerable + OrderedUnion(IEnumerable source1, + IEnumerable source2, + bool isDescending) + { + return OrderedUnion(source1, source2, null, isDescending); + } + + /// + /// Performs the union of two ordered sources. It is like mergesort, but removes duplicates. + /// + /// Type of elements to union. + /// Left sorted stream to union. + /// Right sorted stream to union. + /// Comparison function to use for TSource. + /// The union of all elements, in sorted order. + public static IEnumerable + OrderedUnion(IEnumerable source1, + IEnumerable source2, + IComparer comparer, + bool isDescending) + { + comparer = TypeSystem.GetComparer(comparer); + IEnumerator enum1 = source1.GetEnumerator(); + IEnumerator enum2 = source2.GetEnumerator(); + bool hasMoreWork1 = enum1.MoveNext(); + bool hasMoreWork2 = enum2.MoveNext(); + TSource item1 = default(TSource); + TSource item2 = default(TSource); + + if (hasMoreWork1) + { + item1 = enum1.Current; + } + + if (hasMoreWork2) + { + item2 = enum2.Current; + } + + while (hasMoreWork1 && hasMoreWork2) + { + int cmpResult = comparer.Compare(item1, item2); + cmpResult = (isDescending) ? -cmpResult : cmpResult; + if (cmpResult <= 0) + { + yield return item1; + + // skip duplicates: + TSource item3 = item1; + while (hasMoreWork1 = enum1.MoveNext()) + { + item1 = enum1.Current; + if (comparer.Compare(item1, item3) != 0) break; + } + if (cmpResult == 0) + { + while (hasMoreWork2 = enum2.MoveNext()) + { + item2 = enum2.Current; + if (comparer.Compare(item2, item3) != 0) break; + } + } + } + else + { + yield return item2; + + // skip duplicates: + TSource item3 = item2; + while (hasMoreWork2 = enum2.MoveNext()) + { + item2 = enum2.Current; + if (comparer.Compare(item2, item3) != 0) break; + } + } + } + + // yield the remaining items: + if (hasMoreWork2) + { + //switch enum2 over to enum1 to simplify the next block. + hasMoreWork1 = true; + enum1 = enum2; + item1 = item2; + } + + if (hasMoreWork1) + { + while (true) + { + yield return item1; + + // skip duplicates: + item2 = item1; + while (true) + { + if (!enum1.MoveNext()) yield break; + item1 = enum1.Current; + if (comparer.Compare(item1, item2) != 0) break; + } + } + } + } + + // Operator: Intersect + public static IEnumerable + Intersect(IEnumerable source1, + IEnumerable source2) + { + if (s_multiThreading) + { + return new ParallelSetOperation( + "Intersect", source1, source2, null, null, false); + } + else + { + return System.Linq.Enumerable.Intersect(source1, source2); + } + } + + public static IEnumerable + Intersect(IEnumerable source1, + IEnumerable source2, + IEqualityComparer comparer) + { + if (s_multiThreading) + { + return new ParallelSetOperation( + "Intersect", source1, source2, comparer, null, false); + } + else + { + return System.Linq.Enumerable.Intersect(source1, source2, comparer); + } + } + + /// + /// Compute the intersection of two ordered sources. Like mergesort, but only keeps common values. + /// + /// Type of elements to intersect. + /// Left sorted stream of values. + /// Right sorted stream of values. + /// + public static IEnumerable + OrderedIntersect(IEnumerable source1, + IEnumerable source2, + bool isDescending) + { + return OrderedIntersect(source1, source2, null, isDescending); + } + + /// + /// Compute the intersection between two ordered sets. Like mergesort, but only keeps common values. + /// + /// Type of elements to intersect. + /// Left sorted stream of values. + /// Right sorted stream of values. + /// Comparison function to use. + /// + public static IEnumerable + OrderedIntersect(IEnumerable source1, + IEnumerable source2, + IComparer comparer, + bool isDescending) + { + comparer = TypeSystem.GetComparer(comparer); + + IEnumerator enum1 = source1.GetEnumerator(); + IEnumerator enum2 = source2.GetEnumerator(); + bool hasMoreWork = enum1.MoveNext() && enum2.MoveNext(); + if (hasMoreWork) + { + TSource item1 = enum1.Current; + TSource item2 = enum2.Current; + do + { + int cmpResult = comparer.Compare(item1, item2); + if (cmpResult == 0) + { + yield return item1; + + // skip duplicates: + TSource item3 = item1; + while (true) + { + hasMoreWork = enum1.MoveNext(); + if (!hasMoreWork) yield break; + item1 = enum1.Current; + if (comparer.Compare(item1, item3) != 0) break; + } + while (true) + { + hasMoreWork = enum2.MoveNext(); + if (!hasMoreWork) yield break; + item2 = enum2.Current; + if (comparer.Compare(item3, item2) != 0) break; + } + } + else + { + cmpResult = (isDescending) ? -cmpResult : cmpResult; + if (cmpResult < 0) + { + hasMoreWork = enum1.MoveNext(); + item1 = enum1.Current; + } + else + { + hasMoreWork = enum2.MoveNext(); + item2 = enum2.Current; + } + } + } + while (hasMoreWork); + } + } + + // Operator: Except + public static IEnumerable + Except(IEnumerable source1, + IEnumerable source2) + { + if (s_multiThreading) + { + return new ParallelSetOperation( + "Except", source1, source2, null, null, false); + } + else + { + return System.Linq.Enumerable.Except(source1, source2); + } + } + + public static IEnumerable + Except(IEnumerable source1, + IEnumerable source2, + IEqualityComparer comparer) + { + if (s_multiThreading) + { + return new ParallelSetOperation( + "Except", source1, source2, comparer, null, false); + } + else + { + return System.Linq.Enumerable.Except(source1, source2, comparer); + } + } + + /// + /// Perform a set difference between two ordered sources. + /// + /// Type of elements to compare. + /// Sorted stream from which subtraction occurs. + /// Subtracted sorted stream. + /// Elements in left steram not ocurring in right stream. + public static IEnumerable + OrderedExcept(IEnumerable source1, + IEnumerable source2, + bool isDescending) + { + return OrderedExcept(source1, source2, null, isDescending); + } + + /// + /// Perform a set difference between two ordered sources. + /// + /// Type of elements to compare. + /// Sorted stream from which subtraction occurs. + /// Subtracted sorted stream. + /// Function to use for comparison testing. + /// Elements in left steram not ocurring in right stream. + public static IEnumerable + OrderedExcept(IEnumerable source1, + IEnumerable source2, + IComparer comparer, + bool isDescending) + { + comparer = TypeSystem.GetComparer(comparer); + + IEnumerator enum1 = source1.GetEnumerator(); + IEnumerator enum2 = source2.GetEnumerator(); + bool hasMoreWork1 = enum1.MoveNext(); + bool hasMoreWork2 = enum2.MoveNext(); + if (hasMoreWork1) + { + TSource item1 = enum1.Current; + while (hasMoreWork2) + { + TSource item2 = enum2.Current; + int cmpResult = comparer.Compare(item1, item2); + if (cmpResult == 0) + { + // skip duplicates: + TSource item3 = item1; + while (true) + { + if (!enum1.MoveNext()) yield break; + item1 = enum1.Current; + if (comparer.Compare(item1, item3) != 0) break; + } + while (hasMoreWork2 = enum2.MoveNext()) + { + item2 = enum2.Current; + if (comparer.Compare(item2, item3) != 0) break; + } + } + else + { + cmpResult = (isDescending) ? -cmpResult : cmpResult; + if (cmpResult < 0) + { + yield return item1; + + // skip duplicates: + TSource item3 = item1; + while (true) + { + if (!enum1.MoveNext()) yield break; + item1 = enum1.Current; + if (comparer.Compare(item1, item3) != 0) break; + } + } + else + { + hasMoreWork2 = enum2.MoveNext(); + } + } + } + + // yield the remaining items: + while (true) + { + yield return item1; + + // skip duplicates: + TSource item2 = item1; + while (true) + { + if (!enum1.MoveNext()) yield break; + item1 = enum1.Current; + if (comparer.Compare(item1, item2) != 0) break; + } + } + } + } + + // Operator: Count + // This one could be implemented in native for better performance. + public static int Count(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Count(source); + } + int count = 0; + var partialResults = pipe.Extend(x => AsEnumerable(x.Count()), false); + foreach (int item in partialResults) + { + checked { count += item; } + } + return count; + } + else + { + return Enumerable.Count(source); + } + } + + public static int Count(IEnumerable source, + Func predicate) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Count(source, predicate); + } + int count = 0; + var partialResults = pipe.Extend(x => AsEnumerable(x.Count(predicate)), false); + foreach (int item in partialResults) + { + checked { count += item; } + } + return count; + } + else + { + return Enumerable.Count(source, predicate); + } + } + + // Operator: LongCount + // This one could be implemented in native for better performance. + public static long LongCount(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.LongCount(source); + } + long count = 0; + var partialResults = pipe.Extend(x => AsEnumerable(x.LongCount()), false); + foreach (long item in partialResults) + { + checked { count += item; } + } + return count; + } + else + { + return Enumerable.LongCount(source); + } + } + + public static long LongCount(IEnumerable source, + Func predicate) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.LongCount(source, predicate); + } + long count = 0; + var partialResults = pipe.Extend(x => AsEnumerable(x.LongCount(predicate)), false); + foreach (long item in partialResults) + { + checked { count += item; } + } + return count; + } + else + { + return Enumerable.LongCount(source, predicate); + } + } + + // Operator: Contains + public static bool Contains(IEnumerable source, + TSource value, + IEqualityComparer comparer) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Contains(source, value, comparer); + } + + var partialResults = pipe.Extend(x => AsEnumerable(x.Contains(value, comparer)), false); + foreach (bool item in partialResults) + { + if (item) return true; + } + return false; + } + else + { + return Enumerable.Contains(source, value, comparer); + } + } + + // Operator: Aggregate + public static TSource + Aggregate(IEnumerable source, + Func aggregator) + { + return System.Linq.Enumerable.Aggregate(source, aggregator); + } + + public static TAccumulate + Aggregate(IEnumerable source, + TAccumulate seed, + Func aggregator) + { + return System.Linq.Enumerable.Aggregate(source, seed, aggregator); + } + + public static TResult + Aggregate(IEnumerable source, + TAccumulate seed, + Func aggregator, + Func resultSelector) + { + return System.Linq.Enumerable.Aggregate(source, seed, aggregator, resultSelector); + } + + // If the aggregate function is associative... + public static AggregateValue + AssocAggregate(IEnumerable source, + Func aggregator) + { + TSource result = default(TSource); + long count = 0; + using (IEnumerator elems = source.GetEnumerator()) + { + if (elems.MoveNext()) + { + result = elems.Current; + count = 1; + while (elems.MoveNext()) + { + result = aggregator(result, elems.Current); + } + } + } + return new AggregateValue(result, count); + } + + public static AggregateValue + AssocAggregate(IEnumerable source, + TAccumulate seed, + Func aggregator) + { + TAccumulate result = seed; + long count = 1; + foreach (TSource elem in source) + { + result = aggregator(result, elem); + } + return new AggregateValue(result, count); + } + + /// + /// This is an aggregation function used for intermediate layers in the aggregation tree. + /// This function is just like AssocAggregate, but it retunrs a different type. + /// + /// Type of elements to aggregate. + /// A collection of AggregateValue objects to aggregate. + /// A combiner function which perform aggregation on TSource + /// Another AggregateValue. + public static AggregateValue + AssocTreeAggregate(IEnumerable> source, + Func combiner) + { + TSource result = default(TSource); + bool hasElem = false; + long count = 0; + foreach (var elem in source) + { + if (elem.Count > 0) + { + if (hasElem) + { + result = combiner(result, elem.Value); + } + else + { + result = elem.Value; + hasElem = true; + } + count += elem.Count; + } + } + return new AggregateValue(result, count); + } + + public static TSource + AssocAggregate(IEnumerable> source, + Func combiner) + { + TSource result = default(TSource); + bool hasElem = false; + foreach (var elem in source) + { + if (elem.Count > 0) + { + if (hasElem) + { + result = combiner(result, elem.Value); + } + else + { + result = elem.Value; + hasElem = true; + } + } + } + if (hasElem) return result; + throw new DryadLinqException(HpcLinqErrorCode.AggregateNoElements, SR.AggregateNoElements); + } + + public static TResult + AssocAggregate(IEnumerable> source, + Func combiner, + Func resultSelector) + { + return resultSelector(AssocAggregate(source, combiner)); + } + + // Operator: First + public static AggregateValue First(IEnumerable source) + { + using (IEnumerator e = source.GetEnumerator()) + { + if (e.MoveNext()) + { + return new AggregateValue(e.Current, 1); + } + } + return new AggregateValue(default(TSource), 0); + } + + public static AggregateValue + First(IEnumerable source, Func predicate) + { + foreach (TSource elem in source) + { + if (predicate(elem)) + { + return new AggregateValue(elem, 1); + } + } + return new AggregateValue(default(TSource), 0); + } + + public static TSource First(IEnumerable> source) + { + foreach (var elem in source) + { + if (elem.Count > 0) return elem.Value; + } + throw new DryadLinqException(HpcLinqErrorCode.FirstNoElementsFirst, SR.FirstNoElementsFirst); + } + + // Operator: FirstOrDefault + public static AggregateValue FirstOrDefault(IEnumerable source) + { + return First(source); + } + + public static AggregateValue + FirstOrDefault(IEnumerable source, + Func predicate) + { + return First(source, predicate); + } + + public static AggregateValue + FirstOrDefault(IEnumerable> source) + { + foreach (var elem in source) + { + if (elem.Count > 0) return elem; + } + return new AggregateValue(default(TSource), 0); + } + + // Operator: Single + public static AggregateValue Single(IEnumerable source) + { + using (IEnumerator e = source.GetEnumerator()) + { + if (!e.MoveNext()) + { + return new AggregateValue(default(TSource), 0); + } + TSource val = e.Current; + if (!e.MoveNext()) + { + return new AggregateValue(val, 1); + } + } + throw new DryadLinqException(HpcLinqErrorCode.SingleMoreThanOneElement, + SR.SingleMoreThanOneElement); + } + + public static AggregateValue + Single(IEnumerable source, + Func predicate) + { + IParallelPipeline pipe = source as IParallelPipeline; + IEnumerable> partialResults; + if (pipe == null) + { + partialResults = source.PApply(s => AsEnumerable(SingleInner(s, predicate)), false); + } + else + { + partialResults = pipe.Extend(s => AsEnumerable(SingleInner(s, predicate)), false); + } + + TSource theValue = default(TSource); + long count = 0; + foreach (var elem in partialResults) + { + if (elem.Count > 0) + { + if (count > 0) + { + new DryadLinqException(HpcLinqErrorCode.SingleMoreThanOneElement, + SR.SingleMoreThanOneElement); + } + count = 1; + theValue = elem.Value; + } + } + return new AggregateValue(theValue, count); + } + + private static AggregateValue + SingleInner(IEnumerable source, Func predicate) + { + long count = 0; + TSource theValue = default(TSource); + foreach (TSource elem in source) + { + if (predicate(elem)) + { + if (count > 0) + { + new DryadLinqException(HpcLinqErrorCode.SingleMoreThanOneElement, + SR.SingleMoreThanOneElement); + } + count = 1; + theValue = elem; + } + } + return new AggregateValue(theValue, count); + } + + public static TSource Single(IEnumerable> source) + { + AggregateValue result = new AggregateValue(default(TSource), 0); + foreach (var elem in source) + { + if (elem.Count > 0) + { + if (result.Count > 0) + { + throw new DryadLinqException(HpcLinqErrorCode.SingleMoreThanOneElement, + SR.SingleMoreThanOneElement); + } + result = elem; + } + } + if (result.Count == 0) + { + throw new DryadLinqException(HpcLinqErrorCode.SingleNoElements, SR.SingleNoElements); + } + return result.Value; + } + + // Operator: SingleOrDefault + public static AggregateValue SingleOrDefault(IEnumerable source) + { + return Single(source); + } + + public static AggregateValue + SingleOrDefault(IEnumerable source, + Func predicate) + { + return Single(source, predicate); + } + + public static AggregateValue + SingleOrDefault(IEnumerable> source) + { + AggregateValue result = new AggregateValue(default(TSource), 0); + foreach (var elem in source) + { + if (elem.Count > 0) + { + if (result.Count > 0) + { + throw new DryadLinqException(HpcLinqErrorCode.SingleMoreThanOneElement, + SR.SingleMoreThanOneElement); + } + result = elem; + } + } + return result; + } + + // Operator: Last + public static AggregateValue Last(IEnumerable source) + { + using (IEnumerator e = source.GetEnumerator()) + { + if (e.MoveNext()) + { + TSource result = e.Current; + while (e.MoveNext()) + { + result = e.Current; + } + return new AggregateValue(result, 1); + } + } + return new AggregateValue(default(TSource), 0); + } + + public static AggregateValue + Last(IEnumerable source, Func predicate) + { + IParallelPipeline pipe = source as IParallelPipeline; + IEnumerable> partialResults; + if (pipe == null) + { + partialResults = source.PApply(s => AsEnumerable(LastInner(s, predicate)), true); + } + else + { + partialResults = pipe.Extend(s => AsEnumerable(LastInner(s, predicate)), true); + } + + TSource lastValue = default(TSource); + long count = 0; + foreach (var elem in partialResults) + { + if (elem.Count > 0) + { + count = 1; + lastValue = elem.Value; + } + } + return new AggregateValue(lastValue, count); + } + + private static AggregateValue + LastInner(IEnumerable source, Func predicate) + { + long count = 0; + TSource lastValue = default(TSource); + foreach (TSource elem in source) + { + if (predicate(elem)) + { + count = 1; + lastValue = elem; + } + } + return new AggregateValue(lastValue, count); + } + + public static TSource Last(IEnumerable> source) + { + AggregateValue result = new AggregateValue(default(TSource), 0); + foreach (var elem in source) + { + if (elem.Count > 0) + { + result = elem; + } + } + if (result.Count == 0) + { + throw new DryadLinqException(HpcLinqErrorCode.LastNoElements, SR.LastNoElements); + } + return result.Value; + } + + // Operator: LastOrDefault + public static AggregateValue LastOrDefault(IEnumerable source) + { + return Last(source); + } + + public static AggregateValue + LastOrDefault(IEnumerable source, + Func predicate) + { + return Last(source, predicate); + } + + public static AggregateValue + LastOrDefault(IEnumerable> source) + { + AggregateValue result = new AggregateValue(default(TSource), 0); + foreach (var elem in source) + { + if (elem.Count > 0) + { + result = elem; + } + } + return result; + } + + // Operator: Sum + public static int Sum(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Sum(source); + } + int sum = 0; + var partialResults = pipe.Extend(x => AsEnumerable(x.Sum()), false); + foreach (int item in partialResults) + { + checked { sum += item; } + } + return sum; + } + else + { + return System.Linq.Enumerable.Sum(source); + } + } + + public static int? Sum(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Sum(source); + } + int sum = 0; + var partialResults = pipe.Extend(x => AsEnumerable(x.Sum()), false); + foreach (int? item in partialResults) + { + if (item != null) + { + checked { sum += item.GetValueOrDefault(); } + } + } + return sum; + } + else + { + return System.Linq.Enumerable.Sum(source); + } + } + + public static long Sum(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Sum(source); + } + long sum = 0; + var partialResults = pipe.Extend(x => AsEnumerable(x.Sum()), false); + foreach (long item in partialResults) + { + checked { sum += item; } + } + return sum; + } + else + { + return Enumerable.Sum(source); + } + } + + public static long? Sum(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Sum(source); + } + long sum = 0; + var partialResults = pipe.Extend(x => AsEnumerable(x.Sum()), false); + foreach (long? item in partialResults) + { + if (item != null) + { + checked { sum += item.GetValueOrDefault(); } + } + } + return sum; + } + else + { + return System.Linq.Enumerable.Sum(source); + } + } + + public static float Sum(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Sum(source); + } + float sum = 0; + var partialResults = pipe.Extend(x => AsEnumerable(x.Sum()), false); + foreach (float item in partialResults) + { + sum += item; + } + return sum; + } + else + { + return Enumerable.Sum(source); + } + } + + public static float? Sum(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Sum(source); + } + float sum = 0; + var partialResults = pipe.Extend(x => AsEnumerable(x.Sum()), false); + foreach (float? item in partialResults) + { + if (item != null) + { + sum += item.GetValueOrDefault(); + } + } + return sum; + } + else + { + return System.Linq.Enumerable.Sum(source); + } + } + + public static double Sum(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Sum(source); + } + double sum = 0; + var partialResults = pipe.Extend(x => AsEnumerable(x.Sum()), false); + foreach (double item in partialResults) + { + sum += item; + } + return sum; + } + else + { + return Enumerable.Sum(source); + } + } + + public static double? Sum(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Sum(source); + } + double sum = 0; + var partialResults = pipe.Extend(x => AsEnumerable(x.Sum()), false); + foreach (double? item in partialResults) + { + if (item != null) + { + sum += item.GetValueOrDefault(); + } + } + return sum; + } + else + { + return System.Linq.Enumerable.Sum(source); + } + } + + public static decimal Sum(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Sum(source); + } + decimal sum = 0; + var partialResults = pipe.Extend(x => AsEnumerable(x.Sum()), false); + foreach (decimal item in partialResults) + { + sum += item; + } + return sum; + } + else + { + return Enumerable.Sum(source); + } + } + + public static decimal? Sum(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Sum(source); + } + decimal sum = 0; + var partialResults = pipe.Extend(x => AsEnumerable(x.Sum()), false); + foreach (decimal? item in partialResults) + { + if (item != null) + { + sum += item.GetValueOrDefault(); + } + } + return sum; + } + else + { + return System.Linq.Enumerable.Sum(source); + } + } + + public static int Sum(IEnumerable source, + Func selector) + { + if (s_multiThreading) + { + var partialResults = source.ExtendParallelPipeline(x => AsEnumerable(x.Sum(selector)), false); + int sum = 0; + foreach (int item in partialResults) + { + checked { sum += item; } + } + return sum; + } + else + { + return Enumerable.Sum(source, selector); + } + } + + public static int? Sum(IEnumerable source, + Func selector) + { + if (s_multiThreading) + { + var partialResults = source.ExtendParallelPipeline(x => AsEnumerable(x.Sum(selector)), false); + int sum = 0; + foreach (int? item in partialResults) + { + if (item != null) + { + checked { sum += item.GetValueOrDefault(); } + } + } + return sum; + } + else + { + return Enumerable.Sum(source, selector); + } + } + + public static long Sum(IEnumerable source, + Func selector) + { + if (s_multiThreading) + { + var partialResults = source.ExtendParallelPipeline(x => AsEnumerable(x.Sum(selector)), false); + long sum = 0; + foreach (long item in partialResults) + { + checked { sum += item; } + } + return sum; + } + else + { + return Enumerable.Sum(source, selector); + } + } + + public static long? Sum(IEnumerable source, + Func selector) + { + if (s_multiThreading) + { + var partialResults = source.ExtendParallelPipeline(x => AsEnumerable(x.Sum(selector)), false); + long sum = 0; + foreach (long? item in partialResults) + { + if (item != null) + { + checked { sum += item.GetValueOrDefault(); } + } + } + return sum; + } + else + { + return Enumerable.Sum(source, selector); + } + } + + public static float Sum(IEnumerable source, + Func selector) + { + if (s_multiThreading) + { + var partialResults = source.ExtendParallelPipeline(x => AsEnumerable(x.Sum(selector)), false); + float sum = 0; + foreach (float item in partialResults) + { + sum += item; + } + return sum; + } + else + { + return Enumerable.Sum(source, selector); + } + } + + public static float? Sum(IEnumerable source, + Func selector) + { + if (s_multiThreading) + { + var partialResults = source.ExtendParallelPipeline(x => AsEnumerable(x.Sum(selector)), false); + float sum = 0; + foreach (float? item in partialResults) + { + if (item != null) + { + sum += item.GetValueOrDefault(); + } + } + return sum; + } + else + { + return Enumerable.Sum(source, selector); + } + } + + public static double Sum(IEnumerable source, + Func selector) + { + if (s_multiThreading) + { + var partialResults = source.ExtendParallelPipeline(x => AsEnumerable(x.Sum(selector)), false); + double sum = 0; + foreach (double item in partialResults) + { + sum += item; + } + return sum; + } + else + { + return Enumerable.Sum(source, selector); + } + } + + public static double? Sum(IEnumerable source, + Func selector) + { + if (s_multiThreading) + { + var partialResults = source.ExtendParallelPipeline(x => AsEnumerable(x.Sum(selector)), false); + double sum = 0; + foreach (double? item in partialResults) + { + if (item != null) + { + sum += item.GetValueOrDefault(); + } + } + return sum; + } + else + { + return Enumerable.Sum(source, selector); + } + } + + public static decimal Sum(IEnumerable source, + Func selector) + { + if (s_multiThreading) + { + var partialResults = source.ExtendParallelPipeline(x => AsEnumerable(x.Sum(selector)), false); + decimal sum = 0; + foreach (decimal item in partialResults) + { + sum += item; + } + return sum; + } + else + { + return Enumerable.Sum(source, selector); + } + } + + public static decimal? Sum(IEnumerable source, + Func selector) + { + if (s_multiThreading) + { + var partialResults = source.ExtendParallelPipeline(x => AsEnumerable(x.Sum(selector)), false); + decimal sum = 0; + foreach (decimal? item in partialResults) + { + if (item != null) + { + sum += item.GetValueOrDefault(); + } + } + return sum; + } + else + { + return Enumerable.Sum(source, selector); + } + } + + public static int SumAccumulate(int a, int? x) + { + return (x == null) ? a : checked(a + x.GetValueOrDefault()); + } + + public static long SumAccumulate(long a, long? x) + { + return (x == null) ? a : checked(a + x.GetValueOrDefault()); + } + + public static float SumAccumulate(float a, float? x) + { + return (x == null) ? a : (a + x.GetValueOrDefault()); + } + + public static double SumAccumulate(double a, double? x) + { + return (x == null) ? a : (a + x.GetValueOrDefault()); + } + + public static decimal SumAccumulate(decimal a, decimal? x) + { + return (x == null) ? a : (a + x.GetValueOrDefault()); + } + + // Operator: Min + public static AggregateValue Min(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return MinInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(MinInner(s)), false); + int value = Int32.MaxValue; + long count = 0; + foreach (var elem in partialResults) + { + if (elem.Value < value) value = elem.Value; + count += elem.Count; + } + return new AggregateValue(value, count); + } + else + { + return MinInner(source); + } + } + + private static AggregateValue MinInner(IEnumerable source) + { + int value = Int32.MaxValue; + long count = 0; + + foreach (int elem in source) + { + if (elem < value) value = elem; + count = 1; + } + return new AggregateValue(value, count); + } + + public static int Min(IEnumerable> source) + { + int value = Int32.MaxValue; + bool hasValue = false; + foreach (var elem in source) + { + if (elem.Count > 0) + { + if (elem.Value < value) + { + value = elem.Value; + } + hasValue = true; + } + } + if (!hasValue) + { + throw new DryadLinqException(HpcLinqErrorCode.MinNoElements, SR.MinNoElements); + } + return value; + } + + public static int? Min(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Min(source); + } + + IEnumerable partialResults = pipe.Extend(s => AsEnumerable(s.Min()), false); + int? value = null; + foreach (var elem in partialResults) + { + if (value == null || elem < value) + { + value = elem; + } + } + return value; + } + else + { + return System.Linq.Enumerable.Min(source); + } + } + + public static AggregateValue Min(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return MinInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(MinInner(s)), false); + long value = Int64.MaxValue; + long count = 0; + foreach (var elem in partialResults) + { + if (elem.Value < value) + { + value = elem.Value; + } + count = 1; + } + return new AggregateValue(value, count); + } + else + { + return MinInner(source); + } + } + + private static AggregateValue MinInner(IEnumerable source) + { + long value = Int64.MaxValue; + long count = 0; + + foreach (long elem in source) + { + if (elem < value) value = elem; + count = 1; + } + return new AggregateValue(value, count); + } + + public static long Min(IEnumerable> source) + { + long value = Int64.MaxValue; + bool hasElem = false; + foreach (var elem in source) + { + if (elem.Count > 0) + { + if (elem.Value < value) + { + value = elem.Value; + } + hasElem = true; + } + } + if (!hasElem) + { + throw new DryadLinqException(HpcLinqErrorCode.MinNoElements, SR.MinNoElements); + } + return value; + } + + public static long? Min(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Min(source); + } + + IEnumerable partialResults = pipe.Extend(s => AsEnumerable(s.Min()), false); + long? value = null; + foreach (var elem in partialResults) + { + if (value == null || elem < value) + { + value = elem; + } + } + return value; + } + else + { + return System.Linq.Enumerable.Min(source); + } + } + + public static AggregateValue Min(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return MinInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(MinInner(s)), false); + float value = System.Single.MaxValue; + long count = 0; + foreach (var elem in partialResults) + { + if (elem.Value < value) + { + value = elem.Value; + } + count = 1; + } + return new AggregateValue(value, count); + } + else + { + return MinInner(source); + } + } + + private static AggregateValue MinInner(IEnumerable source) + { + float value = System.Single.MaxValue; + long count = 0; + + foreach (float elem in source) + { + if (elem < value) value = elem; + count = 1; + } + return new AggregateValue(value, count); + } + + public static float Min(IEnumerable> source) + { + float value = System.Single.MaxValue; + bool hasElem = false; + foreach (var elem in source) + { + if (elem.Count > 0) + { + if (elem.Value < value) + { + value = elem.Value; + } + hasElem = true; + } + } + if (!hasElem) + { + throw new DryadLinqException(HpcLinqErrorCode.MinNoElements, SR.MinNoElements); + } + return value; + } + + public static float? Min(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Min(source); + } + + IEnumerable partialResults = pipe.Extend(s => AsEnumerable(s.Min()), false); + float? value = null; + foreach (var elem in partialResults) + { + if (value == null || elem < value) + { + value = elem; + } + } + return value; + } + else + { + return System.Linq.Enumerable.Min(source); + } + } + + public static AggregateValue Min(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return MinInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(MinInner(s)), false); + double value = Double.MaxValue; + long count = 0; + foreach (var elem in partialResults) + { + if (elem.Value < value) + { + value = elem.Value; + } + count = 1; + } + return new AggregateValue(value, count); + } + else + { + return MinInner(source); + } + } + + private static AggregateValue MinInner(IEnumerable source) + { + double value = Double.MaxValue; + long count = 0; + + foreach (double elem in source) + { + if (elem < value) value = elem; + count = 1; + } + return new AggregateValue(value, count); + } + + public static double Min(IEnumerable> source) + { + double value = Double.MaxValue; + bool hasElem = false; + foreach (var elem in source) + { + if (elem.Count > 0) + { + if (elem.Value < value) + { + value = elem.Value; + } + hasElem = true; + } + } + if (!hasElem) + { + throw new DryadLinqException(HpcLinqErrorCode.MinNoElements, SR.MinNoElements); + } + return value; + } + + public static double? Min(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Min(source); + } + + IEnumerable partialResults = pipe.Extend(s => AsEnumerable(s.Min()), false); + double? value = null; + foreach (var elem in partialResults) + { + if (value == null || elem < value) + { + value = elem; + } + } + return value; + } + else + { + return System.Linq.Enumerable.Min(source); + } + } + + public static AggregateValue Min(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return MinInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(MinInner(s)), false); + decimal value = Decimal.MaxValue; + long count = 0; + foreach (var elem in partialResults) + { + if (elem.Value < value) + { + value = elem.Value; + } + count = 1; + } + return new AggregateValue(value, count); + } + else + { + return MinInner(source); + } + } + + private static AggregateValue MinInner(IEnumerable source) + { + decimal value = Decimal.MaxValue; + long count = 0; + + foreach (decimal x in source) + { + if (x < value) value = x; + count = 1; + } + return new AggregateValue(value, count); + } + + public static decimal Min(IEnumerable> source) + { + decimal value = Decimal.MaxValue; + bool hasElem = false; + foreach (var elem in source) + { + if (elem.Count > 0) + { + if (elem.Value < value) + { + value = elem.Value; + } + hasElem = true; + } + } + if (!hasElem) + { + throw new DryadLinqException(HpcLinqErrorCode.MinNoElements, SR.MinNoElements); + } + return value; + } + + public static decimal? Min(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Min(source); + } + + IEnumerable partialResults = pipe.Extend(s => AsEnumerable(s.Min()), false); + decimal? value = null; + foreach (var elem in partialResults) + { + if (value == null || elem < value) + { + value = elem; + } + } + return value; + } + else + { + return System.Linq.Enumerable.Min(source); + } + } + + public static AggregateValue Min(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + IEnumerable> partialResults; + if (pipe == null) + { + partialResults = source.PApply(s => AsEnumerable(MinInner(s)), false); + } + else + { + partialResults = pipe.Extend(s => AsEnumerable(MinInner(s)), false); + } + + IComparer comparer = TypeSystem.GetComparer(null); + TSource value = default(TSource); + long count = 0; + foreach (var elem in partialResults) + { + if (elem.Count > 0) + { + if (count == 0 || comparer.Compare(elem.Value, value) < 0) + { + value = elem.Value; + } + count = 1; + } + } + return new AggregateValue(value, count); + } + else + { + return MinInner(source); + } + } + + private static AggregateValue MinInner(IEnumerable source) + { + IComparer comparer = TypeSystem.GetComparer(null); + TSource value = default(TSource); + long count = 0; + + foreach (TSource x in source) + { + if (x != null) + { + if (count == 0 || comparer.Compare(x, value) < 0) + { + value = x; + count = 1; + } + } + } + return new AggregateValue(value, count); + } + + public static TSource Min(IEnumerable> source) + { + IComparer comparer = TypeSystem.GetComparer(null); + TSource value = default(TSource); + bool hasElem = false; + + foreach (var elem in source) + { + if (elem.Count > 0) + { + if (!hasElem || comparer.Compare(elem.Value, value) < 0) + { + value = elem.Value; + hasElem = true; + } + } + } + if (!hasElem) + { + throw new DryadLinqException(HpcLinqErrorCode.MinNoElements, SR.MinNoElements); + } + return value; + } + + public static AggregateValue + Min(IEnumerable source, Func selector) + { + return Min(Select(source, selector, false)); + } + + public static int? Min(IEnumerable source, + Func selector) + { + return Min(Select(source, selector, false)); + } + + public static AggregateValue Min(IEnumerable source, + Func selector) + { + return Min(Select(source, selector, false)); + } + + public static long? Min(IEnumerable source, + Func selector) + { + return Min(Select(source, selector, false)); + } + + public static AggregateValue Min(IEnumerable source, + Func selector) + { + return Min(Select(source, selector, false)); + } + + public static float? Min(IEnumerable source, + Func selector) + { + return Min(Select(source, selector, false)); + } + + public static AggregateValue Min(IEnumerable source, + Func selector) + { + return Min(Select(source, selector, false)); + } + + public static double? Min(IEnumerable source, + Func selector) + { + return Min(Select(source, selector, false)); + } + + public static AggregateValue Min(IEnumerable source, + Func selector) + { + return Min(Select(source, selector, false)); + } + + public static decimal? Min(IEnumerable source, + Func selector) + { + return Min(Select(source, selector, false)); + } + + public static AggregateValue Min(IEnumerable source, + Func selector) + { + return Min(Select(source, selector, false)); + } + + public static int MinAccumulate(int a, int x) + { + return (x < a) ? x : a; + } + + public static int? MinAccumulate(int? a, int? x) + { + return (a == null || x < a) ? x : a; + } + + public static long MinAccumulate(long a, long x) + { + return (x < a) ? x : a; + } + + public static long? MinAccumulate(long? a, long? x) + { + return (a == null || x < a) ? x : a; + } + + public static float MinAccumulate(float a, float x) + { + return (x < a) ? x : a; + } + + public static float? MinAccumulate(float? a, float? x) + { + return (a == null || x < a) ? x : a; + } + + public static double MinAccumulate(double a, double x) + { + return (x < a) ? x : a; + } + + public static double? MinAccumulate(double? a, double? x) + { + return (a == null || x < a) ? x : a; + } + + public static decimal MinAccumulate(decimal a, decimal x) + { + return (x < a) ? x : a; + } + + public static decimal? MinAccumulate(decimal? a, decimal? x) + { + return (a == null || x < a) ? x : a; + } + + public static TSource MinAccumulateGeneric(TSource a, TSource x) + { + return (x != null && (a == null || Comparer.Default.Compare(x, a) < 0)) ? x : a; + } + + // Operator: Max + public static AggregateValue Max(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return MaxInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(MaxInner(s)), false); + int value = Int32.MinValue; + long count = 0; + foreach (var elem in partialResults) + { + if (elem.Value > value) value = elem.Value; + count = 1; + } + return new AggregateValue(value, count); + } + else + { + return MaxInner(source); + } + } + + private static AggregateValue MaxInner(IEnumerable source) + { + int value = Int32.MinValue; + long count = 0; + + foreach (int elem in source) + { + if (elem > value) value = elem; + count = 1; + } + return new AggregateValue(value, count); + } + + public static int Max(IEnumerable> source) + { + int value = Int32.MinValue; + bool hasElem = false; + foreach (var elem in source) + { + if (elem.Count > 0) + { + if (elem.Value > value) value = elem.Value; + hasElem = true; + } + } + if (!hasElem) + { + throw new DryadLinqException(HpcLinqErrorCode.MaxNoElements, SR.MaxNoElements); + } + return value; + } + + public static int? Max(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Max(source); + } + + IEnumerable partialResults = pipe.Extend(s => AsEnumerable(s.Max()), false); + int? value = null; + foreach (var elem in partialResults) + { + if (value == null || elem > value) + { + value = elem; + } + } + return value; + } + else + { + return System.Linq.Enumerable.Max(source); + } + } + + public static AggregateValue Max(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return MaxInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(MaxInner(s)), false); + long value = Int64.MinValue; + long count = 0; + foreach (var elem in partialResults) + { + if (elem.Value > value) value = elem.Value; + count = 1; + } + return new AggregateValue(value, count); + } + else + { + return MaxInner(source); + } + } + + private static AggregateValue MaxInner(IEnumerable source) + { + long value = Int64.MinValue; + long count = 0; + + foreach (long elem in source) + { + if (elem > value) value = elem; + count = 1; + } + return new AggregateValue(value, count); + } + + public static long Max(IEnumerable> source) + { + long value = Int64.MinValue; + bool hasElem = false; + foreach (var elem in source) + { + if (elem.Count > 0) + { + if (elem.Value > value) value = elem.Value; + hasElem = true; + } + } + if (!hasElem) + { + throw new DryadLinqException(HpcLinqErrorCode.MaxNoElements, SR.MaxNoElements); + } + return value; + } + + public static long? Max(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Max(source); + } + + IEnumerable partialResults = pipe.Extend(s => AsEnumerable(s.Max()), false); + long? value = null; + foreach (var elem in partialResults) + { + if (value == null || elem > value) + { + value = elem; + } + } + return value; + } + else + { + return System.Linq.Enumerable.Max(source); + } + } + + public static AggregateValue Max(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return MaxInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(MaxInner(s)), false); + double value = Double.MinValue; + long count = 0; + foreach (var elem in partialResults) + { + if (elem.Value > value) value = elem.Value; + count = 1; + } + return new AggregateValue(value, count); + } + else + { + return MaxInner(source); + } + } + + private static AggregateValue MaxInner(IEnumerable source) + { + double value = Double.MinValue; + long count = 0; + + foreach (double elem in source) + { + if (elem > value) value = elem; + count = 1; + } + return new AggregateValue(value, count); + } + + public static double Max(IEnumerable> source) + { + double value = Double.MinValue; + bool hasElem = false; + foreach (var elem in source) + { + if (elem.Count > 0) + { + if (elem.Value > value) + { + value = elem.Value; + } + hasElem = true; + } + } + if (!hasElem) + { + throw new DryadLinqException(HpcLinqErrorCode.MaxNoElements, SR.MaxNoElements); + } + return value; + } + + public static double? Max(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Max(source); + } + + IEnumerable partialResults = pipe.Extend(s => AsEnumerable(s.Max()), false); + double? value = null; + foreach (var elem in partialResults) + { + if (value == null || elem > value) + { + value = elem; + } + } + return value; + } + else + { + return System.Linq.Enumerable.Max(source); + } + } + + public static AggregateValue Max(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return MaxInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(MaxInner(s)), false); + float value = System.Single.MinValue; + long count = 0; + foreach (var elem in partialResults) + { + if (elem.Value > value) value = elem.Value; + count = 1; + } + return new AggregateValue(value, count); + } + else + { + return MaxInner(source); + } + } + + private static AggregateValue MaxInner(IEnumerable source) + { + float value = System.Single.MinValue; + long count = 0; + + foreach (float x in source) + { + if (x > value) value = x; + count = 1; + } + return new AggregateValue(value, count); + } + + public static float Max(IEnumerable> source) + { + float value = System.Single.MinValue; + bool hasElem = false; + foreach (var elem in source) + { + if (elem.Count > 0) + { + if (elem.Value > value) + { + value = elem.Value; + } + hasElem = true; + } + } + if (!hasElem) + { + throw new DryadLinqException(HpcLinqErrorCode.MaxNoElements, SR.MaxNoElements); + } + return value; + } + + public static float? Max(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Max(source); + } + + IEnumerable partialResults = pipe.Extend(s => AsEnumerable(s.Max()), false); + float? value = null; + foreach (var elem in partialResults) + { + if (value == null || elem > value) + { + value = elem; + } + } + return value; + } + else + { + return System.Linq.Enumerable.Max(source); + } + } + + public static AggregateValue Max(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return MaxInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(MaxInner(s)), false); + decimal value = Decimal.MinValue; + long count = 0; + foreach (var elem in partialResults) + { + if (elem.Value > value) value = elem.Value; + count = 1; + } + return new AggregateValue(value, count); + } + else + { + return MaxInner(source); + } + } + + private static AggregateValue MaxInner(IEnumerable source) + { + decimal value = Decimal.MinValue; + long count = 0; + + foreach (decimal x in source) + { + if (x > value) value = x; + count = 1; + } + return new AggregateValue(value, count); + } + + public static decimal Max(IEnumerable> source) + { + decimal value = Decimal.MinValue; + bool hasElem = false; + foreach (var elem in source) + { + if (elem.Count > 0) + { + if (elem.Value > value) value = elem.Value; + hasElem = true; + } + } + if (!hasElem) + { + throw new DryadLinqException(HpcLinqErrorCode.MaxNoElements, SR.MaxNoElements); + } + return value; + } + + public static decimal? Max(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return System.Linq.Enumerable.Max(source); + } + + IEnumerable partialResults = pipe.Extend(s => AsEnumerable(s.Max()), false); + decimal? value = null; + foreach (var elem in partialResults) + { + if (value == null || elem > value) + { + value = elem; + } + } + return value; + } + else + { + return System.Linq.Enumerable.Max(source); + } + } + + public static AggregateValue Max(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + IEnumerable> partialResults; + if (pipe == null) + { + partialResults = source.PApply(s => AsEnumerable(MaxInner(s)), false); + } + else + { + partialResults = pipe.Extend(s => AsEnumerable(MaxInner(s)), false); + } + + IComparer comparer = TypeSystem.GetComparer(null); + TSource value = default(TSource); + long count = 0; + foreach (var elem in partialResults) + { + if (elem.Count > 0) + { + if (count == 0 || comparer.Compare(elem.Value, value) > 0) + { + value = elem.Value; + } + count = 1; + } + } + return new AggregateValue(value, count); + } + else + { + return MaxInner(source); + } + } + + private static AggregateValue MaxInner(IEnumerable source) + { + IComparer comparer = TypeSystem.GetComparer(null); + TSource value = default(TSource); + long count = 0; + + foreach (TSource x in source) + { + if (x != null) + { + if (count == 0 || comparer.Compare(x, value) > 0) + { + value = x; + count = 1; + } + } + } + return new AggregateValue(value, count); + } + + public static TSource Max(IEnumerable> source) + { + IComparer comparer = TypeSystem.GetComparer(null); + TSource value = default(TSource); + bool hasElem = false; + + foreach (var elem in source) + { + if (elem.Count > 0) + { + if (!hasElem || comparer.Compare(elem.Value, value) > 0) + { + value = elem.Value; + hasElem = true; + } + } + } + if (!hasElem) + { + throw new DryadLinqException(HpcLinqErrorCode.MaxNoElements, SR.MaxNoElements); + } + return value; + } + + public static AggregateValue Max(IEnumerable source, + Func selector) + { + return Max(Select(source, selector, false)); + } + + public static int? Max(IEnumerable source, + Func selector) + { + return Max(Select(source, selector, false)); + } + + public static AggregateValue Max(IEnumerable source, + Func selector) + { + return Max(Select(source, selector, false)); + } + + public static long? Max(IEnumerable source, + Func selector) + { + return Max(Select(source, selector, false)); + } + + public static AggregateValue Max(IEnumerable source, + Func selector) + { + return Max(Select(source, selector, false)); + } + + public static float? Max(IEnumerable source, + Func selector) + { + return Max(Select(source, selector, false)); + } + + public static AggregateValue Max(IEnumerable source, + Func selector) + { + return Max(Select(source, selector, false)); + } + + public static double? Max(IEnumerable source, + Func selector) + { + return Max(Select(source, selector, false)); + } + + public static AggregateValue Max(IEnumerable source, + Func selector) + { + return Max(Select(source, selector, false)); + } + + public static decimal? Max(IEnumerable source, + Func selector) + { + return Max(Select(source, selector, false)); + } + + public static AggregateValue Max(IEnumerable source, + Func selector) + { + return Max(Select(source, selector, false)); + } + + public static int? MaxAccumulate(int? a, int? x) + { + return (a == null || x > a) ? x : a; + } + + public static int MaxAccumulate(int a, int x) + { + return (x > a) ? x : a; + } + + public static long MaxAccumulate(long a, long x) + { + return (x > a) ? x : a; + } + + public static long? MaxAccumulate(long? a, long? x) + { + return (a == null || x > a) ? x : a; + } + + public static float MaxAccumulate(float a, float x) + { + return (x > a) ? x : a; + } + + public static float? MaxAccumulate(float? a, float? x) + { + return (a == null || x > a) ? x : a; + } + + public static double MaxAccumulate(double a, double x) + { + return (x > a) ? x : a; + } + + public static double? MaxAccumulate(double? a, double? x) + { + return (a == null || x > a) ? x : a; + } + + public static decimal MaxAccumulate(decimal a, decimal x) + { + return (x > a) ? x : a; + } + + public static decimal? MaxAccumulate(decimal? a, decimal? x) + { + return (a == null || x > a) ? x : a; + } + + public static TSource MaxAccumulateGeneric(TSource a, TSource x) + { + return (x != null && (a == null || Comparer.Default.Compare(x, a) > 0)) ? x : a; + } + + // Operator: Average + public static AggregateValue Average(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return AverageInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(AverageInner(s)), false); + long sum = 0; + long count = 0; + foreach (var elem in partialResults) + { + checked { + sum += elem.Value; + count += elem.Count; + } + } + return new AggregateValue(sum, count); + } + else + { + return AverageInner(source); + } + } + + private static AggregateValue AverageInner(IEnumerable source) + { + long sum = 0; + long count = 0; + foreach (int x in source) + { + checked { + sum += x; + count++; + } + } + return new AggregateValue(sum, count); + } + + public static double Average(IEnumerable> source) + { + long sum = 0; + long count = 0; + foreach (var x in source) + { + checked { + sum += x.Value; + count += x.Count; + } + } + + if (count == 0) + { + throw new DryadLinqException(HpcLinqErrorCode.AverageNoElements, SR.AverageNoElements); + } + return (double)sum / count; + } + + public static AggregateValue Average(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe != null) + { + return AverageInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(AverageInner(s)), false); + long sum = 0; + long count = 0; + foreach (var elem in partialResults) + { + checked { + sum += elem.Value.GetValueOrDefault(); + count += elem.Count; + } + } + return new AggregateValue(sum, count); + } + else + { + return AverageInner(source); + } + } + + private static AggregateValue AverageInner(IEnumerable source) + { + long sum = 0; + long count = 0; + foreach (int? x in source) + { + if (x != null) + { + checked { + sum += x.GetValueOrDefault(); + count++; + } + } + } + return new AggregateValue(sum, count); + } + + public static double? Average(IEnumerable> source) + { + long sum = 0; + long count = 0; + foreach (var x in source) + { + checked { + sum += x.Value.GetValueOrDefault(); + count += x.Count; + } + } + + if (count == 0) return null; + return (double)sum / count; + } + + public static AggregateValue Average(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return AverageInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(AverageInner(s)), false); + long sum = 0; + long count = 0; + foreach (var elem in partialResults) + { + checked { + sum += elem.Value; + count += elem.Count; + } + } + return new AggregateValue(sum, count); + } + else + { + return AverageInner(source); + } + } + + private static AggregateValue AverageInner(IEnumerable source) + { + long sum = 0; + long count = 0; + foreach (long x in source) + { + checked { + sum += x; + count++; + } + } + return new AggregateValue(sum, count); + } + + public static AggregateValue Average(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe != null) + { + return AverageInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(AverageInner(s)), false); + long sum = 0; + long count = 0; + foreach (var elem in partialResults) + { + checked { + sum += elem.Value.GetValueOrDefault(); + count += elem.Count; + } + } + return new AggregateValue(sum, count); + } + else + { + return AverageInner(source); + } + } + + private static AggregateValue AverageInner(IEnumerable source) + { + long sum = 0; + long count = 0; + foreach (long? x in source) + { + if (x != null) + { + checked { + sum += x.GetValueOrDefault(); + count++; + } + } + } + return new AggregateValue(sum, count); + } + + public static AggregateValue Average(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return AverageInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(AverageInner(s)), false); + double sum = 0; + long count = 0; + foreach (var elem in partialResults) + { + sum += elem.Value; + checked { count += elem.Count; } + } + return new AggregateValue(sum, count); + } + else + { + return AverageInner(source); + } + } + + private static AggregateValue AverageInner(IEnumerable source) + { + double sum = 0; + long count = 0; + foreach (float v in source) + { + sum += v; + checked { count++; } + } + return new AggregateValue(sum, count); + } + + public static AggregateValue Average(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe != null) + { + return AverageInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(AverageInner(s)), false); + double sum = 0; + long count = 0; + foreach (var elem in partialResults) + { + if (elem.Count != 0) + { + sum += elem.Value.GetValueOrDefault(); + checked { count += elem.Count; } + } + } + return new AggregateValue(sum, count); + } + else + { + return AverageInner(source); + } + } + + private static AggregateValue AverageInner(IEnumerable source) + { + double sum = 0; + long count = 0; + foreach (float? x in source) + { + if (x != null) + { + sum += x.GetValueOrDefault(); + checked { count++; } + } + } + return new AggregateValue(sum, count); + } + + public static AggregateValue Average(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return AverageInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(AverageInner(s)), false); + double sum = 0; + long count = 0; + foreach (var elem in partialResults) + { + sum += elem.Value; + checked { count += elem.Count; } + } + return new AggregateValue(sum, count); + } + else + { + return AverageInner(source); + } + } + + private static AggregateValue AverageInner(IEnumerable source) + { + double sum = 0; + long count = 0; + foreach (double x in source) + { + sum += x; + checked { count++; } + } + return new AggregateValue(sum, count); + } + + public static double Average(IEnumerable> source) + { + double sum = 0; + long count = 0; + foreach (var x in source) + { + sum += x.Value; + checked { count += x.Count; } + } + + if (count == 0) + { + throw new DryadLinqException(HpcLinqErrorCode.AverageNoElements, SR.AverageNoElements); + } + return sum / count; + } + + public static AggregateValue Average(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe != null) + { + return AverageInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(AverageInner(s)), false); + double sum = 0; + long count = 0; + foreach (var elem in partialResults) + { + sum += elem.Value.GetValueOrDefault(); + checked { count += elem.Count; } + } + return new AggregateValue(sum, count); + } + else + { + return AverageInner(source); + } + } + + private static AggregateValue AverageInner(IEnumerable source) + { + double sum = 0; + long count = 0; + foreach (double? x in source) + { + if (x != null) + { + sum += x.GetValueOrDefault(); + checked { count++; } + } + } + return new AggregateValue(sum, count); + } + + public static double? Average(IEnumerable> source) + { + double sum = 0; + long count = 0; + foreach (var x in source) + { + sum += x.Value.GetValueOrDefault(); + checked { count += x.Count; } + } + + if (count == 0) return null; + return sum / count; + } + + public static AggregateValue Average(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe == null) + { + return AverageInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(AverageInner(s)), false); + decimal sum = 0; + long count = 0; + foreach (var elem in partialResults) + { + sum += elem.Value; + checked { count += elem.Count; } + } + return new AggregateValue(sum, count); + } + else + { + return AverageInner(source); + } + } + + private static AggregateValue AverageInner(IEnumerable source) + { + decimal sum = 0; + long count = 0; + foreach (decimal x in source) + { + sum += x; + checked { count++; } + } + return new AggregateValue(sum, count); + } + + public static decimal Average(IEnumerable> source) + { + decimal sum = 0; + long count = 0; + foreach (var x in source) + { + sum += x.Value; + checked { count += x.Count; } + } + + if (count == 0) + { + throw new DryadLinqException(HpcLinqErrorCode.AverageNoElements, SR.AverageNoElements); + } + return sum / count; + } + + public static AggregateValue Average(IEnumerable source) + { + if (s_multiThreading) + { + IParallelPipeline pipe = source as IParallelPipeline; + if (pipe != null) + { + return AverageInner(source); + } + + IEnumerable> partialResults = pipe.Extend(s => AsEnumerable(AverageInner(s)), false); + decimal sum = 0; + long count = 0; + foreach (var elem in partialResults) + { + sum += elem.Value.GetValueOrDefault(); + checked { count += elem.Count; } + } + return new AggregateValue(sum, count); + } + else + { + return AverageInner(source); + } + } + + private static AggregateValue AverageInner(IEnumerable source) + { + decimal sum = 0; + long count = 0; + foreach (decimal? x in source) + { + if (x != null) + { + sum += x.GetValueOrDefault(); + checked { count++; } + } + } + return new AggregateValue(sum, count); + } + + public static decimal? Average(IEnumerable> source) + { + decimal sum = 0; + long count = 0; + foreach (var x in source) + { + sum += x.Value.GetValueOrDefault(); + checked { count += x.Count; } + } + + if (count == 0) return null; + return sum / count; + } + + public static AggregateValue Average(IEnumerable source, + Func selector) + { + return Average(Select(source, selector, false)); + } + + public static AggregateValue Average(IEnumerable source, + Func selector) + { + return Average(Select(source, selector, false)); + } + + public static AggregateValue Average(IEnumerable source, + Func selector) + { + return Average(Select(source, selector, false)); + } + + public static AggregateValue Average(IEnumerable source, + Func selector) + { + return Average(Select(source, selector, false)); + } + + public static AggregateValue Average(IEnumerable source, + Func selector) + { + return Average(Select(source, selector, false)); + } + + public static AggregateValue Average(IEnumerable source, + Func selector) + { + return Average(Select(source, selector, false)); + } + + public static AggregateValue Average(IEnumerable source, + Func selector) + { + return Average(Select(source, selector, false)); + } + + public static AggregateValue Average(IEnumerable source, + Func selector) + { + return Average(Select(source, selector, false)); + } + + public static AggregateValue Average(IEnumerable source, + Func selector) + { + return Average(Select(source, selector, false)); + } + + public static AggregateValue Average(IEnumerable source, + Func selector) + { + return Average(Select(source, selector, false)); + } + + public static AggregateValue AverageAccumulate(AggregateValue a, int x) + { + return new AggregateValue(checked(a.Value + x), checked(a.Count + 1)); + } + + public static AggregateValue AverageAccumulate(AggregateValue a, int? x) + { + if (x == null) return a; + return new AggregateValue(checked(a.Value + x.GetValueOrDefault()), + checked(a.Count + 1)); + } + + public static AggregateValue AverageAccumulate(AggregateValue a, long x) + { + return new AggregateValue(checked(a.Value + x), checked(a.Count + 1)); + } + + public static AggregateValue AverageAccumulate(AggregateValue a, long? x) + { + if (x == null) return a; + return new AggregateValue(checked(a.Value + x.GetValueOrDefault()), + checked(a.Count + 1)); + } + + public static AggregateValue AverageAccumulate(AggregateValue a, float x) + { + return new AggregateValue(a.Value + x, + checked(a.Count + 1)); + } + + public static AggregateValue AverageAccumulate(AggregateValue a, float? x) + { + if (x == null) return a; + return new AggregateValue(a.Value + x.GetValueOrDefault(), + checked(a.Count + 1)); + } + + public static AggregateValue AverageAccumulate(AggregateValue a, double x) + { + return new AggregateValue(a.Value + x, + checked(a.Count + 1)); + } + + public static AggregateValue AverageAccumulate(AggregateValue a, double? x) + { + if (x == null) return a; + return new AggregateValue(a.Value + x.GetValueOrDefault(), + checked(a.Count + 1)); + } + + public static AggregateValue AverageAccumulate(AggregateValue a, decimal x) + { + return new AggregateValue(a.Value + x, + checked(a.Count + 1)); + } + + public static AggregateValue AverageAccumulate(AggregateValue a, decimal? x) + { + if (x == null) return a; + return new AggregateValue(a.Value + x.GetValueOrDefault(), + checked(a.Count + 1)); + } + + // Operator: Any + public static bool Any(IEnumerable source, + Func predicate) + { + if (s_multiThreading) + { + var partialResults = source.ExtendParallelPipeline(x => AsEnumerable(x.Any(predicate)), false); + foreach (bool item in partialResults) + { + if (item) return true; + } + return false; + } + else + { + return Enumerable.Any(source, predicate); + } + } + + public static bool Any(IEnumerable source) + { + foreach (bool item in source) + { + if (item) return true; + } + return false; + } + + // Operator: All + public static bool All(IEnumerable source, + Func predicate) + { + if (s_multiThreading) + { + var partialResults = source.ExtendParallelPipeline(x => AsEnumerable(x.All(predicate)), false); + foreach (bool item in partialResults) + { + if (!item) return false; + } + return true; + } + else + { + return Enumerable.All(source, predicate); + } + } + + public static bool All(IEnumerable source) + { + foreach (bool item in source) + { + if (!item) return false; + } + return true; + } + + // Operator: Reverse + public static IEnumerable Reverse(IEnumerable source) + { + BigCollection buffer = new BigCollection(); + foreach (var elem in source) + { + buffer.Add(elem); + } + return buffer.Reverse(); + } + + // Operator: Merge + // When not pipelined, this should be implemented in native for better performance. + public static IEnumerable Merge(IEnumerable source) + { + return source; + } + + // Operator: Apply + public static IEnumerable + Apply(IEnumerable source, + Func, IEnumerable> procFunc) + { + return procFunc(source); + } + + public static IEnumerable + Apply(IEnumerable source1, + IEnumerable source2, + Func, IEnumerable, IEnumerable> procFunc) + { + return procFunc(source1, source2); + } + + public static IEnumerable + Apply(IEnumerable[] sources, + Func[], IEnumerable> procFunc) + { + return procFunc(sources); + } + + public static IEnumerable + PApply(this IEnumerable source, + Func, IEnumerable> procFunc, + bool orderPreserving) + { + return source.ExtendParallelPipeline(procFunc, orderPreserving); + } + + // Operator: HashPartition + public static void + HashPartition(IEnumerable source, + IEqualityComparer comparer, + HpcVertexWriter sink) + { + if (comparer == null) + { + comparer = EqualityComparer.Default; + } + + Int32 numOfPorts = (Int32)sink.NumberOfOutputs; + foreach (TSource item in source) + { + Int32 hashCode = comparer.GetHashCode(item) & 0x7FFFFFFF; + Int32 portNum = hashCode % numOfPorts; + sink.WriteItem(item, portNum); + } + sink.CloseWriters(); + } + + public static void + HashPartition(IEnumerable source, + IEqualityComparer comparer, + Func resultSelector, + HpcVertexWriter sink) + { + if (comparer == null) + { + comparer = EqualityComparer.Default; + } + + Int32 numOfPorts = (Int32)sink.NumberOfOutputs; + foreach (TSource item in source) + { + Int32 hashCode = comparer.GetHashCode(item) & 0x7FFFFFFF; + Int32 portNum = hashCode % numOfPorts; + sink.WriteItem(resultSelector(item), portNum); + } + sink.CloseWriters(); + } + + public static void + HashPartition(IEnumerable source, + Func keySelector, + bool isExpensive, + IEqualityComparer comparer, + HpcVertexWriter sink) + { + if (s_multiThreading && isExpensive) + { + var source1 = source.ExtendParallelPipeline(s => s.Select(x => new Pair(keySelector(x), x)), true); + HashPartition(source1, x => x.Key, comparer, x => x.Value, sink); + } + else + { + HashPartition(source, keySelector, comparer, sink); + } + } + + private static void + HashPartition(IEnumerable source, + Func keySelector, + IEqualityComparer comparer, + HpcVertexWriter sink) + { + if (comparer == null) + { + comparer = EqualityComparer.Default; + } + + Int32 numOfPorts = (Int32)sink.NumberOfOutputs; + foreach (TSource item in source) + { + Int32 hashCode = comparer.GetHashCode(keySelector(item)) & 0x7FFFFFFF; + Int32 portNum = hashCode % numOfPorts; + sink.WriteItem(item, portNum); + } + sink.CloseWriters(); + } + + public static void + HashPartition(IEnumerable source, + Func keySelector, + bool isExpensive, + IEqualityComparer comparer, + Func resultSelector, + HpcVertexWriter sink) + { + if (s_multiThreading && isExpensive) + { + var source1 = source.ExtendParallelPipeline(s => s.Select(x => new Pair(keySelector(x), resultSelector(x))), true); + HashPartition(source1, x => x.Key, comparer, x => x.Value, sink); + } + else + { + HashPartition(source, keySelector, comparer, resultSelector, sink); + } + } + + private static void + HashPartition(IEnumerable source, + Func keySelector, + IEqualityComparer comparer, + Func resultSelector, + HpcVertexWriter sink) + { + if (comparer == null) + { + comparer = EqualityComparer.Default; + } + + Int32 numOfPorts = (Int32)sink.NumberOfOutputs; + foreach (TSource item in source) + { + Int32 hashCode = comparer.GetHashCode(keySelector(item)) & 0x7FFFFFFF; + Int32 portNum = hashCode % numOfPorts; + sink.WriteItem(resultSelector(item), portNum); + } + sink.CloseWriters(); + } + + // Operator: RangePartition + + // special case for keySelector=identityFunction + public static void + RangePartition(IEnumerable source, + IEnumerable partitionKeys, + IComparer comparer, + bool isDescending, + HpcVertexWriter sink) + { + if (partitionKeys == null) + { + throw new DryadLinqException(HpcLinqErrorCode.RangePartitionKeysMissing, + SR.RangePartitionKeysMissing); + } + comparer = TypeSystem.GetComparer(comparer); + Int32 numOfPorts = (Int32)sink.NumberOfOutputs; + TSource[] keys = new TSource[numOfPorts - 1]; + int idx = 0; + foreach (TSource key in partitionKeys) + { + if (idx < keys.Length) + { + keys[idx] = key; + } + idx++; + } + if (idx > keys.Length) + { + throw new DryadLinqException(HpcLinqErrorCode.RangePartitionInputOutputMismatch, + String.Format(SR.RangePartitionInputOutputMismatch, + idx, numOfPorts)); + } + if (idx < keys.Length) + { + TSource[] keys1 = new TSource[idx]; + Array.Copy(keys, keys1, idx); + keys = keys1; + } + foreach (TSource item in source) + { + int portNum = HpcLinqUtil.BinarySearch(keys, item, comparer, isDescending); + sink.WriteItem(item, portNum); + } + sink.CloseWriters(); + } + + // special case for keySelector=identityFunction and resultSelector + public static void + RangePartition(IEnumerable source, + IEnumerable partitionKeys, + IComparer comparer, + bool isDescending, + Func resultSelector, + HpcVertexWriter sink) + { + if (partitionKeys == null) + { + throw new DryadLinqException(HpcLinqErrorCode.RangePartitionKeysMissing, + SR.RangePartitionKeysMissing); + } + comparer = TypeSystem.GetComparer(comparer); + Int32 numOfPorts = (Int32)sink.NumberOfOutputs; + TSource[] keys = new TSource[numOfPorts - 1]; + int idx = 0; + foreach (TSource key in partitionKeys) + { + if (idx < keys.Length) + { + keys[idx] = key; + } + idx++; + } + if (idx > keys.Length) + { + throw new DryadLinqException(HpcLinqErrorCode.RangePartitionInputOutputMismatch, + String.Format(SR.RangePartitionInputOutputMismatch, + idx, numOfPorts)); + } + if (idx < keys.Length) + { + TSource[] keys1 = new TSource[idx]; + Array.Copy(keys, keys1, idx); + keys = keys1; + } + foreach (TSource item in source) + { + int portNum = HpcLinqUtil.BinarySearch(keys, item, comparer, isDescending); + sink.WriteItem(resultSelector(item), portNum); + } + sink.CloseWriters(); + } + + // general case + public static void + RangePartition(IEnumerable source, + Func keySelector, + bool isExpensive, + IEnumerable partitionKeys, + IComparer comparer, + bool isDescending, + HpcVertexWriter sink) + { + if (s_multiThreading && isExpensive) + { + var source1 = source.ExtendParallelPipeline(s => s.Select(x => new Pair(keySelector(x), x)), true); + RangePartition(source1, x => x.Key, partitionKeys, comparer, isDescending, x => x.Value, sink); + } + else + { + RangePartition(source, keySelector, partitionKeys, comparer, isDescending, sink); + } + } + + private static void + RangePartition(IEnumerable source, + Func keySelector, + IEnumerable partitionKeys, + IComparer comparer, + bool isDescending, + HpcVertexWriter sink) + { + if (partitionKeys == null) + { + throw new DryadLinqException(HpcLinqErrorCode.RangePartitionKeysMissing, + SR.RangePartitionKeysMissing); + } + comparer = TypeSystem.GetComparer(comparer); + + Int32 numOfPorts = (Int32)sink.NumberOfOutputs; + TKey[] keys = new TKey[numOfPorts - 1]; + int idx = 0; + foreach (TKey key in partitionKeys) + { + if (idx < keys.Length) + { + keys[idx] = key; + } + idx++; + } + if (idx > keys.Length) + { + throw new DryadLinqException(HpcLinqErrorCode.RangePartitionInputOutputMismatch, + String.Format(SR.RangePartitionInputOutputMismatch, + idx, numOfPorts)); + } + if (idx < keys.Length) + { + TKey[] keys1 = new TKey[idx]; + Array.Copy(keys, keys1, idx); + keys = keys1; + } + foreach (TSource item in source) + { + int portNum = HpcLinqUtil.BinarySearch(keys, keySelector(item), comparer, isDescending); + sink.WriteItem(item, portNum); + } + sink.CloseWriters(); + } + + // general case with result-selector + public static void + RangePartition(IEnumerable source, + Func keySelector, + bool isExpensive, + IEnumerable partitionKeys, + IComparer comparer, + bool isDescending, + Func resultSelector, + HpcVertexWriter sink) + { + if (s_multiThreading && isExpensive) + { + var source1 = source.ExtendParallelPipeline(s => s.Select(x => new Pair(keySelector(x), resultSelector(x))), true); + RangePartition(source1, x => x.Key, partitionKeys, comparer, isDescending, x => x.Value, sink); + } + else + { + RangePartition(source, keySelector, partitionKeys, comparer, isDescending, resultSelector, sink); + } + } + + private static void + RangePartition(IEnumerable source, + Func keySelector, + IEnumerable partitionKeys, + IComparer comparer, + bool isDescending, + Func resultSelector, + HpcVertexWriter sink) + { + if (partitionKeys == null) + { + throw new DryadLinqException(HpcLinqErrorCode.RangePartitionKeysMissing, + SR.RangePartitionKeysMissing); + } + comparer = TypeSystem.GetComparer(comparer); + + Int32 numOfPorts = (Int32)sink.NumberOfOutputs; + TKey[] keys = new TKey[numOfPorts - 1]; + int idx = 0; + foreach (TKey key in partitionKeys) + { + if (idx < keys.Length) + { + keys[idx] = key; + } + idx++; + } + if (idx > keys.Length) + { + throw new DryadLinqException(HpcLinqErrorCode.RangePartitionInputOutputMismatch, + String.Format(SR.RangePartitionInputOutputMismatch, + idx, numOfPorts)); + } + if (idx < keys.Length) + { + TKey[] keys1 = new TKey[idx]; + Array.Copy(keys, keys1, idx); + keys = keys1; + } + foreach (TSource item in source) + { + int portNum = HpcLinqUtil.BinarySearch(keys, keySelector(item), comparer, isDescending); + sink.WriteItem(resultSelector(item), portNum); + } + sink.CloseWriters(); + } + + // Operator: Fork + public static void Fork(IEnumerable source, + Func, IEnumerable>> mapper, + bool orderPreserving, + HpcVertexWriter sink1, + HpcVertexWriter sink2) + { + HpcRecordWriter writer1 = sink1.GetWriter(0); + HpcRecordWriter writer2 = sink2.GetWriter(0); + + IEnumerable> result = mapper(source); + foreach (ForkTuple val in result) + { + if (val.HasFirst) + { + writer1.WriteRecordAsync(val.First); + } + if (val.HasSecond) + { + writer2.WriteRecordAsync(val.Second); + } + } + sink1.CloseWriters(); + sink2.CloseWriters(); + } + + public static void Fork(IEnumerable source, + Func, IEnumerable>> mapper, + bool orderPreserving, + HpcVertexWriter sink1, + HpcVertexWriter sink2, + HpcVertexWriter sink3) + { + HpcRecordWriter writer1 = sink1.GetWriter(0); + HpcRecordWriter writer2 = sink2.GetWriter(0); + HpcRecordWriter writer3 = sink3.GetWriter(0); + + IEnumerable> result = mapper(source); + foreach (ForkTuple val in result) + { + if (val.HasFirst) + { + writer1.WriteRecordAsync(val.First); + } + if (val.HasSecond) + { + writer2.WriteRecordAsync(val.Second); + } + if (val.HasThird) + { + writer3.WriteRecordAsync(val.Third); + } + } + sink1.CloseWriters(); + sink2.CloseWriters(); + sink3.CloseWriters(); + } + + public static void Fork(IEnumerable source, + Func> mapper, + bool orderPreserving, + HpcVertexWriter sink1, + HpcVertexWriter sink2) + { + HpcRecordWriter writer1 = sink1.GetWriter(0); + HpcRecordWriter writer2 = sink2.GetWriter(0); + + IEnumerable> + result = source.ExtendParallelPipeline(s => s.Select(mapper), orderPreserving); + + foreach (ForkTuple val in result) + { + if (val.HasFirst) + { + writer1.WriteRecordAsync(val.First); + } + if (val.HasSecond) + { + writer2.WriteRecordAsync(val.Second); + } + } + sink1.CloseWriters(); + sink2.CloseWriters(); + } + + public static void Fork(IEnumerable source, + Func> mapper, + bool orderPreserving, + HpcVertexWriter sink1, + HpcVertexWriter sink2, + HpcVertexWriter sink3) + { + HpcRecordWriter writer1 = sink1.GetWriter(0); + HpcRecordWriter writer2 = sink2.GetWriter(0); + HpcRecordWriter writer3 = sink3.GetWriter(0); + + IEnumerable> + result = source.ExtendParallelPipeline(s => s.Select(mapper), orderPreserving); + + foreach (ForkTuple val in result) + { + if (val.HasFirst) + { + writer1.WriteRecordAsync(val.First); + } + if (val.HasSecond) + { + writer2.WriteRecordAsync(val.Second); + } + if (val.HasThird) + { + writer3.WriteRecordAsync(val.Third); + } + } + sink1.CloseWriters(); + sink2.CloseWriters(); + sink3.CloseWriters(); + } + + public static void Fork(IEnumerable source, + Func keySelector, + K[] keys, + bool orderPreserving, + params HpcVertexWriter[] sinks) + { + if (keys.Length != sinks.Length) + { + throw new DryadLinqException(SR.NumberOfKeysMustEqualNumOutputPorts); + } + Dictionary keyMap = new Dictionary(keys.Length); + for (int i = 0; i < keys.Length; i++) + { + keyMap.Add(keys[i], i); + } + HpcRecordWriter[] writers = new HpcRecordWriter[keys.Length]; + for (int i = 0; i < writers.Length; i++) + { + writers[i] = sinks[i].GetWriter(0); + } + foreach (T item in source) + { + int portNum; + if (keyMap.TryGetValue(keySelector(item), out portNum)) + { + writers[portNum].WriteRecordAsync(item); + } + } + foreach (var sink in sinks) + { + sink.CloseWriters(); + } + } + + public static IEnumerable AsEnumerable(T value) + { + yield return value; + } + } + + public struct AggregateValue + { + private T _value; + private long _count; + + public AggregateValue(T val, long count) + { + _value = val; + _count = count; + } + + public T Value + { + get { return _value; } + set { _value = value; } + } + + public long Count + { + get { return _count; } + set { _count = value; } + } + + public override string ToString() + { + return "<" + Value + ", " + Count + ">"; + } + } + + internal class ParallelHashGroupBy : IParallelPipeline + { + private IEnumerable m_source; + private Func m_keySelector; + private Func m_elementSelector; + private Func, TResult> m_resultSelector; + private IEqualityComparer m_comparer; + private Func, IEnumerable> m_applyFunc; + + public ParallelHashGroupBy(IEnumerable source, + Func keySelector, + Func elementSelector, + Func, TResult> resultSelector, + IEqualityComparer comparer, + Func, IEnumerable> applyFunc) + { + if (resultSelector == null || elementSelector == null) + { + throw new DryadLinqException("Internal error: The accumulator and element selector can't be null"); + } + this.m_source = source; + this.m_keySelector = keySelector; + this.m_elementSelector = elementSelector; + this.m_resultSelector = resultSelector; + this.m_applyFunc = applyFunc; + this.m_comparer = comparer; + if (this.m_comparer == null) + { + this.m_comparer = EqualityComparer.Default; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + return new InnerEnumerator(this); + } + + public IParallelPipeline + Extend(Func, IEnumerable> func, bool orderPreserving) + { + if (this.m_applyFunc == null) + { + var applyFunc = (Func, IEnumerable>)(object)func; + return new ParallelHashGroupBy( + this.m_source, this.m_keySelector, this.m_elementSelector, + this.m_resultSelector, this.m_comparer, applyFunc); + } + else + { + return new ParallelHashGroupBy( + this.m_source, this.m_keySelector, this.m_elementSelector, + this.m_resultSelector, this.m_comparer, s => func(this.m_applyFunc(s))); + } + } + + private class InnerEnumerator : IEnumerator + { + private const Int32 HashSetCapacity = 16411; + private const Int32 BufferSize = 2048; + + private IEnumerable m_source; + private Func m_keySelector; + private Func m_elementSelector; + private Func, TResult> m_resultSelector; + private IEqualityComparer m_comparer; + private Func, IEnumerable> m_applyFunc; + private Thread m_mainWorker; + private Task[] m_workers; + private BlockingCollection[] m_queues; + private BlockingCollection m_resultSet; + private BlockingCollection m_stealingQueue; + private int m_stealingWorkerCnt; + private bool m_isDone; + private Exception m_workerException; + + private IEnumerator m_resultEnum; + private TFinal[] m_currentItems; + private int m_index; + private bool m_disposed; + + public InnerEnumerator(ParallelHashGroupBy parent) + { + this.m_source = parent.m_source; + this.m_keySelector = parent.m_keySelector; + this.m_elementSelector = parent.m_elementSelector; + this.m_resultSelector = parent.m_resultSelector; + this.m_comparer = parent.m_comparer; + this.m_applyFunc = parent.m_applyFunc; + + this.m_isDone = false; + this.m_stealingWorkerCnt = 0; + this.m_stealingQueue = new BlockingCollection(); + this.m_workerException = null; + this.m_workers = new Task[Environment.ProcessorCount]; + this.m_queues = new BlockingCollection[this.m_workers.Length]; + for (int i = 0; i < this.m_queues.Length; i++) + { + this.m_queues[i] = new BlockingCollection(2); + } + this.m_resultSet = new BlockingCollection(4); + + for (int i = 0; i < this.m_workers.Length; i++) + { + this.m_workers[i] = this.CreateTask(i); + } + this.m_mainWorker = new Thread(this.ProcessAllItems); + this.m_mainWorker.Name = "HGB.ProcessAllItem"; + this.m_mainWorker.Start(); + + this.m_resultEnum = this.m_resultSet.GetConsumingEnumerable().GetEnumerator(); + this.m_currentItems = new TFinal[0]; + this.m_index = -1; + this.m_disposed = false; + } + + ~InnerEnumerator() + { + this.Dispose(false); + } + + public bool MoveNext() + { + this.m_index++; + if (this.m_index < this.m_currentItems.Length) + { + return true; + } + + while (this.m_resultEnum.MoveNext()) + { + if (this.m_workerException != null) break; + + this.m_currentItems = this.m_resultEnum.Current; + if (this.m_currentItems.Length > 0) + { + this.m_index = 0; + return true; + } + } + if (this.m_workerException != null) + { + throw new DryadLinqException("Failed while enumerating.", this.m_workerException); + } + return false; + } + + public TFinal Current + { + get { return this.m_currentItems[this.m_index]; } + } + + object IEnumerator.Current + { + get { return this.m_currentItems[this.m_index]; } + } + + public void Reset() + { + throw new DryadLinqException("Internal error: Cannot reset this IEnumerator."); + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!this.m_disposed) + { + this.m_disposed = true; + this.m_isDone = true; + + // Always drain the queues + foreach (var queue in this.m_queues) + { + foreach (var item in queue.GetConsumingEnumerable()) + { + } + } + + // Always drain the result queue + while (this.m_resultEnum.MoveNext()) + { + } + } + } + + private void ProcessAllItems() + { + try + { + DryadLinqLog.Add("Parallel GroupBy (Hash) started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + // Read all the items + TSource[][] buffers = new TSource[this.m_workers.Length][]; + Int32[] counts = new Int32[this.m_workers.Length]; + for (int i = 0; i < buffers.Length; i++) + { + buffers[i] = new TSource[BufferSize]; + counts[i] = 0; + } + foreach (TSource item in this.m_source) + { + Int32 hashCode = this.m_comparer.GetHashCode(m_keySelector(item)); + Int32 idx = HpcLinqUtil.GetTaskIndex(hashCode, counts.Length); + if (counts[idx] == BufferSize) + { + if (this.m_isDone) break; + + this.m_queues[idx].Add(buffers[idx]); + buffers[idx] = new TSource[BufferSize]; + counts[idx] = 0; + } + buffers[idx][counts[idx]] = item; + counts[idx]++; + } + + // Add the final buffers to the queues and declare adding is complete + for (int i = 0; i < counts.Length; i++) + { + if (!this.m_isDone && counts[i] > 0) + { + TSource[] buffer = new TSource[counts[i]]; + Array.Copy(buffers[i], buffer, counts[i]); + this.m_queues[i].Add(buffer); + } + this.m_queues[i].CompleteAdding(); + } + + DryadLinqLog.Add("Parallel GroupBy (Hash) ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + // Wait for all the workers to complete + Task.WaitAll(this.m_workers); + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = e; + } + finally + { + for (int i = 0; i < this.m_queues.Length; i++) + { + this.m_queues[i].CompleteAdding(); + } + this.m_resultSet.CompleteAdding(); + } + } + + private Task CreateTask(Int32 idx) + { + return Task.Factory.StartNew(delegate { this.HashGroupBy(idx); }, TaskCreationOptions.LongRunning); + } + + private void HashGroupBy(Int32 idx) + { + try + { + DryadLinqLog.Add("Parallel GroupBy (Hash) worker {0} started at {1}", + idx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + BlockingCollection queue = this.m_queues[idx]; + Int32 wlen = this.m_workers.Length; + GroupingHashSet + groups = new GroupingHashSet(this.m_comparer, HashSetCapacity); + TResult[] resultBuffer = new TResult[BufferSize]; + Int32 count = 0; + Int32 myWorkCnt = 0; + Int32 otherWorkCnt = 0; + Int32 stealingWorkerCnt = this.m_stealingWorkerCnt; + + foreach (var buffer in queue.GetConsumingEnumerable()) + { + if (this.m_isDone) break; + + for (int i = 0; i < buffer.Length; i++) + { + TSource item = buffer[i]; + groups.AddItem(this.m_keySelector(item), this.m_elementSelector(item)); + } + } + + foreach (IGrouping g in groups) + { + if (count == BufferSize) + { + if (this.m_isDone) break; + + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TFinal[])(object)resultBuffer); + resultBuffer = new TResult[BufferSize]; + } + else + { + int newStealingWorkerCnt = this.m_stealingWorkerCnt; + if (newStealingWorkerCnt > stealingWorkerCnt) + { + myWorkCnt = 0; + otherWorkCnt = 0; + stealingWorkerCnt = newStealingWorkerCnt; + } + if ((stealingWorkerCnt * myWorkCnt) <= ((wlen - stealingWorkerCnt) * otherWorkCnt)) + { + this.m_resultSet.Add(this.m_applyFunc(resultBuffer).ToArray()); + myWorkCnt++; + } + else + { + this.m_stealingQueue.Add(resultBuffer); + resultBuffer = new TResult[BufferSize]; + otherWorkCnt++; + } + } + count = 0; + } + resultBuffer[count++] = this.m_resultSelector(g.Key, g); + } + + // Add the last buffer + if (!this.m_isDone && count > 0) + { + TResult[] lastResultBuffer = new TResult[count]; + Array.Copy(resultBuffer, lastResultBuffer, count); + resultBuffer = null; + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TFinal[])(object)lastResultBuffer); + } + else + { + this.m_resultSet.Add(this.m_applyFunc(lastResultBuffer).ToArray()); + } + } + + // Stealing work + if (this.m_applyFunc != null) + { + stealingWorkerCnt = Interlocked.Increment(ref this.m_stealingWorkerCnt); + if (stealingWorkerCnt == wlen) + { + this.m_stealingQueue.CompleteAdding(); + } + foreach (TResult[] buffer in this.m_stealingQueue.GetConsumingEnumerable()) + { + if (this.m_isDone) return; + this.m_resultSet.Add(this.m_applyFunc(buffer).ToArray()); + } + } + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = new DryadLinqException(HpcLinqErrorCode.FailureInHashGroupBy, + SR.FailureInHashGroupBy, e); + this.m_resultSet.Add(new TFinal[0]); + this.m_stealingQueue.CompleteAdding(); + } + finally + { + DryadLinqLog.Add("Parallel GroupBy (Hash) worker {0} ended at {1}", + idx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + } + } + } + + internal class ParallelHashGroupByPartialAccumulate + : IEnumerable> + { + private IEnumerable m_source; + private Func, IEnumerable> m_preApply; + private Func m_keySelector; + private Func m_elementSelector; + private Func m_seed; + private Func m_accumulator; + private IEqualityComparer m_comparer; + + public ParallelHashGroupByPartialAccumulate(IEnumerable source, + Func, IEnumerable> preApply, + Func keySelector, + Func elementSelector, + Func seed, + Func accumulator, + IEqualityComparer comparer) + { + if (seed == null || accumulator == null || elementSelector == null) + { + throw new DryadLinqException("Internal error: The accumulator and element selector can't be null"); + } + this.m_source = source; + this.m_preApply = preApply; + this.m_keySelector = keySelector; + this.m_elementSelector = elementSelector; + this.m_seed = seed; + this.m_accumulator = accumulator; + this.m_comparer = comparer; + if (this.m_comparer == null) + { + this.m_comparer = EqualityComparer.Default; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator> GetEnumerator() + { + return new InnerEnumerator(this); + } + + private class InnerEnumerator : IEnumerator> + { + private const Int32 HashSetCapacity = 16411; + private const Int32 BufferSize = 2048; + + private IEnumerable m_source; + private Func, IEnumerable> m_preApply; + private Func m_keySelector; + private Func m_elementSelector; + private Func m_seed; + private Func m_accumulator; + private IEqualityComparer m_comparer; + private Thread m_mainWorker; + private Task[] m_workers; + private BlockingCollection[] m_queues; + private BlockingCollection[]> m_resultSet; + private bool m_isDone; + private Exception m_workerException; + + private IEnumerator[]> m_resultEnum; + private Pair[] m_currentItems; + private int m_index; + private bool m_disposed; + + public InnerEnumerator(ParallelHashGroupByPartialAccumulate parent) + { + this.m_source = parent.m_source; + this.m_preApply = parent.m_preApply; + this.m_keySelector = parent.m_keySelector; + this.m_elementSelector = parent.m_elementSelector; + this.m_seed = parent.m_seed; + this.m_accumulator = parent.m_accumulator; + this.m_comparer = parent.m_comparer; + + this.m_isDone = false; + this.m_workerException = null; + this.m_workers = new Task[Environment.ProcessorCount]; + this.m_queues = new BlockingCollection[this.m_workers.Length]; + for (int i = 0; i < this.m_queues.Length; i++) + { + this.m_queues[i] = new BlockingCollection(2); + } + this.m_resultSet = new BlockingCollection[]>(4); + + for (int i = 0; i < this.m_workers.Length; i++) + { + this.m_workers[i] = this.CreateTask(i); + } + this.m_mainWorker = new Thread(this.ProcessAllItems); + this.m_mainWorker.Name = "HGBPA.ProcessAllItem"; + this.m_mainWorker.Start(); + + this.m_resultEnum = this.m_resultSet.GetConsumingEnumerable().GetEnumerator(); + this.m_currentItems = new Pair[0]; + this.m_index = -1; + this.m_disposed = false; + } + + ~InnerEnumerator() + { + this.Dispose(false); + } + + public bool MoveNext() + { + this.m_index++; + if (this.m_index < this.m_currentItems.Length) + { + return true; + } + + while (this.m_resultEnum.MoveNext()) + { + if (this.m_workerException != null) break; + + this.m_currentItems = this.m_resultEnum.Current; + if (this.m_currentItems.Length > 0) + { + this.m_index = 0; + return true; + } + } + if (this.m_workerException != null) + { + throw new DryadLinqException("Failed while enumerating.", this.m_workerException); + } + return false; + } + + public Pair Current + { + get { return this.m_currentItems[this.m_index]; } + } + + object IEnumerator.Current + { + get { return this.m_currentItems[this.m_index]; } + } + + public void Reset() + { + throw new DryadLinqException("Internal error: Cannot reset this IEnumerator."); + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!this.m_disposed) + { + this.m_disposed = true; + this.m_isDone = true; + + // Always drain the queues + foreach (var queue in this.m_queues) + { + foreach (var item in queue.GetConsumingEnumerable()) + { + } + } + + // Always drain the result queue + while (this.m_resultEnum.MoveNext()) + { + } + } + } + + private void ProcessAllItems() + { + try + { + DryadLinqLog.Add("Parallel GroupBy (HashPartialAcc) started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + // Read all the items + TSource[] buffer = new TSource[BufferSize]; + Int32 count = 0; + foreach (TSource item in this.m_source) + { + if (count == BufferSize) + { + if (this.m_isDone) break; + + BlockingCollection.AddToAny(this.m_queues, buffer); + buffer = new TSource[BufferSize]; + count = 0; + } + buffer[count++] = item; + } + + // Add the final buffers to the queues and declare adding is complete + if (!this.m_isDone && count > 0) + { + TSource[] lastBuffer = new TSource[count]; + Array.Copy(buffer, lastBuffer, count); + buffer = null; + BlockingCollection.AddToAny(this.m_queues, lastBuffer); + } + + for (int i = 0; i < this.m_queues.Length; i++) + { + this.m_queues[i].CompleteAdding(); + } + + DryadLinqLog.Add("Parallel GroupBy (HashPartialAcc) ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + // Wait for all the workers to complete + Task.WaitAll(this.m_workers); + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = e; + } + finally + { + for (int i = 0; i < this.m_queues.Length; i++) + { + this.m_queues[i].CompleteAdding(); + } + this.m_resultSet.CompleteAdding(); + } + } + + private Task CreateTask(Int32 idx) + { + return Task.Factory.StartNew(delegate { this.HashGroupBy(idx); }, TaskCreationOptions.LongRunning); + } + + private void HashGroupBy(Int32 idx) + { + try + { + DryadLinqLog.Add("Parallel GroupBy (HashPartialAcc) worker {0} started at {1}", + idx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + BlockingCollection queue = this.m_queues[idx]; + AccumulateDictionary + groups = new AccumulateDictionary( + this.m_comparer, HashSetCapacity, this.m_seed, this.m_accumulator); + + foreach (var buffer in queue.GetConsumingEnumerable()) + { + if (this.m_isDone) break; + + if (this.m_preApply == null) + { + TRecord[] recBuffer = (TRecord[])(object)buffer; + for (int i = 0; i < recBuffer.Length; i++) + { + TRecord item = recBuffer[i]; + groups.Add(this.m_keySelector(item), this.m_elementSelector(item)); + } + } + else + { + IEnumerable recBuffer = this.m_preApply(buffer); + foreach (TRecord item in recBuffer) + { + groups.Add(this.m_keySelector(item), this.m_elementSelector(item)); + } + } + } + + Int32 wlen = this.m_workers.Length; + Pair[] resultBuffer = new Pair[BufferSize]; + Int32 count = 0; + foreach (Pair g in groups) + { + if (count == BufferSize) + { + if (this.m_isDone) break; + + this.m_resultSet.Add(resultBuffer); + resultBuffer = new Pair[BufferSize]; + count = 0; + } + resultBuffer[count++] = g; + } + + if (!this.m_isDone && count > 0) + { + Pair[] lastResultBuffer = new Pair[count]; + Array.Copy(resultBuffer, lastResultBuffer, count); + resultBuffer = null; + this.m_resultSet.Add(lastResultBuffer); + } + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = new DryadLinqException(HpcLinqErrorCode.FailureInHashGroupBy, + SR.FailureInHashGroupBy, e); + this.m_resultSet.Add(new Pair[0]); + } + finally + { + DryadLinqLog.Add("Parallel GroupBy (HashPartialAcc) worker {0} ended at {1}", + idx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + } + } + } + + internal class ParallelHashGroupByFullAccumulate + : IParallelPipeline + { + private IEnumerable m_source; + private Func m_keySelector; + private Func m_elementSelector; + private Func m_seed; + private Func m_accumulator; + private IEqualityComparer m_comparer; + private Func>, IEnumerable> m_postApply; + + public ParallelHashGroupByFullAccumulate(IEnumerable source, + Func keySelector, + Func elementSelector, + Func seed, + Func accumulator, + IEqualityComparer comparer, + Func>, IEnumerable> postApply) + { + if (seed == null || accumulator == null || elementSelector == null) + { + throw new DryadLinqException("Internal error: The accumulator and element selector can't be null"); + } + this.m_source = source; + this.m_keySelector = keySelector; + this.m_elementSelector = elementSelector; + this.m_seed = seed; + this.m_accumulator = accumulator; + this.m_postApply = postApply; + this.m_comparer = comparer; + if (this.m_comparer == null) + { + this.m_comparer = EqualityComparer.Default; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + return new InnerEnumerator(this); + } + + public IParallelPipeline + Extend(Func, IEnumerable> func, bool orderPreserving) + { + if (this.m_postApply == null) + { + var applyFunc = (Func>, IEnumerable>)(object)func; + return new ParallelHashGroupByFullAccumulate( + this.m_source, this.m_keySelector, this.m_elementSelector, this.m_seed, + this.m_accumulator, this.m_comparer, applyFunc); + } + else + { + return new ParallelHashGroupByFullAccumulate( + this.m_source, this.m_keySelector, this.m_elementSelector, this.m_seed, + this.m_accumulator, this.m_comparer, s => func(this.m_postApply(s))); + } + } + + private class InnerEnumerator : IEnumerator + { + private const Int32 HashSetCapacity = 16411; + private const Int32 BufferSize = 2048; + + private IEnumerable m_source; + private Func m_keySelector; + private Func m_elementSelector; + private Func m_seed; + private Func m_accumulator; + private IEqualityComparer m_comparer; + private Func>, IEnumerable> m_postApply; + private Thread m_mainWorker; + private Task[] m_workers; + private BlockingCollection[] m_queues; + private BlockingCollection m_resultSet; + private BlockingCollection[]> m_stealingQueue; + private int m_stealingWorkerCnt; + private bool m_isDone; + private Exception m_workerException; + + private IEnumerator m_resultEnum; + private TFinal[] m_currentItems; + private int m_index; + private bool m_disposed; + + public InnerEnumerator(ParallelHashGroupByFullAccumulate parent) + { + this.m_source = parent.m_source; + this.m_keySelector = parent.m_keySelector; + this.m_elementSelector = parent.m_elementSelector; + this.m_seed = parent.m_seed; + this.m_accumulator = parent.m_accumulator; + this.m_comparer = parent.m_comparer; + this.m_postApply = parent.m_postApply; + + this.m_isDone = false; + this.m_stealingWorkerCnt = 0; + this.m_stealingQueue = new BlockingCollection[]>(); + this.m_workerException = null; + this.m_workers = new Task[Environment.ProcessorCount]; + this.m_queues = new BlockingCollection[this.m_workers.Length]; + for (int i = 0; i < this.m_queues.Length; i++) + { + this.m_queues[i] = new BlockingCollection(2); + } + this.m_resultSet = new BlockingCollection(4); + + for (int i = 0; i < this.m_workers.Length; i++) + { + this.m_workers[i] = this.CreateTask(i); + } + this.m_mainWorker = new Thread(this.ProcessAllItems); + this.m_mainWorker.Name = "HGBFA.ProcessAllItem"; + this.m_mainWorker.Start(); + + this.m_resultEnum = this.m_resultSet.GetConsumingEnumerable().GetEnumerator(); + this.m_currentItems = new TFinal[0]; + this.m_index = -1; + this.m_disposed = false; + } + + ~InnerEnumerator() + { + this.Dispose(false); + } + + public bool MoveNext() + { + this.m_index++; + if (this.m_index < this.m_currentItems.Length) + { + return true; + } + + while (this.m_resultEnum.MoveNext()) + { + if (this.m_workerException != null) break; + + this.m_currentItems = this.m_resultEnum.Current; + if (this.m_currentItems.Length > 0) + { + this.m_index = 0; + return true; + } + } + if (this.m_workerException != null) + { + throw new DryadLinqException("Failed while enumerating.", this.m_workerException); + } + return false; + } + + public TFinal Current + { + get { return this.m_currentItems[this.m_index]; } + } + + object IEnumerator.Current + { + get { return this.m_currentItems[this.m_index]; } + } + + public void Reset() + { + throw new DryadLinqException("Internal error: Cannot reset this IEnumerator."); + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!this.m_disposed) + { + this.m_disposed = true; + this.m_isDone = true; + + // Always drain the queues + foreach (var queue in this.m_queues) + { + foreach (var item in queue.GetConsumingEnumerable()) + { + } + } + + // Always drain the result queue + while (this.m_resultEnum.MoveNext()) + { + } + } + } + + private void ProcessAllItems() + { + try + { + DryadLinqLog.Add("Parallel GroupBy (HashFullAcc) started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + // Read all the items + TSource[][] buffers = new TSource[this.m_workers.Length][]; + Int32[] counts = new Int32[this.m_workers.Length]; + for (int i = 0; i < buffers.Length; i++) + { + buffers[i] = new TSource[BufferSize]; + counts[i] = 0; + } + foreach (TSource item in this.m_source) + { + Int32 hashCode = this.m_comparer.GetHashCode(m_keySelector(item)); + Int32 idx = HpcLinqUtil.GetTaskIndex(hashCode, counts.Length); + if (counts[idx] == BufferSize) + { + if (this.m_isDone) break; + + this.m_queues[idx].Add(buffers[idx]); + buffers[idx] = new TSource[BufferSize]; + counts[idx] = 0; + } + buffers[idx][counts[idx]] = item; + counts[idx]++; + } + + // Add the final buffers to the queues and declare adding is complete + for (int i = 0; i < counts.Length; i++) + { + if (!this.m_isDone && counts[i] > 0) + { + TSource[] buffer = new TSource[counts[i]]; + Array.Copy(buffers[i], buffer, counts[i]); + this.m_queues[i].Add(buffer); + } + this.m_queues[i].CompleteAdding(); + } + + DryadLinqLog.Add("Parallel GroupBy (HashFullAcc) ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + // Wait for all the workers to complete + Task.WaitAll(this.m_workers); + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = e; + } + finally + { + for (int i = 0; i < this.m_queues.Length; i++) + { + this.m_queues[i].CompleteAdding(); + } + this.m_resultSet.CompleteAdding(); + } + } + + private Task CreateTask(Int32 idx) + { + return Task.Factory.StartNew(delegate { this.HashGroupBy(idx); }, TaskCreationOptions.LongRunning); + } + + private void HashGroupBy(Int32 idx) + { + try + { + DryadLinqLog.Add("Parallel GroupBy (HashFullAcc) worker {0} started at {1}", + idx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + BlockingCollection queue = this.m_queues[idx]; + AccumulateDictionary + groups = new AccumulateDictionary( + this.m_comparer, HashSetCapacity, this.m_seed, this.m_accumulator); + + foreach (var buffer in queue.GetConsumingEnumerable()) + { + if (this.m_isDone) break; + + for (int i = 0; i < buffer.Length; i++) + { + TSource item = buffer[i]; + groups.Add(this.m_keySelector(item), this.m_elementSelector(item)); + } + } + + Int32 wlen = this.m_workers.Length; + Pair[] resultBuffer = new Pair[BufferSize]; + Int32 count = 0; + Int32 myWorkCnt = 0; + Int32 otherWorkCnt = 0; + Int32 stealingWorkerCnt = this.m_stealingWorkerCnt; + + foreach (Pair g in groups) + { + if (count == BufferSize) + { + if (this.m_isDone) break; + + if (this.m_postApply == null) + { + this.m_resultSet.Add((TFinal[])(object)resultBuffer); + resultBuffer = new Pair[BufferSize]; + } + else + { + int newStealingWorkerCnt = this.m_stealingWorkerCnt; + if (newStealingWorkerCnt > stealingWorkerCnt) + { + myWorkCnt = 0; + otherWorkCnt = 0; + stealingWorkerCnt = newStealingWorkerCnt; + } + if ((stealingWorkerCnt * myWorkCnt) <= ((wlen - stealingWorkerCnt) * otherWorkCnt)) + { + this.m_resultSet.Add(this.m_postApply(resultBuffer).ToArray()); + myWorkCnt++; + } + else + { + this.m_stealingQueue.Add(resultBuffer); + resultBuffer = new Pair[BufferSize]; + otherWorkCnt++; + } + } + count = 0; + } + resultBuffer[count++] = g; + } + + if (!this.m_isDone && count > 0) + { + Pair[] lastResultBuffer = new Pair[count]; + Array.Copy(resultBuffer, lastResultBuffer, count); + resultBuffer = null; + if (this.m_postApply == null) + { + this.m_resultSet.Add((TFinal[])(object)lastResultBuffer); + } + else + { + this.m_resultSet.Add(this.m_postApply(lastResultBuffer).ToArray()); + } + } + + // Stealing work + if (this.m_postApply != null) + { + stealingWorkerCnt = Interlocked.Increment(ref this.m_stealingWorkerCnt); + if (stealingWorkerCnt == wlen) + { + this.m_stealingQueue.CompleteAdding(); + } + foreach (Pair[] buffer in this.m_stealingQueue.GetConsumingEnumerable()) + { + if (this.m_isDone) return; + this.m_resultSet.Add(this.m_postApply(buffer).ToArray()); + } + } + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = new DryadLinqException(HpcLinqErrorCode.FailureInHashGroupBy, + SR.FailureInHashGroupBy, e); + this.m_resultSet.Add(new TFinal[0]); + this.m_stealingQueue.CompleteAdding(); + } + finally + { + DryadLinqLog.Add("Parallel GroupBy (HashFullAcc) worker {0} ended at {1}", + idx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + } + } + } + + // This is really partial groupby. + internal class ParallelSortGroupBy : IEnumerable + { + private IEnumerable m_source; + private Func m_keySelector; + private Func m_elementSelector; + private Func, TResult> m_resultSelector; + private IComparer m_comparer; + private Func>, IEnumerable> m_applyFunc; + + public ParallelSortGroupBy(IEnumerable source, + Func keySelector, + Func elementSelector, + Func, TResult> resultSelector, + IComparer comparer, + Func>, IEnumerable> applyFunc) + { + if (resultSelector == null || elementSelector == null) + { + throw new DryadLinqException("Internal error: The result and element selectors can't be null"); + } + this.m_source = source; + this.m_keySelector = keySelector; + this.m_elementSelector = elementSelector; + this.m_resultSelector = resultSelector; + this.m_comparer = comparer; + this.m_applyFunc = applyFunc; + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + return new InnerEnumerator(this); + } + + private class InnerEnumerator : IEnumerator + { + private const Int32 ChunkSize = (1 << 22); + private const Int32 ResultChunkSize = 2048; + + private IEnumerable m_source; + private Func m_keySelector; + private Func m_elementSelector; + private Func, TResult> m_resultSelector; + private IComparer m_comparer; + private Func>, IEnumerable> m_applyFunc; + + private Thread m_mainWorker; + private List m_workers; + private BlockingCollection m_resultSet; + private volatile bool m_isDone; + private Exception m_workerException; + + private IEnumerator m_resultEnum; + private TFinal[] m_currentItems; + private int m_index; + private bool m_disposed; + + public InnerEnumerator(ParallelSortGroupBy parent) + { + this.m_source = parent.m_source; + this.m_keySelector = parent.m_keySelector; + this.m_elementSelector = parent.m_elementSelector; + this.m_resultSelector = parent.m_resultSelector; + this.m_comparer = parent.m_comparer; + this.m_applyFunc = parent.m_applyFunc; + + this.m_isDone = false; + this.m_workerException = null; + this.m_workers = new List(16); + this.m_resultSet = new BlockingCollection(); + this.m_mainWorker = new Thread(this.ProcessAllItems); + this.m_mainWorker.Name = "SGB.ProcessAllItem"; + this.m_mainWorker.Start(); + + this.m_resultEnum = this.m_resultSet.GetConsumingEnumerable().GetEnumerator(); + this.m_currentItems = new TFinal[0]; + this.m_index = -1; + this.m_disposed = false; + } + + ~InnerEnumerator() + { + this.Dispose(false); + } + + public bool MoveNext() + { + this.m_index++; + if (this.m_index < this.m_currentItems.Length) + { + return true; + } + + while (this.m_resultEnum.MoveNext()) + { + if (this.m_workerException != null) break; + + this.m_currentItems = this.m_resultEnum.Current; + if (this.m_currentItems.Length > 0) + { + this.m_index = 0; + return true; + } + } + if (this.m_workerException != null) + { + throw new DryadLinqException("Failed while enumerating.", this.m_workerException); + } + return false; + } + + public TFinal Current + { + get { return this.m_currentItems[this.m_index]; } + } + + object IEnumerator.Current + { + get { return this.m_currentItems[this.m_index]; } + } + + public void Reset() + { + throw new DryadLinqException("Internal error: Cannot reset this IEnumerator."); + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!this.m_disposed) + { + this.m_disposed = true; + this.m_isDone = true; + + // Always drain the result queue + while (this.m_resultEnum.MoveNext()) + { + } + } + } + + private void ProcessAllItems() + { + try + { + DryadLinqLog.Add("Parallel GroupBy (PartialSort) started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + TSource[] itemArray = new TSource[ChunkSize]; + Int32 itemCnt = 0; + foreach (TSource item in this.m_source) + { + if (itemCnt == ChunkSize) + { + if (this.m_isDone) break; + + this.ProcessItemArray(itemArray, itemCnt); + itemArray = new TSource[ChunkSize]; + itemCnt = 0; + } + itemArray[itemCnt++] = item; + } + + if (!this.m_isDone && itemCnt > 0) + { + this.ProcessItemArray(itemArray, itemCnt); + } + + DryadLinqLog.Add("Parallel GroupBy (PartialSort) ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + // Wait for all the workers to complete + foreach (Task task in this.m_workers) + { + task.Wait(); + } + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = e; + } + finally + { + this.m_resultSet.CompleteAdding(); + } + } + + private void ProcessItemArray(TSource[] itemArray, Int32 itemCnt) + { + Wrapper wrappedItemArray = new Wrapper(itemArray); + + // NOT using the TCO.LongRunning option for this task, because it's spawned an + // arbitrary # of times for potentially shorter work which means it's best to + // leave this to the TP load balancing algorithm + Task task = Task.Factory.StartNew(delegate { this.SortGroupByItemArray(wrappedItemArray, itemCnt); }); + this.m_workers.Add(task); + } + + private void SortGroupByItemArray(Wrapper wrappedItemArray, Int32 itemCnt) + { + try + { + TSource[] itemArray = wrappedItemArray.item; + TKey[] keyArray = new TKey[itemCnt]; + for (int i = 0; i < itemCnt; i++) + { + keyArray[i] = this.m_keySelector(itemArray[i]); + } + Array.Sort(keyArray, itemArray, 0, itemCnt, this.m_comparer); + + Pair[] resultChunk = new Pair[ResultChunkSize]; + Int32 count = 0; + if (itemCnt > 0) + { + Grouping curGroup = new Grouping(keyArray[0]); + curGroup.AddItem(this.m_elementSelector(itemArray[0])); + for (int i = 1; i < itemCnt; i++) + { + if (this.m_comparer.Compare(curGroup.Key, keyArray[i]) != 0) + { + if (count == ResultChunkSize) + { + if (this.m_isDone) break; + + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TFinal[])(object)resultChunk); + resultChunk = new Pair[ResultChunkSize]; + } + else + { + this.m_resultSet.Add(this.m_applyFunc(resultChunk).ToArray()); + } + count = 0; + } + resultChunk[count++] = new Pair(curGroup.Key, + this.m_resultSelector(curGroup)); + curGroup = new Grouping(keyArray[i]); + } + curGroup.AddItem(this.m_elementSelector(itemArray[i])); + } + + // Add the last group + if (count == ResultChunkSize) + { + Pair[] lastResultChunk = new Pair[count + 1]; + Array.Copy(resultChunk, lastResultChunk, count); + resultChunk = lastResultChunk; + } + resultChunk[count++] = new Pair(curGroup.Key, + this.m_resultSelector(curGroup)); + + // Add the last chunk + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TFinal[])(object)resultChunk); + } + else + { + this.m_resultSet.Add(this.m_applyFunc(resultChunk).ToArray()); + } + } + wrappedItemArray.item = null; + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = new DryadLinqException(HpcLinqErrorCode.FailureInSortGroupBy, + SR.FailureInSortGroupBy, e); + throw this.m_workerException; + } + } + } + } + + internal class ParallelHashJoin : IParallelPipeline + { + private IEnumerable m_outer; + private IEnumerable m_inner; + private Func m_outerKeySelector; + private Func m_innerKeySelector; + private Func m_resultSelector; + private IEqualityComparer m_comparer; + private Func, IEnumerable> m_applyFunc; + + public ParallelHashJoin(IEnumerable outer, + IEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func resultSelector, + IEqualityComparer comparer, + Func, IEnumerable> applyFunc) + { + this.m_outer = outer; + this.m_inner = inner; + this.m_outerKeySelector = outerKeySelector; + this.m_innerKeySelector = innerKeySelector; + this.m_resultSelector = resultSelector; + this.m_applyFunc = applyFunc; + if (this.m_comparer == null) + { + this.m_comparer = EqualityComparer.Default; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + return new InnerEnumerator(this); + } + + public IParallelPipeline + Extend(Func, IEnumerable> func, bool orderPreserving) + { + if (this.m_applyFunc == null) + { + var applyFunc = (Func, IEnumerable>)(object)func; + return new ParallelHashJoin( + this.m_outer, this.m_inner, this.m_outerKeySelector, this.m_innerKeySelector, + this.m_resultSelector, this.m_comparer, applyFunc); + } + else + { + return new ParallelHashJoin( + this.m_outer, this.m_inner, this.m_outerKeySelector, this.m_innerKeySelector, + this.m_resultSelector, this.m_comparer, s => func(this.m_applyFunc(s))); + } + } + + private class InnerEnumerator : IEnumerator + { + private const Int32 BufferSize = 1024; + + private IEnumerable m_outer; + private IEnumerable m_inner; + private Func m_outerKeySelector; + private Func m_innerKeySelector; + private Func m_resultSelector; + private Func, IEnumerable> m_applyFunc; + private IEqualityComparer m_comparer; + private bool m_hashInner; + + private Thread m_mainWorker; + private Task[] m_workers; + private GroupingHashSet[] m_innerGroupingArray; + private GroupingHashSet[] m_outerGroupingArray; + private BlockingCollection[] m_innerQueues; + private BlockingCollection[] m_outerQueues; + private BlockingCollection m_resultSet; + private BlockingCollection m_stealingQueue; + private Int32 m_stealingWorkerCnt; + private volatile bool m_isDone; + private Exception m_workerException; + + private IEnumerator m_resultEnum; + private TFinal[] m_currentItems; + private Int32 m_index; + private bool m_disposed; + + public InnerEnumerator(ParallelHashJoin parent) + { + this.m_outer = parent.m_outer; + this.m_inner = parent.m_inner; + this.m_outerKeySelector = parent.m_outerKeySelector; + this.m_innerKeySelector = parent.m_innerKeySelector; + this.m_resultSelector = parent.m_resultSelector; + this.m_applyFunc = parent.m_applyFunc; + this.m_comparer = parent.m_comparer; + + this.m_hashInner = true; + if ((this.m_outer is HpcVertexReader) && + (this.m_inner is HpcVertexReader)) + { + Int64 outerLen = ((HpcVertexReader)this.m_outer).GetTotalLength(); + Int64 innerLen = ((HpcVertexReader)this.m_inner).GetTotalLength(); + if (innerLen >= 0 && outerLen >= 0) + { + this.m_hashInner = innerLen <= outerLen; + } + DryadLinqLog.Add("Parallel HashJoin: outerLen={0}, innerLen={1}", outerLen, innerLen); + } + + this.m_isDone = false; + this.m_stealingWorkerCnt = 0; + this.m_stealingQueue = new BlockingCollection(); + this.m_workerException = null; + Int32 wlen = Environment.ProcessorCount; + this.m_workers = new Task[wlen]; + if (this.m_hashInner) + { + this.m_innerGroupingArray = new GroupingHashSet[wlen]; + this.m_outerQueues = new BlockingCollection[wlen]; + for (int i = 0; i < wlen; i++) + { + this.m_innerGroupingArray[i] = new GroupingHashSet(this.m_comparer); + this.m_outerQueues[i] = new BlockingCollection(2); + } + } + else + { + this.m_outerGroupingArray = new GroupingHashSet[wlen]; + this.m_innerQueues = new BlockingCollection[wlen]; + for (int i = 0; i < wlen; i++) + { + this.m_outerGroupingArray[i] = new GroupingHashSet(this.m_comparer); + this.m_innerQueues[i] = new BlockingCollection(2); + } + } + this.m_resultSet = new BlockingCollection(4); + for (int i = 0; i < this.m_workers.Length; i++) + { + this.m_workers[i] = this.CreateTask(i); + } + this.m_mainWorker = new Thread(this.ProcessAllItems); + this.m_mainWorker.Name = "HJ.ProcessAllItem"; + this.m_mainWorker.Start(); + + this.m_resultEnum = this.m_resultSet.GetConsumingEnumerable().GetEnumerator(); + this.m_currentItems = new TFinal[0]; + this.m_index = -1; + this.m_disposed = false; + } + + ~InnerEnumerator() + { + this.Dispose(false); + } + + public bool MoveNext() + { + this.m_index++; + if (this.m_index < this.m_currentItems.Length) + { + return true; + } + + while (this.m_resultEnum.MoveNext()) + { + if (this.m_workerException != null) break; + + this.m_currentItems = this.m_resultEnum.Current; + if (this.m_currentItems.Length > 0) + { + this.m_index = 0; + return true; + } + } + if (this.m_workerException != null) + { + throw new DryadLinqException("Failed while enumerating.", this.m_workerException); + } + return false; + } + + public TFinal Current + { + get { return this.m_currentItems[this.m_index]; } + } + + object IEnumerator.Current + { + get { return this.m_currentItems[this.m_index]; } + } + + public void Reset() + { + throw new InvalidOperationException(); + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!this.m_disposed) + { + this.m_disposed = true; + this.m_isDone = true; + + if (this.m_hashInner) + { + // Always drain the outer queues + foreach (var outerQueue in this.m_outerQueues) + { + foreach (var item in outerQueue.GetConsumingEnumerable()) + { + } + } + } + else + { + // Always drain the inner queues + foreach (var innerQueue in this.m_innerQueues) + { + foreach (var item in innerQueue.GetConsumingEnumerable()) + { + } + } + } + + // Always drain the result queue + while (this.m_resultEnum.MoveNext()) + { + } + } + } + + private void ProcessAllItems() + { + try + { + DryadLinqLog.Add("Parallel HashJoin started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + Int32 wlen = this.m_workers.Length; + if (this.m_hashInner) + { + TOuter[][] buffers = new TOuter[wlen][]; + Int32[] counts = new Int32[wlen]; + for (int i = 0; i < wlen; i++) + { + buffers[i] = new TOuter[BufferSize]; + counts[i] = 0; + } + + // Create a hash lookup table using inner + foreach (TInner innerItem in this.m_inner) + { + TKey innerKey = this.m_innerKeySelector(innerItem); + Int32 hashCode = this.m_comparer.GetHashCode(innerKey); + Int32 idx = HpcLinqUtil.GetTaskIndex(hashCode, counts.Length); + this.m_innerGroupingArray[idx].AddItem(innerKey, innerItem); + } + + DryadLinqLog.Add("Parallel HashJoin: In-memory hashtable created at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + foreach (TOuter outerItem in this.m_outer) + { + TKey outerKey = this.m_outerKeySelector(outerItem); + Int32 hashCode = this.m_comparer.GetHashCode(outerKey); + Int32 idx = HpcLinqUtil.GetTaskIndex(hashCode, counts.Length); + if (counts[idx] == BufferSize) + { + if (this.m_isDone) break; + + this.m_outerQueues[idx].Add(buffers[idx]); + buffers[idx] = new TOuter[BufferSize]; + counts[idx] = 0; + } + buffers[idx][counts[idx]] = outerItem; + counts[idx]++; + } + + // Add the final buffers to the queues and declare adding is complete + for (int i = 0; i < wlen; i++) + { + if (!this.m_isDone && counts[i] > 0) + { + TOuter[] lastBuffer = new TOuter[counts[i]]; + Array.Copy(buffers[i], lastBuffer, counts[i]); + buffers[i] = null; + this.m_outerQueues[i].Add(lastBuffer); + } + this.m_outerQueues[i].CompleteAdding(); + } + } + else + { + TInner[][] buffers = new TInner[wlen][]; + Int32[] counts = new Int32[wlen]; + for (int i = 0; i < wlen; i++) + { + buffers[i] = new TInner[BufferSize]; + counts[i] = 0; + } + + // Create a hash lookup table using outer + foreach (TOuter outerItem in this.m_outer) + { + TKey outerKey = this.m_outerKeySelector(outerItem); + Int32 hashCode = this.m_comparer.GetHashCode(outerKey); + Int32 idx = HpcLinqUtil.GetTaskIndex(hashCode, wlen); + this.m_outerGroupingArray[idx].AddItem(outerKey, outerItem); + } + + DryadLinqLog.Add("Parallel HashJoin: In-memory hashtable created at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + foreach (TInner innerItem in this.m_inner) + { + TKey innerKey = this.m_innerKeySelector(innerItem); + Int32 hashCode = this.m_comparer.GetHashCode(innerKey); + Int32 idx = HpcLinqUtil.GetTaskIndex(hashCode, wlen); + if (counts[idx] == BufferSize) + { + if (this.m_isDone) break; + + this.m_innerQueues[idx].Add(buffers[idx]); + buffers[idx] = new TInner[BufferSize]; + counts[idx] = 0; + } + buffers[idx][counts[idx]] = innerItem; + counts[idx]++; + } + + // Add the final buffers to the queues and declare adding is complete + for (int i = 0; i < wlen; i++) + { + if (!this.m_isDone && counts[i] > 0) + { + TInner[] lastBuffer = new TInner[counts[i]]; + Array.Copy(buffers[i], lastBuffer, counts[i]); + buffers[i] = null; + this.m_innerQueues[i].Add(lastBuffer); + } + this.m_innerQueues[i].CompleteAdding(); + } + } + + DryadLinqLog.Add("Parallel HashJoin ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + // Wait for all the workers to complete + Task.WaitAll(this.m_workers); + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = e; + } + finally + { + if (this.m_outerQueues != null) + { + for (int i = 0; i < this.m_outerQueues.Length; i++) + { + this.m_outerQueues[i].CompleteAdding(); + } + } + if (this.m_innerQueues != null) + { + for (int i = 0; i < this.m_innerQueues.Length; i++) + { + this.m_innerQueues[i].CompleteAdding(); + } + } + this.m_resultSet.CompleteAdding(); + } + } + + private Task CreateTask(Int32 idx) + { + // using the TCO.LongRunning option, because this method is called a fix number of times to spawn worker tasks + // which means it is safe to request a decicated thread for this task + return Task.Factory.StartNew(delegate { this.DoHashJoin(idx); }, + TaskCreationOptions.LongRunning); + } + + private void DoHashJoin(Int32 qidx) + { + try + { + DryadLinqLog.Add("Parallel HashJoin worker {0} started at {1}", + qidx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + Int32 wlen = this.m_workers.Length; + TResult[] resultBuffer = new TResult[BufferSize]; + Int32 count = 0; + Int32 myWorkCnt = 0; + Int32 otherWorkCnt = 0; + Int32 stealingWorkerCnt = this.m_stealingWorkerCnt; + + if (this.m_hashInner) + { + GroupingHashSet innerGroups = this.m_innerGroupingArray[qidx]; + foreach (var buffer in this.m_outerQueues[qidx].GetConsumingEnumerable()) + { + if (this.m_isDone) return; + + for (int bidx = 0; bidx < buffer.Length; bidx++) + { + TOuter outerItem = buffer[bidx]; + Grouping innerGroup = innerGroups.GetGroup(this.m_outerKeySelector(outerItem)); + if (innerGroup != null) + { + foreach (TInner item in innerGroup) + { + if (count == BufferSize) + { + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TFinal[])(object)resultBuffer); + resultBuffer = new TResult[BufferSize]; + } + else + { + int newStealingWorkerCnt = this.m_stealingWorkerCnt; + if (newStealingWorkerCnt > stealingWorkerCnt) + { + myWorkCnt = 0; + otherWorkCnt = 0; + stealingWorkerCnt = newStealingWorkerCnt; + } + if ((stealingWorkerCnt * myWorkCnt) <= ((wlen - stealingWorkerCnt) * otherWorkCnt)) + { + this.m_resultSet.Add(this.m_applyFunc(resultBuffer).ToArray()); + myWorkCnt++; + } + else + { + this.m_stealingQueue.Add(resultBuffer); + resultBuffer = new TResult[BufferSize]; + otherWorkCnt++; + } + } + count = 0; + } + resultBuffer[count++] = this.m_resultSelector(outerItem, item); + } + } + } + } + } + else + { + GroupingHashSet outerGroups = this.m_outerGroupingArray[qidx]; + foreach (var buffer in this.m_innerQueues[qidx].GetConsumingEnumerable()) + { + if (this.m_isDone) return; + + for (int bidx = 0; bidx < buffer.Length; bidx++) + { + TInner innerItem = buffer[bidx]; + Grouping outerGroup = outerGroups.GetGroup(this.m_innerKeySelector(innerItem)); + if (outerGroup != null) + { + foreach (TOuter item in outerGroup) + { + if (count == BufferSize) + { + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TFinal[])(object)resultBuffer); + resultBuffer = new TResult[BufferSize]; + } + else + { + int newStealingWorkerCnt = this.m_stealingWorkerCnt; + if (newStealingWorkerCnt > stealingWorkerCnt) + { + myWorkCnt = 0; + otherWorkCnt = 0; + stealingWorkerCnt = newStealingWorkerCnt; + } + if ((stealingWorkerCnt * myWorkCnt) <= ((wlen - stealingWorkerCnt) * otherWorkCnt)) + { + this.m_resultSet.Add(this.m_applyFunc(resultBuffer).ToArray()); + myWorkCnt++; + } + else + { + this.m_stealingQueue.Add(resultBuffer); + resultBuffer = new TResult[BufferSize]; + otherWorkCnt++; + } + } + count = 0; + } + resultBuffer[count++] = this.m_resultSelector(item, innerItem); + } + } + } + } + } + + // Add the final buffer: + if (!this.m_isDone && count > 0) + { + TResult[] lastResultBuffer = new TResult[count]; + Array.Copy(resultBuffer, lastResultBuffer, count); + resultBuffer = null; + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TFinal[])(object)lastResultBuffer); + } + else + { + this.m_resultSet.Add(this.m_applyFunc(lastResultBuffer).ToArray()); + } + } + + // Stealing work + if (this.m_applyFunc != null) + { + stealingWorkerCnt = Interlocked.Increment(ref this.m_stealingWorkerCnt); + if (stealingWorkerCnt == wlen) + { + this.m_stealingQueue.CompleteAdding(); + } + foreach (TResult[] buffer in this.m_stealingQueue.GetConsumingEnumerable()) + { + if (this.m_isDone) return; + this.m_resultSet.Add(this.m_applyFunc(buffer).ToArray()); + } + } + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = new DryadLinqException(HpcLinqErrorCode.FailureInHashJoin, + SR.FailureInHashJoin, e); + this.m_resultSet.Add(new TFinal[0]); + this.m_stealingQueue.CompleteAdding(); + } + finally + { + DryadLinqLog.Add("Parallel HashJoin worker {0} ended at {1}", + qidx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + } + } + } + + internal class ParallelHashGroupJoin : IParallelPipeline + { + private IEnumerable m_outer; + private IEnumerable m_inner; + private Func m_outerKeySelector; + private Func m_innerKeySelector; + private Func, TResult> m_resultSelector; + private Func, IEnumerable> m_applyFunc; + private IEqualityComparer m_comparer; + + public ParallelHashGroupJoin(IEnumerable outer, + IEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func, TResult> resultSelector, + IEqualityComparer comparer, + Func, IEnumerable> applyFunc) + { + this.m_outer = outer; + this.m_inner = inner; + this.m_outerKeySelector = outerKeySelector; + this.m_innerKeySelector = innerKeySelector; + this.m_resultSelector = resultSelector; + this.m_applyFunc = applyFunc; + if (this.m_comparer == null) + { + this.m_comparer = EqualityComparer.Default; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + return new InnerEnumerator(this); + } + + public IParallelPipeline + Extend(Func, IEnumerable> func, bool orderPreserving) + { + if (this.m_applyFunc == null) + { + var applyFunc = (Func, IEnumerable>)(object)func; + return new ParallelHashGroupJoin( + this.m_outer, this.m_inner, this.m_outerKeySelector, this.m_innerKeySelector, + this.m_resultSelector, this.m_comparer, applyFunc); + } + else + { + return new ParallelHashGroupJoin( + this.m_outer, this.m_inner, this.m_outerKeySelector, this.m_innerKeySelector, + this.m_resultSelector, this.m_comparer, s => func(this.m_applyFunc(s))); + } + } + + private class InnerEnumerator : IEnumerator + { + private const Int32 BufferSize = 1024; + + private IEnumerable m_outer; + private IEnumerable m_inner; + private Func m_outerKeySelector; + private Func m_innerKeySelector; + private Func, TResult> m_resultSelector; + private Func, IEnumerable> m_applyFunc; + private IEqualityComparer m_comparer; + + private Thread m_mainWorker; + private Task[] m_workers; + private GroupingHashSet[] m_innerGroupingArray; + private BlockingCollection[] m_queues; + private BlockingCollection m_resultSet; + private BlockingCollection m_stealingQueue; + private int m_stealingWorkerCnt; + private volatile bool m_isDone; + private Exception m_workerException; + + private IEnumerator m_resultEnum; + private TFinal[] m_currentItems; + private int m_index; + private bool m_disposed; + + public InnerEnumerator(ParallelHashGroupJoin parent) + { + this.m_outer = parent.m_outer; + this.m_inner = parent.m_inner; + this.m_outerKeySelector = parent.m_outerKeySelector; + this.m_innerKeySelector = parent.m_innerKeySelector; + this.m_resultSelector = parent.m_resultSelector; + this.m_applyFunc = parent.m_applyFunc; + this.m_comparer = parent.m_comparer; + + this.m_isDone = false; + this.m_stealingWorkerCnt = 0; + this.m_stealingQueue = new BlockingCollection(); + this.m_workerException = null; + this.m_workers = new Task[Environment.ProcessorCount]; + this.m_innerGroupingArray = new GroupingHashSet[this.m_workers.Length]; + this.m_queues = new BlockingCollection[this.m_workers.Length]; + for (int i = 0; i < this.m_queues.Length; i++) + { + this.m_innerGroupingArray[i] = new GroupingHashSet(this.m_comparer); + this.m_queues[i] = new BlockingCollection(2); + } + this.m_resultSet = new BlockingCollection(4); + for (int i = 0; i < this.m_workers.Length; i++) + { + this.m_workers[i] = this.CreateTask(i); + } + this.m_mainWorker = new Thread(this.ProcessAllItems); + this.m_mainWorker.Name = "HGJ.ProcessAllItem"; + this.m_mainWorker.Start(); + + this.m_resultEnum = this.m_resultSet.GetConsumingEnumerable().GetEnumerator(); + this.m_currentItems = new TFinal[0]; + this.m_index = -1; + this.m_disposed = false; + } + + ~InnerEnumerator() + { + this.Dispose(false); + } + + public bool MoveNext() + { + this.m_index++; + if (this.m_index < this.m_currentItems.Length) + { + return true; + } + + while (this.m_resultEnum.MoveNext()) + { + if (this.m_workerException != null) break; + this.m_currentItems = this.m_resultEnum.Current; + if (this.m_currentItems.Length > 0) + { + this.m_index = 0; + return true; + } + } + if (this.m_workerException != null) + { + throw new DryadLinqException("Failed while enumerating.", this.m_workerException); + } + return false; + } + + public TFinal Current + { + get { return this.m_currentItems[this.m_index]; } + } + + object IEnumerator.Current + { + get { return this.m_currentItems[this.m_index]; } + } + + public void Reset() + { + throw new DryadLinqException("Internal error: Cannot reset this IEnumerator."); + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!this.m_disposed) + { + this.m_disposed = true; + this.m_isDone = true; + + // Always drain the queues + foreach (var queue in this.m_queues) + { + foreach (var item in queue.GetConsumingEnumerable()) + { + } + } + + // Always drain the result queue + while (this.m_resultEnum.MoveNext()) + { + } + } + } + + private void ProcessAllItems() + { + try + { + DryadLinqLog.Add("Parallel HashGroupJoin started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + Int32 wlen = this.m_workers.Length; + // Create a hash lookup table using inner. It is hard to do the same + // optimization as Join, because resultSelector is not symemtric. + foreach (TInner innerItem in this.m_inner) + { + TKey innerKey = this.m_innerKeySelector(innerItem); + Int32 hashCode = this.m_comparer.GetHashCode(innerKey); + Int32 idx = HpcLinqUtil.GetTaskIndex(hashCode, wlen); + this.m_innerGroupingArray[idx].AddItem(innerKey, innerItem); + } + + TOuter[][] buffers = new TOuter[wlen][]; + Int32[] counts = new Int32[wlen]; + for (int i = 0; i < buffers.Length; i++) + { + buffers[i] = new TOuter[BufferSize]; + counts[i] = 0; + } + foreach (TOuter outerItem in this.m_outer) + { + TKey outerKey = this.m_outerKeySelector(outerItem); + Int32 hashCode = this.m_comparer.GetHashCode(outerKey); + Int32 idx = HpcLinqUtil.GetTaskIndex(hashCode, wlen); + if (counts[idx] == BufferSize) + { + if (this.m_isDone) break; + + this.m_queues[idx].Add(buffers[idx]); + buffers[idx] = new TOuter[BufferSize]; + counts[idx] = 0; + } + buffers[idx][counts[idx]] = outerItem; + counts[idx]++; + } + // Add the final buffers to the queues and declare adding is complete + for (int i = 0; i < counts.Length; i++) + { + if (!this.m_isDone && counts[i] > 0) + { + TOuter[] lastBuffer = new TOuter[counts[i]]; + Array.Copy(buffers[i], lastBuffer, counts[i]); + buffers[i] = null; + this.m_queues[i].Add(lastBuffer); + } + this.m_queues[i].CompleteAdding(); + } + + DryadLinqLog.Add("Parallel HashGroupJoin ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + // Wait for all the workers to complete + Task.WaitAll(this.m_workers); + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = e; + } + finally + { + for (int i = 0; i < this.m_queues.Length; i++) + { + this.m_queues[i].CompleteAdding(); + } + this.m_resultSet.CompleteAdding(); + } + } + + private Task CreateTask(Int32 idx) + { + // using the TCO.LongRunning option, because this method is called a fix number of times to spawn worker tasks + // which means it is safe to request a decicated thread for this task + return Task.Factory.StartNew(delegate { this.DoGroupJoin(idx); }, + TaskCreationOptions.LongRunning); + } + + private void DoGroupJoin(Int32 qidx) + { + try + { + DryadLinqLog.Add("Parallel HashGroupJoin worker {0} started at {1}", + qidx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + GroupingHashSet innerGroups = this.m_innerGroupingArray[qidx]; + Int32 wlen = this.m_workers.Length; + BlockingCollection queue = this.m_queues[qidx]; + TResult[] resultBuffer = new TResult[BufferSize]; + Int32 count = 0; + Int32 myWorkCnt = 0; + Int32 otherWorkCnt = 0; + Int32 stealingWorkerCnt = this.m_stealingWorkerCnt; + + TInner[] emptyGroup = new TInner[0]; + foreach (var buffer in queue.GetConsumingEnumerable()) + { + if (this.m_isDone) return; + + for (int i = 0; i < buffer.Length; i++) + { + TOuter outerItem = buffer[i]; + IEnumerable innerGroup = innerGroups.GetGroup(this.m_outerKeySelector(outerItem)); + if (innerGroup == null) + { + innerGroup = emptyGroup; + } + if (count == BufferSize) + { + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TFinal[])(object)resultBuffer); + resultBuffer = new TResult[BufferSize]; + } + else + { + int newStealingWorkerCnt = this.m_stealingWorkerCnt; + if (newStealingWorkerCnt > stealingWorkerCnt) + { + myWorkCnt = 0; + otherWorkCnt = 0; + stealingWorkerCnt = newStealingWorkerCnt; + } + if ((stealingWorkerCnt * myWorkCnt) <= ((wlen - stealingWorkerCnt) * otherWorkCnt)) + { + this.m_resultSet.Add(this.m_applyFunc(resultBuffer).ToArray()); + myWorkCnt++; + } + else + { + this.m_stealingQueue.Add(resultBuffer); + resultBuffer = new TResult[BufferSize]; + otherWorkCnt++; + } + } + count = 0; + } + resultBuffer[count++] = this.m_resultSelector(outerItem, innerGroup); + } + } + + // Add the final buffer: + if (!this.m_isDone && count > 0) + { + TResult[] lastResultBuffer = new TResult[count]; + Array.Copy(resultBuffer, lastResultBuffer, count); + resultBuffer = null; + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TFinal[])(object)lastResultBuffer); + } + else + { + this.m_resultSet.Add(this.m_applyFunc(lastResultBuffer).ToArray()); + } + } + + // Stealing work + if (this.m_applyFunc != null) + { + stealingWorkerCnt = Interlocked.Increment(ref this.m_stealingWorkerCnt); + if (stealingWorkerCnt == wlen) + { + this.m_stealingQueue.CompleteAdding(); + } + foreach (TResult[] buffer in this.m_stealingQueue.GetConsumingEnumerable()) + { + if (this.m_isDone) return; + this.m_resultSet.Add(this.m_applyFunc(buffer).ToArray()); + } + } + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = new DryadLinqException(HpcLinqErrorCode.FailureInHashGroupJoin, + SR.FailureInHashGroupJoin, e); + this.m_resultSet.Add(new TFinal[0]); + this.m_stealingQueue.CompleteAdding(); + } + finally + { + DryadLinqLog.Add("Parallel HashGroupJoin worker {0} ended at {1}", + qidx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + } + } + } + + internal class ParallelSetOperation : IParallelPipeline + { + private string m_opName; + private IEnumerable m_source; + private IEnumerable m_otherSource; + private Func, IEnumerable> m_applyFunc; + private IEqualityComparer m_comparer; + private bool m_isPartial; + + public ParallelSetOperation(string opName, + IEnumerable source, + IEnumerable otherSource, + IEqualityComparer comparer, + Func, IEnumerable> applyFunc, + bool isPartial) + { + this.m_opName = opName; + this.m_source = source; + this.m_otherSource = otherSource; + this.m_comparer = comparer; + this.m_applyFunc = applyFunc; + if (this.m_comparer == null) + { + this.m_comparer = EqualityComparer.Default; + } + this.m_isPartial = isPartial; + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + return new InnerEnumerator(this); + } + + public IParallelPipeline + Extend(Func, IEnumerable> func, bool orderPreserving) + { + if (this.m_applyFunc == null) + { + var applyFunc = (Func, IEnumerable>)(object)func; + return new ParallelSetOperation( + this.m_opName, this.m_source, this.m_otherSource, this.m_comparer, + applyFunc, this.m_isPartial); + } + else + { + return new ParallelSetOperation( + this.m_opName, this.m_source, this.m_otherSource, this.m_comparer, + s => func(this.m_applyFunc(s)), this.m_isPartial); + } + } + + private class InnerEnumerator : IEnumerator + { + private const Int32 SetSize = (1 << 23); // 8388608 + private const Int32 BufferSize = 2048; + + private string m_opName; + private IEnumerable m_source; + private IEnumerable m_otherSource; + private IEqualityComparer m_comparer; + private Func, IEnumerable> m_applyFunc; + private bool m_isPartial; + + private Thread m_mainWorker; + private Task[] m_workers; + private BlockingCollection[] m_queues; + private BlockingCollection m_resultSet; + private BlockingCollection m_stealingQueue; + private int m_stealingWorkerCnt; + private volatile bool m_isDone; + private Exception m_workerException; + + private IEnumerator m_resultEnum; + private TResult[] m_currentItems; + private int m_index; + private bool m_disposed; + + public InnerEnumerator(ParallelSetOperation parent) + { + this.m_opName = parent.m_opName; + this.m_source = parent.m_source; + this.m_otherSource = parent.m_otherSource; + + if (this.m_opName == "Except") + { + this.m_source = parent.m_otherSource; + this.m_otherSource = parent.m_source; + } + else if (this.m_opName == "Intersect") + { + if ((parent.m_source is HpcVertexReader) && + (parent.m_otherSource is HpcVertexReader)) + { + Int64 len1 = ((HpcVertexReader)parent.m_source).GetTotalLength(); + Int64 len2 = ((HpcVertexReader)parent.m_otherSource).GetTotalLength(); + if (len2 >= 0 && len1 > len2) + { + this.m_source = parent.m_otherSource; + this.m_otherSource = parent.m_source; + } + DryadLinqLog.Add("Parallel " + this.m_opName + ": len1={0}, len2={1}", len1, len2); + } + } + + this.m_comparer = parent.m_comparer; + this.m_applyFunc = parent.m_applyFunc; + this.m_isPartial = parent.m_isPartial; + + this.m_isDone = false; + this.m_stealingWorkerCnt = 0; + this.m_stealingQueue = new BlockingCollection(); + this.m_workers = new Task[Environment.ProcessorCount]; + this.m_queues = new BlockingCollection[this.m_workers.Length]; + for (int i = 0; i < this.m_queues.Length; i++) + { + this.m_queues[i] = new BlockingCollection(2); + } + this.m_resultSet = new BlockingCollection(4); + for (int i = 0; i < this.m_workers.Length; i++) + { + this.m_workers[i] = this.CreateTask(i); + } + this.m_mainWorker = new Thread(this.ProcessAllItems); + this.m_mainWorker.Name = "SO.ProcessAllItem"; + this.m_mainWorker.Start(); + + this.m_resultEnum = this.m_resultSet.GetConsumingEnumerable().GetEnumerator(); + this.m_currentItems = new TResult[0]; + this.m_index = -1; + this.m_disposed = false; + } + + ~InnerEnumerator() + { + this.Dispose(false); + } + + public bool MoveNext() + { + this.m_index++; + if (this.m_index < this.m_currentItems.Length) + { + return true; + } + + while (this.m_resultEnum.MoveNext()) + { + if (this.m_workerException != null) break; + + this.m_currentItems = this.m_resultEnum.Current; + if (this.m_currentItems.Length > 0) + { + this.m_index = 0; + return true; + } + } + if (this.m_workerException != null) + { + throw new DryadLinqException("Failed while enumerating.", this.m_workerException); + } + return false; + } + + public TResult Current + { + get { return this.m_currentItems[this.m_index]; } + } + + object IEnumerator.Current + { + get { return this.m_currentItems[this.m_index]; } + } + + public void Reset() + { + throw new DryadLinqException("Internal error: Cannot reset this IEnumerator."); + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!this.m_disposed) + { + this.m_disposed = true; + this.m_isDone = true; + + foreach (var queue in this.m_queues) + { + foreach (var item in queue.GetConsumingEnumerable()) + { + // Always drain the queues + } + } + while (this.m_resultEnum.MoveNext()) + { + // Always drain the result queue + } + } + } + + private void ProcessAllItems() + { + try + { + DryadLinqLog.Add("Parallel " + this.m_opName + " started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + Int32 wlen = this.m_workers.Length; + TSource[][] buffers = new TSource[wlen][]; + Int32[] counts = new Int32[wlen]; + for (int i = 0; i < wlen; i++) + { + buffers[i] = new TSource[BufferSize]; + counts[i] = 0; + } + foreach (TSource item in this.m_source) + { + Int32 hashCode = this.m_comparer.GetHashCode(item); + Int32 idx = HpcLinqUtil.GetTaskIndex(hashCode, wlen); + if (counts[idx] == BufferSize) + { + if (this.m_isDone) break; + + this.m_queues[idx].Add(buffers[idx]); + buffers[idx] = new TSource[BufferSize]; + counts[idx] = 0; + } + buffers[idx][counts[idx]] = item; + counts[idx]++; + } + + if (this.m_opName == "Intersect" || this.m_opName == "Except") + { + // Add the final buffers of m_source to queues + for (int i = 0; i < wlen; i++) + { + if (!this.m_isDone && counts[i] > 0) + { + TSource[] lastBuffer = new TSource[counts[i]]; + Array.Copy(buffers[i], lastBuffer, counts[i]); + buffers[i] = null; + this.m_queues[i].Add(lastBuffer); + counts[i] = 0; + } + this.m_queues[i].Add(new TSource[0]); + } + } + if (this.m_otherSource != null) + { + foreach (TSource item in this.m_otherSource) + { + Int32 hashCode = this.m_comparer.GetHashCode(item); + Int32 idx = HpcLinqUtil.GetTaskIndex(hashCode, wlen); + if (counts[idx] == BufferSize) + { + if (this.m_isDone) break; + + this.m_queues[idx].Add(buffers[idx]); + buffers[idx] = new TSource[BufferSize]; + counts[idx] = 0; + } + buffers[idx][counts[idx]] = item; + counts[idx]++; + } + } + // Add the final buffers to the queues and declare adding is complete + for (int i = 0; i < wlen; i++) + { + if (!this.m_isDone && counts[i] > 0) + { + TSource[] lastBuffer = new TSource[counts[i]]; + Array.Copy(buffers[i], lastBuffer, counts[i]); + buffers[i] = null; + this.m_queues[i].Add(lastBuffer); + } + this.m_queues[i].CompleteAdding(); + } + + DryadLinqLog.Add("Parallel " + this.m_opName + " ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + // Wait for all the workers to complete + Task.WaitAll(this.m_workers); + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = e; + } + finally + { + for (int i = 0; i < this.m_queues.Length; i++) + { + this.m_queues[i].CompleteAdding(); + } + this.m_resultSet.CompleteAdding(); + } + } + + private Task CreateTask(Int32 idx) + { + if (this.m_opName == "Intersect") + { + // using the TCO.LongRunning option, because this method is + // called a fix number of times to spawn worker tasks + return Task.Factory.StartNew(delegate { this.DoIntersect(idx); }, + TaskCreationOptions.LongRunning); + } + else if (this.m_opName == "Except") + { + // using the TCO.LongRunning option, because this method is + // called a fix number of times to spawn worker tasks + return Task.Factory.StartNew(delegate { this.DoExcept(idx); }, + TaskCreationOptions.LongRunning); + } + else + { + if (this.m_isPartial) + { + // using the TCO.LongRunning option, because this method is + // called a fix number of times to spawn worker tasks + return Task.Factory.StartNew(delegate { this.DoPartialDistinct(idx); }, + TaskCreationOptions.LongRunning); + } + else + { + // using the TCO.LongRunning option, because this method is + // called a fix number of times to spawn worker tasks + return Task.Factory.StartNew(delegate { this.DoFullDistinct(idx); }, + TaskCreationOptions.LongRunning); + } + } + } + + private void DoPartialDistinct(Int32 qidx) + { + try + { + DryadLinqLog.Add("Parallel Distinct (partial) worker {0} started at {1}", + qidx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + BlockingCollection queue = this.m_queues[qidx]; + Int32 wlen = this.m_workers.Length; + TSource[] resultBuffer = new TSource[BufferSize]; + Int32 count = 0; + Int32 myWorkCnt = 0; + Int32 otherWorkCnt = 0; + Int32 stealingWorkerCnt = 0; + + if (typeof(TSource).IsValueType) + { + Pair[] seenSet = new Pair[SetSize]; + foreach (var buffer in queue.GetConsumingEnumerable()) + { + if (this.m_isDone) return; + + for (int i = 0; i < buffer.Length; i++) + { + TSource item = buffer[i]; + Int32 idx = (this.m_comparer.GetHashCode(item) & 0x7fffffff) % SetSize; + Pair seenItem = seenSet[idx]; + if (!seenItem.Value || !this.m_comparer.Equals(item, seenItem.Key)) + { + seenSet[idx] = new Pair(item, true); + if (count == BufferSize) + { + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TResult[])(object)resultBuffer); + resultBuffer = new TSource[BufferSize]; + } + else + { + int newStealingWorkerCnt = this.m_stealingWorkerCnt; + if (newStealingWorkerCnt > stealingWorkerCnt) + { + myWorkCnt = 0; + otherWorkCnt = 0; + stealingWorkerCnt = newStealingWorkerCnt; + } + if ((stealingWorkerCnt * myWorkCnt) <= ((wlen - stealingWorkerCnt) * otherWorkCnt)) + { + this.m_resultSet.Add(this.m_applyFunc(resultBuffer).ToArray()); + myWorkCnt++; + } + else + { + this.m_stealingQueue.Add(resultBuffer); + resultBuffer = new TSource[BufferSize]; + otherWorkCnt++; + } + } + count = 0; + } + resultBuffer[count++] = item; + } + } + } + } + else + { + TSource[] seenSet = new TSource[SetSize]; + foreach (var buffer in queue.GetConsumingEnumerable()) + { + if (this.m_isDone) return; + + for (int i = 0; i < buffer.Length; i++) + { + TSource item = buffer[i]; + Int32 idx = (this.m_comparer.GetHashCode(item) & 0x7fffffff) % SetSize; + TSource seenItem = seenSet[idx]; + if (seenItem == null || !this.m_comparer.Equals(item, seenItem)) + { + seenSet[idx] = item; + if (count == BufferSize) + { + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TResult[])(object)resultBuffer); + resultBuffer = new TSource[BufferSize]; + } + else + { + int newStealingWorkerCnt = this.m_stealingWorkerCnt; + if (newStealingWorkerCnt > stealingWorkerCnt) + { + myWorkCnt = 0; + otherWorkCnt = 0; + stealingWorkerCnt = newStealingWorkerCnt; + } + if ((stealingWorkerCnt * myWorkCnt) <= ((wlen - stealingWorkerCnt) * otherWorkCnt)) + { + this.m_resultSet.Add(this.m_applyFunc(resultBuffer).ToArray()); + myWorkCnt++; + } + else + { + this.m_stealingQueue.Add(resultBuffer); + resultBuffer = new TSource[BufferSize]; + otherWorkCnt++; + } + } + count = 0; + } + resultBuffer[count++] = item; + } + } + } + } + + // Add the final buffer: + if (!this.m_isDone && count > 0) + { + TSource[] lastResultBuffer = new TSource[count]; + Array.Copy(resultBuffer, lastResultBuffer, count); + resultBuffer = null; + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TResult[])(object)lastResultBuffer); + } + else + { + this.m_resultSet.Add(this.m_applyFunc(lastResultBuffer).ToArray()); + } + } + + // Stealing work + if (this.m_applyFunc != null) + { + stealingWorkerCnt = Interlocked.Increment(ref this.m_stealingWorkerCnt); + if (stealingWorkerCnt == wlen) + { + this.m_stealingQueue.CompleteAdding(); + } + foreach (TSource[] buffer in this.m_stealingQueue.GetConsumingEnumerable()) + { + if (this.m_isDone) return; + this.m_resultSet.Add(this.m_applyFunc(buffer).ToArray()); + } + } + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = new DryadLinqException(HpcLinqErrorCode.FailureInDistinct, + SR.FailureInDistinct, e); + this.m_resultSet.Add(new TResult[0]); + this.m_stealingQueue.CompleteAdding(); + } + finally + { + DryadLinqLog.Add("Parallel Distinct (partial) worker {0} ended at {1}", + qidx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + } + + private void DoFullDistinct(Int32 qidx) + { + try + { + DryadLinqLog.Add("Parallel " + this.m_opName + " worker {0} started at {1}", + qidx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + BlockingCollection queue = this.m_queues[qidx]; + BigHashSet seenSet = new BigHashSet(this.m_comparer); + // HashSet seenSet = new HashSet(this.m_comparer); + + Int32 wlen = this.m_workers.Length; + TSource[] resultBuffer = new TSource[BufferSize]; + Int32 count = 0; + Int32 myWorkCnt = 0; + Int32 otherWorkCnt = 0; + Int32 stealingWorkerCnt = 0; + + foreach (TSource[] buffer in queue.GetConsumingEnumerable()) + { + if (this.m_isDone) return; + + for (int i = 0; i < buffer.Length; i++) + { + TSource item = buffer[i]; + if (seenSet.Add(item)) + { + if (count == BufferSize) + { + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TResult[])(object)resultBuffer); + resultBuffer = new TSource[BufferSize]; + } + else + { + int newStealingWorkerCnt = this.m_stealingWorkerCnt; + if (newStealingWorkerCnt > stealingWorkerCnt) + { + myWorkCnt = 0; + otherWorkCnt = 0; + stealingWorkerCnt = newStealingWorkerCnt; + } + if ((stealingWorkerCnt * myWorkCnt) <= ((wlen - stealingWorkerCnt) * otherWorkCnt)) + { + this.m_resultSet.Add(this.m_applyFunc(resultBuffer).ToArray()); + myWorkCnt++; + } + else + { + this.m_stealingQueue.Add(resultBuffer); + resultBuffer = new TSource[BufferSize]; + otherWorkCnt++; + } + } + count = 0; + } + resultBuffer[count++] = item; + } + } + } + + // Add the final buffer: + if (!this.m_isDone && count > 0) + { + TSource[] lastResultBuffer = new TSource[count]; + Array.Copy(resultBuffer, lastResultBuffer, count); + resultBuffer = null; + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TResult[])(object)lastResultBuffer); + } + else + { + this.m_resultSet.Add(this.m_applyFunc(lastResultBuffer).ToArray()); + } + } + + // Stealing work + if (this.m_applyFunc != null) + { + stealingWorkerCnt = Interlocked.Increment(ref this.m_stealingWorkerCnt); + if (stealingWorkerCnt == wlen) + { + this.m_stealingQueue.CompleteAdding(); + } + foreach (TSource[] buffer in this.m_stealingQueue.GetConsumingEnumerable()) + { + if (this.m_isDone) return; + this.m_resultSet.Add(this.m_applyFunc(buffer).ToArray()); + } + } + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = new DryadLinqException(HpcLinqErrorCode.FailureInOperator, + String.Format(SR.FailureInOperator, this.m_opName), e); + this.m_resultSet.Add(new TResult[0]); + this.m_stealingQueue.CompleteAdding(); + } + finally + { + DryadLinqLog.Add("Parallel " + this.m_opName + " worker {0} ended at {1}", + qidx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + } + + private void DoExcept(Int32 qidx) + { + try + { + DryadLinqLog.Add("Parallel Except worker {0} started at {1}", + qidx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + BlockingCollection queue = this.m_queues[qidx]; + BigHashSet seenSet = new BigHashSet(this.m_comparer); + + Int32 wlen = this.m_workers.Length; + TSource[] resultBuffer = new TSource[BufferSize]; + Int32 count = 0; + Int32 myWorkCnt = 0; + Int32 otherWorkCnt = 0; + Int32 stealingWorkerCnt = 0; + bool isFirst = false; + + foreach (TSource[] buffer in queue.GetConsumingEnumerable()) + { + if (this.m_isDone) return; + + if (isFirst) + { + for (int i = 0; i < buffer.Length; i++) + { + TSource item = buffer[i]; + if (seenSet.Add(item)) + { + if (count == BufferSize) + { + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TResult[])(object)resultBuffer); + resultBuffer = new TSource[BufferSize]; + } + else + { + int newStealingWorkerCnt = this.m_stealingWorkerCnt; + if (newStealingWorkerCnt > stealingWorkerCnt) + { + myWorkCnt = 0; + otherWorkCnt = 0; + stealingWorkerCnt = newStealingWorkerCnt; + } + if ((stealingWorkerCnt * myWorkCnt) <= ((wlen - stealingWorkerCnt) * otherWorkCnt)) + { + this.m_resultSet.Add(this.m_applyFunc(resultBuffer).ToArray()); + myWorkCnt++; + } + else + { + this.m_stealingQueue.Add(resultBuffer); + resultBuffer = new TSource[BufferSize]; + otherWorkCnt++; + } + } + count = 0; + } + resultBuffer[count++] = item; + } + } + } + else + { + isFirst = (buffer.Length == 0); + for (int i = 0; i < buffer.Length; i++) + { + seenSet.Add(buffer[i]); + } + } + } + + // Add the final buffer: + if (!this.m_isDone && count > 0) + { + TSource[] lastResultBuffer = new TSource[count]; + Array.Copy(resultBuffer, lastResultBuffer, count); + resultBuffer = null; + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TResult[])(object)lastResultBuffer); + } + else + { + this.m_resultSet.Add(this.m_applyFunc(lastResultBuffer).ToArray()); + } + } + + // Stealing work + if (this.m_applyFunc != null) + { + stealingWorkerCnt = Interlocked.Increment(ref this.m_stealingWorkerCnt); + if (stealingWorkerCnt == wlen) + { + this.m_stealingQueue.CompleteAdding(); + } + foreach (TSource[] buffer in this.m_stealingQueue.GetConsumingEnumerable()) + { + if (this.m_isDone) return; + this.m_resultSet.Add(this.m_applyFunc(buffer).ToArray()); + } + } + } + catch (Exception e) + { + this.m_isDone = true; + this.m_stealingQueue.CompleteAdding(); + throw new DryadLinqException(HpcLinqErrorCode.FailureInExcept, + String.Format(SR.FailureInExcept), e); + } + finally + { + DryadLinqLog.Add("Parallel Except worker {0} ended at {1}", + qidx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + } + + private void DoIntersect(Int32 qidx) + { + try + { + DryadLinqLog.Add("Parallel Intersect worker {0} started at {1}", + qidx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + BlockingCollection queue = this.m_queues[qidx]; + BigHashSet leftSet = new BigHashSet(this.m_comparer); + + Int32 wlen = this.m_workers.Length; + TSource[] resultBuffer = new TSource[BufferSize]; + Int32 count = 0; + Int32 myWorkCnt = 0; + Int32 otherWorkCnt = 0; + Int32 stealingWorkerCnt = 0; + bool isFirst = true; + + foreach (TSource[] buffer in queue.GetConsumingEnumerable()) + { + if (this.m_isDone) return; + + if (isFirst) + { + isFirst = (buffer.Length != 0); + for (int i = 0; i < buffer.Length; i++) + { + leftSet.Add(buffer[i]); + } + } + else + { + for (int i = 0; i < buffer.Length; i++) + { + TSource item = buffer[i]; + if (leftSet.Remove(item)) + { + if (count == BufferSize) + { + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TResult[])(object)resultBuffer); + resultBuffer = new TSource[BufferSize]; + } + else + { + int newStealingWorkerCnt = this.m_stealingWorkerCnt; + if (newStealingWorkerCnt > stealingWorkerCnt) + { + myWorkCnt = 0; + otherWorkCnt = 0; + stealingWorkerCnt = newStealingWorkerCnt; + } + if ((stealingWorkerCnt * myWorkCnt) <= ((wlen - stealingWorkerCnt) * otherWorkCnt)) + { + this.m_resultSet.Add(this.m_applyFunc(resultBuffer).ToArray()); + myWorkCnt++; + } + else + { + this.m_stealingQueue.Add(resultBuffer); + resultBuffer = new TSource[BufferSize]; + otherWorkCnt++; + } + } + count = 0; + } + resultBuffer[count++] = item; + } + } + } + } + + // Add the final buffer: + if (!this.m_isDone && count > 0) + { + TSource[] lastResultBuffer = new TSource[count]; + Array.Copy(resultBuffer, lastResultBuffer, count); + resultBuffer = null; + if (this.m_applyFunc == null) + { + this.m_resultSet.Add((TResult[])(object)lastResultBuffer); + } + else + { + this.m_resultSet.Add(this.m_applyFunc(lastResultBuffer).ToArray()); + } + } + + // Stealing work + if (this.m_applyFunc != null) + { + stealingWorkerCnt = Interlocked.Increment(ref this.m_stealingWorkerCnt); + if (stealingWorkerCnt == wlen) + { + this.m_stealingQueue.CompleteAdding(); + } + foreach (TSource[] buffer in this.m_stealingQueue.GetConsumingEnumerable()) + { + if (this.m_isDone) return; + this.m_resultSet.Add(this.m_applyFunc(buffer).ToArray()); + } + } + } + catch (Exception e) + { + this.m_isDone = true; + this.m_stealingQueue.CompleteAdding(); + throw new DryadLinqException(HpcLinqErrorCode.FailureInIntersect, + String.Format(SR.FailureInIntersect), e); + } + finally + { + DryadLinqLog.Add("Parallel Intersect worker {0} ended at {1}", + qidx, DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + } + } + } + } + + internal class ParallelOrderedGroupBy : IParallelPipeline + { + private IEnumerable m_source; + private Func m_keySelector; + private Func m_elementSelector; + private Func, TResult> m_resultSelector; + private Func, IEnumerable> m_applyFunc; + private IEqualityComparer m_comparer; + + public ParallelOrderedGroupBy(IEnumerable source, + Func keySelector, + Func elementSelector, + Func, TResult> resultSelector, + IEqualityComparer comparer, + Func, IEnumerable> applyFunc) + { + if (resultSelector == null || elementSelector == null) + { + throw new InvalidOperationException(); + } + this.m_source = source; + this.m_keySelector = keySelector; + this.m_elementSelector = elementSelector; + this.m_resultSelector = resultSelector; + this.m_applyFunc = applyFunc; + this.m_comparer = comparer; + if (this.m_comparer == null) + { + this.m_comparer = EqualityComparer.Default; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + return new InnerEnumerator(this); + } + + public IParallelPipeline + Extend(Func, IEnumerable> func, bool orderPreserving) + { + if (this.m_applyFunc == null) + { + var applyFunc = (Func, IEnumerable>)(object)func; + return new ParallelOrderedGroupBy( + this.m_source, this.m_keySelector, this.m_elementSelector, + this.m_resultSelector, this.m_comparer, applyFunc); + } + else + { + return new ParallelOrderedGroupBy( + this.m_source, this.m_keySelector, this.m_elementSelector, + this.m_resultSelector, this.m_comparer, + s => func(this.m_applyFunc(s))); + } + } + + private class InnerEnumerator : IEnumerator + { + private const Int32 BufferSize = 16384; // (1 << 14); + + private IEnumerable m_source; + private Func m_keySelector; + private Func m_elementSelector; + private Func, TResult> m_resultSelector; + private Func, IEnumerable> m_applyFunc; + private IEqualityComparer m_comparer; + + private Thread m_mainWorker; + private Task[] m_workers; + private EventWaitHandle[] m_events; + private List[] m_workerResLists; + private BlockingCollection> m_resultQueue; + private volatile bool m_isDone; + private Exception m_workerException; + + private IEnumerator> m_resultEnum; + private List m_currentItems; + private int m_index; + private bool m_disposed; + + public InnerEnumerator(ParallelOrderedGroupBy parent) + { + this.m_source = parent.m_source; + this.m_keySelector = parent.m_keySelector; + this.m_elementSelector = parent.m_elementSelector; + this.m_resultSelector = parent.m_resultSelector; + this.m_applyFunc = parent.m_applyFunc; + this.m_comparer = parent.m_comparer; + + this.m_isDone = false; + this.m_workerException = null; + this.m_resultQueue = new BlockingCollection>(2); + this.m_workers = new Task[2 * Environment.ProcessorCount]; + this.m_events = new ManualResetEvent[this.m_workers.Length]; + this.m_workerResLists = new List[this.m_workers.Length]; + for (int i = 0; i < this.m_events.Length; i++) + { + this.m_events[i] = new ManualResetEvent(true); + } + this.m_mainWorker = new Thread(this.ProcessAllItems); + this.m_mainWorker.Name = "OGB.ProcessAllItem"; + this.m_mainWorker.Start(); + + this.m_resultEnum = this.m_resultQueue.GetConsumingEnumerable().GetEnumerator(); + this.m_currentItems = new List(); + this.m_index = -1; + this.m_disposed = false; + } + + ~InnerEnumerator() + { + this.Dispose(false); + } + + public bool MoveNext() + { + this.m_index++; + if (this.m_index < this.m_currentItems.Count) + { + return true; + } + + while (this.m_resultEnum.MoveNext()) + { + if (this.m_workerException != null) break; + + this.m_currentItems = this.m_resultEnum.Current; + if (this.m_currentItems.Count > 0) + { + this.m_index = 0; + return true; + } + } + if (this.m_workerException != null) + { + throw new DryadLinqException("Failed while enumerating.", this.m_workerException); + } + return false; + } + + public TFinal Current + { + get { return this.m_currentItems[this.m_index]; } + } + + object IEnumerator.Current + { + get { return this.m_currentItems[this.m_index]; } + } + + public void Reset() + { + throw new InvalidOperationException(); + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!this.m_disposed) + { + this.m_disposed = true; + this.m_isDone = true; + while (this.m_resultEnum.MoveNext()) + { + // Always drain the result queue + } + } + } + + private void ProcessAllItems() + { + try + { + // Read all the items + DryadLinqLog.Add("Parallel OrderedGroupBy started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + Int32 wlen = this.m_workers.Length; + TSource[] elemBuffer = new TSource[BufferSize]; + Int32 taskIdx = 0; + Int32 elemCnt = 0; + IEnumerator sourceEnum = this.m_source.GetEnumerator(); + while (sourceEnum.MoveNext()) + { + TSource item = sourceEnum.Current; + if (elemCnt < BufferSize) + { + elemBuffer[elemCnt++] = item; + } + else + { + if (this.m_isDone) break; + + TKey lastKey = this.m_keySelector(elemBuffer[BufferSize - 1]); + bool hasMoreItems = true; + if (this.m_comparer.Equals(lastKey, this.m_keySelector(item))) + { + List lastItems = new List(8); + lastItems.Add(item); + while (hasMoreItems = sourceEnum.MoveNext()) + { + item = sourceEnum.Current; + if (!this.m_comparer.Equals(lastKey, this.m_keySelector(item))) + { + break; + } + lastItems.Add(item); + } + elemCnt = BufferSize + lastItems.Count(); + TSource[] newElemBuffer = new TSource[elemCnt]; + Array.Copy(elemBuffer, 0, newElemBuffer, 0, BufferSize); + lastItems.CopyTo(newElemBuffer, BufferSize); + elemBuffer = newElemBuffer; + } + + this.CreateTask(taskIdx, elemBuffer, elemCnt); + taskIdx = (taskIdx + 1) % wlen; + elemCnt = 0; + if (!hasMoreItems) break; + elemBuffer = new TSource[BufferSize]; + elemBuffer[elemCnt++] = item; + } + } + + // Create a task for the last buffer + if (!this.m_isDone && elemCnt > 0) + { + this.CreateTask(taskIdx, elemBuffer, elemCnt); + taskIdx = (taskIdx + 1) % wlen; + } + DryadLinqLog.Add("Parallel OrderedGroupBy ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + // Wait for all the workers to complete + for (int i = 0; i < this.m_workers.Length; i++) + { + this.m_events[taskIdx].WaitOne(); + if (this.m_workerResLists[taskIdx] != null) + { + this.m_resultQueue.Add(this.m_workerResLists[taskIdx]); + this.m_workerResLists[taskIdx] = null; + } + taskIdx = (taskIdx + 1) % wlen; + } + foreach (Task task in this.m_workers) + { + if (task != null) task.Wait(); + } + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = e; + } + finally + { + this.m_resultQueue.CompleteAdding(); + } + } + + private void CreateTask(Int32 taskIdx, TSource[] elems, Int32 cnt) + { + this.m_events[taskIdx].WaitOne(); + if (this.m_workerResLists[taskIdx] != null) + { + this.m_resultQueue.Add(this.m_workerResLists[taskIdx]); + this.m_workerResLists[taskIdx] = null; + } + this.m_events[taskIdx].Reset(); + + // NOT using the TCO.LongRunning option for this task, because it's spawned an arbitrary # of times for potentially shorter work + // which means it's best to leave this to the TP load balancing algorithm + this.m_workers[taskIdx] = Task.Factory.StartNew(delegate { this.OrderedGroupBy(taskIdx, elems, cnt); }); + } + + private void OrderedGroupBy(Int32 taskIdx, TSource[] elems, Int32 cnt) + { + try + { + List resList = new List(cnt / 8); + Grouping curGroup = new Grouping(this.m_keySelector(elems[0])); + curGroup.AddItem(this.m_elementSelector(elems[0])); + Int32 idx = 1; + while (idx < cnt) + { + if (this.m_comparer.Equals(curGroup.Key, this.m_keySelector(elems[idx]))) + { + curGroup.AddItem(this.m_elementSelector(elems[idx])); + } + else + { + resList.Add(this.m_resultSelector(curGroup.Key, curGroup)); + curGroup = new Grouping(this.m_keySelector(elems[idx])); + curGroup.AddItem(this.m_elementSelector(elems[idx])); + } + idx++; + } + resList.Add(this.m_resultSelector(curGroup.Key, curGroup)); + if (this.m_applyFunc == null) + { + this.m_workerResLists[taskIdx] = (List)(object)resList; + } + else + { + List finalResList = new List(cnt/8); + foreach (var elem in this.m_applyFunc(resList)) + { + finalResList.Add(elem); + } + this.m_workerResLists[taskIdx] = finalResList; + } + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = new DryadLinqException(HpcLinqErrorCode.FailureInOrderedGroupBy, + SR.FailureInOrderedGroupBy, e); + throw this.m_workerException; + } + finally + { + this.m_events[taskIdx].Set(); + } + } + } + } + + internal class ParallelOrderedGroupByAccumulate + : IParallelPipeline + { + private IEnumerable m_source; + private Func m_keySelector; + private Func m_elementSelector; + private Func m_seed; + private Func m_accumulator; + private Func>, IEnumerable> m_applyFunc; + private IEqualityComparer m_comparer; + + public ParallelOrderedGroupByAccumulate( + IEnumerable source, + Func keySelector, + Func elementSelector, + Func seed, + Func accumulator, + IEqualityComparer comparer, + Func>, IEnumerable> applyFunc) + { + if (seed == null || accumulator == null || elementSelector == null) + { + throw new DryadLinqException("Internal error: The accumulator and element selector can't be null"); + } + this.m_source = source; + this.m_keySelector = keySelector; + this.m_elementSelector = elementSelector; + this.m_seed = seed; + this.m_accumulator = accumulator; + this.m_applyFunc = applyFunc; + this.m_comparer = comparer; + if (this.m_comparer == null) + { + this.m_comparer = EqualityComparer.Default; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + return new InnerEnumerator(this); + } + + public IParallelPipeline + Extend(Func, IEnumerable> func, bool orderPreserving) + { + if (this.m_applyFunc == null) + { + var applyFunc = (Func>, IEnumerable>)(object)func; + return new ParallelOrderedGroupByAccumulate( + this.m_source, this.m_keySelector, this.m_elementSelector, + this.m_seed, this.m_accumulator, this.m_comparer, applyFunc); + } + else + { + return new ParallelOrderedGroupByAccumulate( + this.m_source, this.m_keySelector, this.m_elementSelector, + this.m_seed, this.m_accumulator, this.m_comparer, + s => func(this.m_applyFunc(s))); + } + } + + private class InnerEnumerator : IEnumerator + { + private const Int32 BufferSize = 16384; // (1 << 14); + + private IEnumerable m_source; + private Func m_keySelector; + private Func m_elementSelector; + private Func m_seed; + private Func m_accumulator; + private Func>, IEnumerable> m_applyFunc; + private IEqualityComparer m_comparer; + + private Thread m_mainWorker; + private Task[] m_workers; + private EventWaitHandle[] m_events; + private List[] m_workerResLists; + private BlockingCollection> m_resultQueue; + private volatile bool m_isDone; + private Exception m_workerException; + + private IEnumerator> m_resultEnum; + private List m_currentItems; + private int m_index; + private bool m_disposed; + + public InnerEnumerator(ParallelOrderedGroupByAccumulate parent) + { + this.m_source = parent.m_source; + this.m_keySelector = parent.m_keySelector; + this.m_elementSelector = parent.m_elementSelector; + this.m_seed = parent.m_seed; + this.m_accumulator = parent.m_accumulator; + this.m_applyFunc = parent.m_applyFunc; + this.m_comparer = parent.m_comparer; + + this.m_isDone = false; + this.m_workerException = null; + this.m_resultQueue = new BlockingCollection>(2); + this.m_workers = new Task[2 * Environment.ProcessorCount]; + this.m_events = new ManualResetEvent[this.m_workers.Length]; + this.m_workerResLists = new List[this.m_workers.Length]; + for (int i = 0; i < this.m_events.Length; i++) + { + this.m_events[i] = new ManualResetEvent(true); + } + this.m_mainWorker = new Thread(this.ProcessAllItems); + this.m_mainWorker.Name = "OGB.ProcessAllItem"; + this.m_mainWorker.Start(); + + this.m_resultEnum = this.m_resultQueue.GetConsumingEnumerable().GetEnumerator(); + this.m_currentItems = new List(); + this.m_index = -1; + this.m_disposed = false; + } + + ~InnerEnumerator() + { + this.Dispose(false); + } + + public bool MoveNext() + { + this.m_index++; + if (this.m_index < this.m_currentItems.Count) + { + return true; + } + + while (this.m_resultEnum.MoveNext()) + { + if (this.m_workerException != null) break; + + this.m_currentItems = this.m_resultEnum.Current; + if (this.m_currentItems.Count > 0) + { + this.m_index = 0; + return true; + } + } + if (this.m_workerException != null) + { + throw new DryadLinqException("Failed while enumerating.", this.m_workerException); + } + return false; + } + + public TFinal Current + { + get { return this.m_currentItems[this.m_index]; } + } + + object IEnumerator.Current + { + get { return this.m_currentItems[this.m_index]; } + } + + public void Reset() + { + throw new DryadLinqException("Internal error: Cannot reset this IEnumerator."); + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!this.m_disposed) + { + this.m_disposed = true; + this.m_isDone = true; + while (this.m_resultEnum.MoveNext()) + { + // Always drain the result queue + } + } + } + + private void ProcessAllItems() + { + try + { + // Read all the items + DryadLinqLog.Add("Parallel OrderedGroupBy (Acc) started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + Int32 wlen = this.m_workers.Length; + TSource[] elemBuffer = new TSource[BufferSize]; + Int32 taskIdx = 0; + Int32 elemCnt = 0; + IEnumerator sourceEnum = this.m_source.GetEnumerator(); + while (sourceEnum.MoveNext()) + { + TSource item = sourceEnum.Current; + if (elemCnt < BufferSize) + { + elemBuffer[elemCnt++] = item; + } + else + { + if (this.m_isDone) break; + + TKey lastKey = this.m_keySelector(elemBuffer[BufferSize - 1]); + TResult lastValue = default(TResult); + bool moreLast = this.m_comparer.Equals(lastKey, this.m_keySelector(item)); + bool hasMoreItems = true; + if (moreLast) + { + Int32 idx = BufferSize - 2; + while (idx >= 0 && this.m_comparer.Equals(lastKey, this.m_keySelector(elemBuffer[idx]))) + { + idx--; + } + elemCnt = idx++; + lastValue = this.m_seed(this.m_elementSelector(elemBuffer[elemCnt])); + for (int i = elemCnt + 1; i < BufferSize; i++) + { + lastValue = this.m_accumulator(lastValue, this.m_elementSelector(elemBuffer[i])); + } + while (hasMoreItems = sourceEnum.MoveNext()) + { + item = sourceEnum.Current; + if (!this.m_comparer.Equals(lastKey, this.m_keySelector(item))) + { + break; + } + lastValue = this.m_accumulator(lastValue, this.m_elementSelector(item)); + } + } + + Pair last = new Pair(lastKey, lastValue); + this.CreateTask(taskIdx, elemBuffer, elemCnt, moreLast, last); + taskIdx = (taskIdx + 1) % wlen; + elemCnt = 0; + if (!hasMoreItems) break; + elemBuffer = new TSource[BufferSize]; + elemBuffer[elemCnt++] = item; + } + } + + // Create a task for the last buffer + if (!this.m_isDone && elemCnt > 0) + { + this.CreateTask(taskIdx, elemBuffer, elemCnt, false, + new Pair(default(TKey), default(TResult))); + taskIdx = (taskIdx + 1) % wlen; + } + + DryadLinqLog.Add("Parallel OrderedGroupBy (Acc) ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + // Wait for all the workers to complete + for (int i = 0; i < this.m_workers.Length; i++) + { + this.m_events[taskIdx].WaitOne(); + if (this.m_workerResLists[taskIdx] != null) + { + this.m_resultQueue.Add(this.m_workerResLists[taskIdx]); + this.m_workerResLists[taskIdx] = null; + } + taskIdx = (taskIdx + 1) % wlen; + } + foreach (Task task in this.m_workers) + { + if (task != null) task.Wait(); + } + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = e; + } + finally + { + this.m_resultQueue.CompleteAdding(); + } + } + + private void CreateTask(Int32 taskIdx, TSource[] elems, Int32 cnt, bool moreLast, Pair last) + { + this.m_events[taskIdx].WaitOne(); + if (this.m_workerResLists[taskIdx] != null) + { + this.m_resultQueue.Add(this.m_workerResLists[taskIdx]); + this.m_workerResLists[taskIdx] = null; + } + this.m_events[taskIdx].Reset(); + + // NOT using the TCO.LongRunning option for this task, because it's spawned + // an arbitrary # of times for potentially shorter work which means it's best + // to leave this to the TP load balancing algorithm + this.m_workers[taskIdx] = Task.Factory.StartNew( + delegate { this.OrderedGroupBy(taskIdx, elems, cnt, moreLast, last); }); + } + + private void OrderedGroupBy(Int32 taskIdx, TSource[] elems, Int32 cnt, bool moreLast, Pair last) + { + try + { + List> resList = new List>(cnt / 8); + if (cnt > 0) + { + TKey curKey = this.m_keySelector(elems[0]); + TResult curValue = this.m_seed(this.m_elementSelector(elems[0])); + Int32 idx = 1; + while (idx < cnt) + { + if (this.m_comparer.Equals(curKey, this.m_keySelector(elems[idx]))) + { + curValue = this.m_accumulator(curValue, this.m_elementSelector(elems[idx])); + } + else + { + resList.Add(new Pair(curKey, curValue)); + curKey = this.m_keySelector(elems[idx]); + curValue = this.m_seed(this.m_elementSelector(elems[idx])); + } + idx++; + } + resList.Add(new Pair(curKey, curValue)); + } + + // Add the last value: + if (moreLast) + { + resList.Add(new Pair(last.Key, last.Value)); + } + + // Apply applyFunc: + if (this.m_applyFunc == null) + { + this.m_workerResLists[taskIdx] = (List)(object)resList; + } + else + { + List finalResList = new List(cnt/8); + foreach (var elem in this.m_applyFunc(resList)) + { + finalResList.Add(elem); + } + this.m_workerResLists[taskIdx] = finalResList; + } + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = new DryadLinqException(HpcLinqErrorCode.FailureInOrderedGroupBy, + SR.FailureInOrderedGroupBy, e); + throw this.m_workerException; + } + finally + { + this.m_events[taskIdx].Set(); + } + } + } + } + + internal class ParallelSort : IEnumerable + { + private IEnumerable m_source; + private Func m_keySelector; + private IComparer m_comparer; + private bool m_isDescending; + private bool m_isIdKeySelector; + private HpcLinqFactory m_elemFactory; + + public ParallelSort(IEnumerable source, + Func keySelector, + IComparer comparer, + bool isDescending, + bool isIdKeySelector, + HpcLinqFactory elemFactory) + { + this.m_source = source; + this.m_keySelector = keySelector; + this.m_comparer = TypeSystem.GetComparer(comparer); + if (isDescending) + { + this.m_comparer = MinusComparer.Make(this.m_comparer); + this.m_isDescending = false; + } + if (this.m_comparer is MinusComparer) + { + this.m_isDescending = !this.m_isDescending; + this.m_comparer = ((MinusComparer)this.m_comparer).InnerComparer; + } + this.m_isIdKeySelector = isIdKeySelector; + this.m_elemFactory = elemFactory; + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + return new InnerEnumerator(this); + } + + private class InnerEnumerator : IEnumerator + { + private const Int32 ChunkSize = (1 << 21); + private const Int32 MergeSize = 16; + + private IEnumerable m_source; + private Func m_keySelector; + private IComparer m_comparer; + private bool m_isDescending; + private bool m_isIdKeySelector; + private HpcLinqFactory m_elemFactory; + + private BlockingCollection m_sourceQueue; + private Thread m_mainWorker; + private Task[] m_sortWorkers; + + private List m_mergeSortWorkers; + private SortedChunkList m_chunkList; + private List m_mergeList; + private volatile bool m_isDone; + private Exception m_sorterException; + + private IEnumerator[] m_enumArray; + private IEnumerator m_resultEnum; + private bool m_disposed; + + public InnerEnumerator(ParallelSort parent) + { + this.m_source = parent.m_source; + this.m_keySelector = parent.m_keySelector; + this.m_isDescending = parent.m_isDescending; + this.m_comparer = parent.m_comparer; + this.m_isIdKeySelector = parent.m_isIdKeySelector; + this.m_elemFactory = parent.m_elemFactory; + + this.m_disposed = false; + + this.m_isDone = false; + this.m_sorterException = null; + this.m_sortWorkers = new Task[Environment.ProcessorCount]; + this.m_sourceQueue = new BlockingCollection(4); + + this.m_mergeSortWorkers = new List(8); + this.m_chunkList = new SortedChunkList(this); + this.m_mergeList = new List(8); + + // Start all the workers: + for (int i = 0; i < this.m_sortWorkers.Length; i++) + { + // using the TCO.LongRunning option, because we spawn a fixed number of tasks + // which means it is safe to request a decicated thread for each task + this.m_sortWorkers[i] = Task.Factory.StartNew(delegate { this.SortItemArray(); }, + TaskCreationOptions.LongRunning); + } + + // Start main worker, and wait for it + this.m_mainWorker = new Thread(this.ProcessAllItems); + this.m_mainWorker.Name = "S.ProcessAllItem"; + this.m_mainWorker.Start(); + this.m_mainWorker.Join(); + + DryadLinqLog.Add("Parallel mergesort workers all started at {0}, number of workers is {1}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff"), this.m_mergeList.Count); + + this.m_enumArray = new IEnumerator[this.m_mergeList.Count]; + for (int i = 0; i < this.m_enumArray.Length; i++) + { + this.m_enumArray[i] = this.m_mergeList[i].GetEnumerator(); + } + this.m_resultEnum = SortHelper.MergeSort(this.m_enumArray, this.m_keySelector, + this.m_comparer, this.m_isDescending); + } + + ~InnerEnumerator() + { + this.Dispose(false); + } + + public bool MoveNext() + { + return this.m_resultEnum.MoveNext(); + } + + public TElement Current + { + get { return this.m_resultEnum.Current; } + } + + object IEnumerator.Current + { + get { return this.m_resultEnum.Current; } + } + + public void Reset() + { + throw new DryadLinqException("Internal error: Cannot reset this IEnumerator."); + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!this.m_disposed) + { + this.m_disposed = true; + this.m_isDone = true; + } + + foreach (var item in this.m_sourceQueue.GetConsumingEnumerable()) + { + // Always drain the source queue + } + for (int i = 0; i < this.m_enumArray.Length; i++) + { + // Always drain the queues of mergesort workers + this.m_enumArray[i].MoveNext(); + } + } + + private void ProcessAllItems() + { + try + { + DryadLinqLog.Add("Parallel sort started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + TElement[] itemArray = new TElement[ChunkSize]; + Int32 itemCnt = 0; + Int32 chunkCnt = 0; + foreach (TElement item in this.m_source) + { + if (itemCnt == ChunkSize) + { + if (this.m_isDone) break; + + this.m_sourceQueue.Add(itemArray); + chunkCnt++; + itemArray = new TElement[ChunkSize]; + itemCnt = 0; + } + itemArray[itemCnt++] = item; + } + + // Process the final buffer + if (!this.m_isDone && itemCnt > 0) + { + if (itemCnt != itemArray.Length) + { + TElement[] newItemArray = new TElement[itemCnt]; + Array.Copy(itemArray, 0, newItemArray, 0, itemCnt); + itemArray = newItemArray; + } + this.m_sourceQueue.Add(itemArray); + chunkCnt++; + } + this.m_sourceQueue.CompleteAdding(); + + DryadLinqLog.Add("Parallel sort ended reading at {0}, number of sorters is {1}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff"), chunkCnt); + + // Wait for all the sort workers to complete + Task.WaitAll(this.m_sortWorkers); + + // Mergesort the final chunk list + if (!this.m_isDone && this.m_chunkList.Count > 0) + { + this.MergeSortChunks(this.m_chunkList); + this.m_mergeList.Add(this.m_chunkList); // This is happening without a lock because all sort workers have now completed + } + } + catch (Exception e) + { + this.m_isDone = true; + this.m_sorterException = e; + lock (this.m_mergeList) + { + SortedChunkList dummyChunkList = new SortedChunkList(this); + this.m_mergeList.Add(dummyChunkList); + dummyChunkList.MergeSort(); + } + } + finally + { + this.m_sourceQueue.CompleteAdding(); + } + } + + private unsafe void SortItemArray() + { + try + { + DryadLinqLog.Add("Parallel sort worker started at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + foreach (var itemArray in this.m_sourceQueue.GetConsumingEnumerable()) + { + if (this.m_isIdKeySelector) + { + // TElement and TKey must be equal + Array.Sort(itemArray, 0, itemArray.Length, (IComparer)this.m_comparer); + } + else + { + Int32 itemCnt = itemArray.Length; + TKey[] keyArray = new TKey[itemCnt]; + for (int i = 0; i < itemCnt; i++) + { + keyArray[i] = this.m_keySelector(itemArray[i]); + } + Array.Sort(keyArray, itemArray, 0, itemArray.Length, this.m_comparer); + } + if (this.m_isDescending) + { + Array.Reverse(itemArray); + } + + // Flush the current chunk to disk if memory is low + IEnumerable sortedChunk = itemArray; + MEMORYSTATUSEX memStatus = new MEMORYSTATUSEX(); + memStatus.dwLength = (UInt32)sizeof(MEMORYSTATUSEX); + HpcLinqNative.GlobalMemoryStatusEx(ref memStatus); + if (this.m_elemFactory != null && + HpcLinqNative.GlobalMemoryStatusEx(ref memStatus) && + memStatus.ullAvailPhys < 1 * 1024 * 1024 * 1024UL) + { + sortedChunk = new FileEnumerable(itemArray, this.m_elemFactory); + + // It may be a while until we move on to the next foreach iteartion. + // Until that happens itemArray and all its entries remain rooted due + // to the loop scope and therfore won't be available for GC. However + // we've already backed them up in a file and no one else will refer + // to itemArray[*] any more, so we can reduce memory pressure by + // cleaning up the array + Array.Clear(itemArray, 0, itemArray.Length); + } + + if (this.m_isDone) return; + + lock (this.m_mergeList) + { + if (this.m_chunkList.Count == MergeSize) + { + this.MergeSortChunks(this.m_chunkList); + this.m_mergeList.Add(this.m_chunkList); + this.m_chunkList = new SortedChunkList(this); + } + this.m_chunkList.Add(sortedChunk); + } + } + } + catch (Exception e) + { + this.m_isDone = true; + throw new DryadLinqException(HpcLinqErrorCode.FailureInSort, + String.Format(SR.FailureInSort), e); + } + finally + { + DryadLinqLog.Add("Parallel sort worker ended at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + } + + private void MergeSortChunks(SortedChunkList chunkList) + { + // NOT using the TCO.LongRunning option for this task, because it's + // spawned an arbitrary # of times for potentially shorter work + // which means it's best to leave this to the TP load balancing algorithm + Task task = Task.Factory.StartNew(delegate { chunkList.MergeSort(); }); + this.m_mergeSortWorkers.Add(task); + } + + private class SortedChunkList : IEnumerable + { + private const Int32 ResultChunkSize = 4096; // (1 << 12) + + private InnerEnumerator m_parent; + private Func m_keySelector; + private IComparer m_comparer; + private bool m_isDescending; + private IEnumerable[] m_itemArrays; + private Int32 m_count; + private Exception m_mergerException; + private BlockingCollection m_resultQueue; + + public SortedChunkList(InnerEnumerator parent) + { + this.m_parent = parent; + this.m_keySelector = parent.m_keySelector; + this.m_comparer = parent.m_comparer; + this.m_isDescending = parent.m_isDescending; + this.m_itemArrays = new IEnumerable[MergeSize]; + this.m_count = 0; + this.m_mergerException = null; + this.m_resultQueue = new BlockingCollection(2); + } + + public void Add(IEnumerable itemArray) + { + if (this.m_count == this.m_itemArrays.Length) + { + TElement[][] newItemArrays = new TElement[this.m_count * 2][]; + Array.Copy(this.m_itemArrays, 0, newItemArrays, 0, this.m_count); + this.m_itemArrays = newItemArrays; + } + this.m_itemArrays[this.m_count++] = itemArray; + } + + public Int32 Count + { + get { return this.m_count; } + } + + public void MergeSort() + { + try + { + if (this.m_parent.m_sorterException != null) + { + this.m_mergerException = this.m_parent.m_sorterException; + this.m_resultQueue.Add(new TElement[0]); + } + else + { + IEnumerator[] enumArray = new IEnumerator[this.m_count]; + TKey[] keys = new TKey[this.m_count]; + Int32 mergeCnt = this.m_count; + for (int i = 0; i < this.m_count; i++) + { + IEnumerator itemArrayEnum = this.m_itemArrays[i].GetEnumerator(); + if (!itemArrayEnum.MoveNext()) + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, + SR.SortedChunkCannotBeEmpty); + } + enumArray[i] = itemArrayEnum; + keys[i] = this.m_keySelector(itemArrayEnum.Current); + } + + TElement[] resultChunk = new TElement[ResultChunkSize]; + Int32 resultChunkIdx = 0; + while (mergeCnt > 1) + { + TKey key = keys[0]; + int idx = 0; + for (int i = 1; i < mergeCnt; i++) + { + int cmp = this.m_comparer.Compare(key, keys[i]); + cmp = (this.m_isDescending) ? -cmp : cmp; + if (cmp > 0) + { + key = keys[i]; + idx = i; + } + } + + if (resultChunkIdx == ResultChunkSize) + { + if (this.m_parent.m_isDone) break; + + this.m_resultQueue.Add(resultChunk); + resultChunk = new TElement[ResultChunkSize]; + resultChunkIdx = 0; + } + resultChunk[resultChunkIdx++] = enumArray[idx].Current; + + if (enumArray[idx].MoveNext()) + { + keys[idx] = this.m_keySelector(enumArray[idx].Current); + } + else + { + mergeCnt--; + if (idx < mergeCnt) + { + enumArray[idx] = enumArray[mergeCnt]; + keys[idx] = keys[mergeCnt]; + } + } + } + + if (mergeCnt == 1) + { + IEnumerator enum0 = enumArray[0]; + do + { + if (resultChunkIdx == ResultChunkSize) + { + if (this.m_parent.m_isDone) break; + + this.m_resultQueue.Add(resultChunk); + resultChunk = new TElement[ResultChunkSize]; + resultChunkIdx = 0; + } + resultChunk[resultChunkIdx++] = enum0.Current; + } + while (enum0.MoveNext()); + } + + // Get rid of those item arrays + for (int i = 0; i < this.m_count; i++) + { + this.m_itemArrays[i] = null; + } + + // Add the final chunk + if (!this.m_parent.m_isDone && resultChunkIdx > 0) + { + if (resultChunkIdx != ResultChunkSize) + { + TElement[] lastResultChunk = new TElement[resultChunkIdx]; + Array.Copy(resultChunk, 0, lastResultChunk, 0, resultChunkIdx); + resultChunk = lastResultChunk; + } + this.m_resultQueue.Add(resultChunk); + } + } + } + catch (Exception e) + { + this.m_mergerException = e; + this.m_resultQueue.Add(new TElement[0]); + } + finally + { + // Always declare the adding is complete. + this.m_resultQueue.CompleteAdding(); + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + foreach (var res in this.m_resultQueue.GetConsumingEnumerable()) + { + if (res.Length == 0) + { + this.m_parent.m_isDone = true; + throw new DryadLinqException("ParallelSort.SortedChunkList failed.", this.m_mergerException); + } + yield return res; + } + } + } + } + } + + internal class ParallelMergeSort : IEnumerable + { + private IMultiEnumerable m_source; + private Func m_keySelector; + private IComparer m_comparer; + private bool m_isDescending; + + public ParallelMergeSort(IMultiEnumerable source, + Func keySelector, + IComparer comparer, + bool isDescending) + { + this.m_source = source; + this.m_keySelector = keySelector; + this.m_comparer = TypeSystem.GetComparer(comparer); + if (isDescending) + { + this.m_comparer = MinusComparer.Make(this.m_comparer); + this.m_isDescending = false; + } + if (this.m_comparer is MinusComparer) + { + this.m_isDescending = !this.m_isDescending; + this.m_comparer = ((MinusComparer)this.m_comparer).InnerComparer; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + return new InnerEnumerator(this); + } + + private class InnerEnumerator : IEnumerator + { + // This looks a lot like what I did in SkyServerQ18. + private const Int32 MergeSize = 16; + + private IMultiEnumerable m_source; + private Func m_keySelector; + private IComparer m_comparer; + private bool m_isDescending; + private HpcRecordReader[] m_readers; + private List m_mergeSortWorkers; + private List m_mergeList; + + private volatile bool m_isDone; + private IEnumerator[] m_enumArray; + private IEnumerator m_resultEnum; + private bool m_disposed; + + public InnerEnumerator(ParallelMergeSort parent) + { + this.m_source = parent.m_source; + this.m_keySelector = parent.m_keySelector; + this.m_comparer = parent.m_comparer; + this.m_isDescending = parent.m_isDescending; + + this.m_readers = new HpcRecordReader[this.m_source.NumberOfInputs]; + for (int i = 0; i < this.m_readers.Length; i++) + { + this.m_readers[i] = (HpcRecordReader)this.m_source[i]; + } + this.m_mergeSortWorkers = new List(); + this.m_mergeList = new List(); + + // Start mergesort workers: + Int32 readerCnt = this.m_readers.Length; + Int32 startIdx = 0; + while (startIdx < readerCnt) + { + Int32 size = Math.Min(MergeSize, readerCnt - startIdx); + SubrangeReader subReaders = new SubrangeReader(this, startIdx, size); + this.m_mergeList.Add(subReaders); + this.StartMergeSortWorker(subReaders); +#if NET35 + System.Threading.Thread.Sleep(10); // YY: Hack for 3.5 +#endif + startIdx += size; + } + + this.m_enumArray = new IEnumerator[this.m_mergeList.Count]; + for (int i = 0; i < this.m_enumArray.Length; i++) + { + this.m_enumArray[i] = this.m_mergeList[i].GetEnumerator(); + } + this.m_resultEnum = SortHelper.MergeSort(this.m_enumArray, this.m_keySelector, + this.m_comparer, this.m_isDescending); + } + + ~InnerEnumerator() + { + this.Dispose(false); + } + + public bool MoveNext() + { + return this.m_resultEnum.MoveNext(); + } + + public TSource Current + { + get { + return this.m_resultEnum.Current; + } + } + + object IEnumerator.Current + { + get { + return this.m_resultEnum.Current; + } + } + + public void Reset() + { + throw new DryadLinqException("Internal error: Cannot reset this IEnumerator."); + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!this.m_disposed) + { + this.m_disposed = true; + this.m_isDone = true; + } + // Always drain the queues of mergesort workers + for (int i = 0; i < this.m_enumArray.Length; i++) + { + this.m_enumArray[i].MoveNext(); + } + } + + private void StartMergeSortWorker(SubrangeReader subReaders) + { + // NOT using the TCO.LongRunning option for this task, because it's + // spawned an arbitrary # of times for potentially shorter work + // which means it's best to leave this to the TP load balancing algorithm + Task task = Task.Factory.StartNew(delegate { subReaders.MergeSort(); }); + this.m_mergeSortWorkers.Add(task); + } + + private class SubrangeReader : IEnumerable + { + private const Int32 ChunkSize = 4096; // (1 << 12) + + private InnerEnumerator m_parent; + private Int32 m_startIdx; + private IEnumerator[] m_enumerators; + private Func m_keySelector; + private IComparer m_comparer; + private bool m_isDescending; + private Exception m_mergerException; + private BlockingCollection m_resultQueue; + + public SubrangeReader(InnerEnumerator parent, + Int32 startIdx, + Int32 len) + { + this.m_parent = parent; + this.m_startIdx = startIdx; + this.m_enumerators= new IEnumerator[len]; + for (int i = 0; i < len; i++) + { + this.m_enumerators[i] = parent.m_readers[startIdx + i].GetEnumerator(); + } + this.m_keySelector = parent.m_keySelector; + this.m_comparer = parent.m_comparer; + this.m_isDescending = parent.m_isDescending; + this.m_resultQueue = new BlockingCollection(2); + } + + public void MergeSort() + { + try + { + DryadLinqLog.Add("ParallelMergeSort.SubrangeReader({0}) started at {1}", + this.m_startIdx, + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + TSource[] elems = new TSource[this.m_enumerators.Length]; + TKey[] keys = new TKey[this.m_enumerators.Length]; + int lastIdx = m_enumerators.Length - 1; + int readerCnt = 0; + while (readerCnt <= lastIdx) + { + + if (this.m_enumerators[readerCnt].MoveNext()) + { + elems[readerCnt] = this.m_enumerators[readerCnt].Current; + keys[readerCnt] = this.m_keySelector(elems[readerCnt]); + readerCnt++; + } + else + { + this.m_enumerators[readerCnt].Dispose(); + + if (readerCnt == lastIdx) break; + this.m_enumerators[readerCnt] = this.m_enumerators[lastIdx]; + lastIdx--; + } + } + + TSource[] resultChunk = new TSource[ChunkSize]; + Int32 resultChunkIdx = 0; + while (readerCnt > 1) + { + TKey key = keys[0]; + int idx = 0; + for (int i = 1; i < readerCnt; i++) + { + int cmp = this.m_comparer.Compare(key, keys[i]); + cmp = (this.m_isDescending) ? -cmp : cmp; + if (cmp > 0) + { + key = keys[i]; + idx = i; + } + } + + if (resultChunkIdx == ChunkSize) + { + if (this.m_parent.m_isDone) break; + + this.m_resultQueue.Add(resultChunk); + resultChunk = new TSource[ChunkSize]; + resultChunkIdx = 0; + } + resultChunk[resultChunkIdx++] = elems[idx]; + + if (this.m_enumerators[idx].MoveNext()) + { + elems[idx] = this.m_enumerators[idx].Current; + keys[idx] = this.m_keySelector(elems[idx]); + } + else + { + this.m_enumerators[idx].Dispose(); + + readerCnt--; + if (idx < readerCnt) + { + this.m_enumerators[idx] = this.m_enumerators[readerCnt]; + elems[idx] = elems[readerCnt]; + keys[idx] = keys[readerCnt]; + } + } + } + + if (!this.m_parent.m_isDone && readerCnt == 1) + { + TSource elem = elems[0]; + IEnumerator enumerator = this.m_enumerators[0]; + do + { + if (resultChunkIdx == ChunkSize) + { + this.m_resultQueue.Add(resultChunk); + resultChunk = new TSource[ChunkSize]; + resultChunkIdx = 0; + } + resultChunk[resultChunkIdx++] = elem; + if(!enumerator.MoveNext()){ + break; + } + elem = enumerator.Current; + } + while (true); + enumerator.Dispose(); + } + + // Add the final chunk + if (!this.m_parent.m_isDone && resultChunkIdx > 0) + { + if (resultChunkIdx != ChunkSize) + { + TSource[] lastResultChunk = new TSource[resultChunkIdx]; + Array.Copy(resultChunk, 0, lastResultChunk, 0, resultChunkIdx); + resultChunk = lastResultChunk; + } + this.m_resultQueue.Add(resultChunk); + } + } + catch (Exception e) + { + this.m_mergerException = e; + this.m_resultQueue.Add(new TSource[0]); + } + finally + { + DryadLinqLog.Add("ParallelMergeSort.SubrangeReader({0}) ended at {1}", + this.m_startIdx, + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + // Always declare the adding is complete. + this.m_resultQueue.CompleteAdding(); + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + foreach (var res in this.m_resultQueue.GetConsumingEnumerable()) + { + if (res.Length == 0) + { + this.m_parent.m_isDone = true; + throw new DryadLinqException("ParallelMergeSort failed.", this.m_mergerException); + } + yield return res; + } + } + } + } + } + + internal interface IParallelPipeline : IEnumerable + { + IParallelPipeline Extend(Func, IEnumerable> func, + bool orderPreserving); + } + + internal interface IParallelApply + { + IEnumerable> ExtendGroupBy( + Func keySelector, + Func elementSelector, + Func seed, + Func accumulator, + IEqualityComparer comparer); + } + + internal class ParallelApply : IParallelApply, IParallelPipeline + { + private IEnumerable m_source; + private Func, IEnumerable> m_procFunc; + private bool m_orderPreserving; + + public ParallelApply(IEnumerable source, + Func, IEnumerable> procFunc, + bool orderPreserving) + { + this.m_source = source; + this.m_procFunc = procFunc; + this.m_orderPreserving = orderPreserving; + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + return new InnerEnumerator(this); + } + + public IParallelPipeline + Extend(Func, IEnumerable> func, bool orderPreserving) + { + return new ParallelApply(this.m_source, + s => func(this.m_procFunc(s)), + this.m_orderPreserving && orderPreserving); + } + + public IEnumerable> + ExtendGroupBy( + Func keySelector, + Func elementSelector, + Func seed, + Func accumulator, + IEqualityComparer comparer) + { + return new ParallelHashGroupByPartialAccumulate( + this.m_source, this.m_procFunc, keySelector, elementSelector, seed, accumulator, comparer); + + } + + private class InnerEnumerator : IEnumerator + { + private static Int32[] ChunkSizes = new Int32[] { 2, 4, 4, 8, 8, 64, 512, 1024, 2048, 4096 }; + private static Int32 ResultBufferSize = 4096; + + private IEnumerable m_source; + private Func, IEnumerable> m_procFunc; + private bool m_orderPreserving; + + private int m_maxQueueSize; + private BlockingCollection>> m_sourceQueue; + private BlockingCollection> m_resultQueue; + private Thread m_mainWorker; + private Task[] m_workers; + private volatile bool m_isDone; + private Exception m_workerException; + + private IEnumerator> m_resultEnum; + private TResult[] m_currentItems; + private int m_index; + private bool m_disposed; + + public InnerEnumerator(ParallelApply parent) + { + this.m_source = parent.m_source; + this.m_procFunc = parent.m_procFunc; + this.m_orderPreserving = parent.m_orderPreserving; + this.m_isDone = false; + this.m_workerException = null; + this.m_workers = new Task[Environment.ProcessorCount]; + this.m_maxQueueSize = Math.Max(4, this.m_workers.Length); + this.m_sourceQueue = new BlockingCollection>>(this.m_maxQueueSize); + this.m_resultQueue = new BlockingCollection>(this.m_workers.Length*2); + + // Start all the workers: + for (int i = 0; i < this.m_workers.Length; i++) + { + // using the TCO.LongRunning option, because we spawn a fixed number of tasks + // which means it is safe to request a decicated thread for each task + this.m_workers[i] = Task.Factory.StartNew(delegate { this.ApplyFunc(); }, + TaskCreationOptions.LongRunning); + } + this.m_mainWorker = new Thread(this.ProcessAllItems); + this.m_mainWorker.Name = "HA.ProcessAllItem"; + this.m_mainWorker.Start(); + + this.m_resultEnum = this.m_resultQueue.GetConsumingEnumerable().GetEnumerator(); + this.m_currentItems = new TResult[0]; + this.m_index = -1; + this.m_disposed = false; + } + + ~InnerEnumerator() + { + this.Dispose(false); + } + + public bool MoveNext() + { + this.m_index++; + if (this.m_index < this.m_currentItems.Length) + { + return true; + } + + while (this.m_resultEnum.MoveNext()) + { + var wrapperItem = this.m_resultEnum.Current; + if (wrapperItem.item == null) + { + lock (wrapperItem) + { + while (wrapperItem.item == null) + { + Monitor.Wait(wrapperItem); + } + } + } + + if (this.m_workerException != null) break; + + this.m_currentItems = wrapperItem.item; + if (this.m_currentItems.Length > 0) + { + this.m_index = 0; + return true; + } + } + if (this.m_workerException != null) + { + throw new DryadLinqException("ParallelApply failed.", this.m_workerException); + } + return false; + } + + public TResult Current + { + get { return this.m_currentItems[this.m_index]; } + } + + object IEnumerator.Current + { + get { return this.m_currentItems[this.m_index]; } + } + + public void Reset() + { + throw new DryadLinqException("Internal error: Cannot reset this IEnumerator."); + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!this.m_disposed) + { + this.m_disposed = true; + this.m_isDone = true; + + // Always drain the source queue + Pair> sourceElem; + for (int i = 0; i < this.m_maxQueueSize; i++) + { + if (!this.m_sourceQueue.TryTake(out sourceElem)) break; + } + + // Always drain the result queue + while (this.m_resultEnum.MoveNext()) + { + } + } + } + + private void ProcessAllItems() + { + try + { + DryadLinqLog.Add("ParallelApply started reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + // Read all the items + Int32 wlen = this.m_workers.Length; + Int32 bufferSize = ChunkSizes[0] * wlen; + Int32 rateIdx = 1; + TSource[] elemBuffer = new TSource[bufferSize]; + Int32 elemIdx = 0; + foreach (TSource elem in this.m_source) + { + if (elemIdx == bufferSize) + { + if (this.m_isDone) break; + + Int32 chunkSize = bufferSize / wlen; + for (int i = 0; i < wlen; i++) + { + TSource[] chunk = new TSource[chunkSize]; + Array.Copy(elemBuffer, chunkSize * i, chunk, 0, chunkSize); + + if (this.m_orderPreserving) + { + Wrapper res = new Wrapper(null); + this.m_sourceQueue.Add(new Pair>(chunk, res)); + this.m_resultQueue.Add(res); + } + else + { + this.m_sourceQueue.Add(new Pair>(chunk, null)); + } + } + if (rateIdx < ChunkSizes.Length) + { + bufferSize = ChunkSizes[rateIdx] * wlen; + elemBuffer = new TSource[bufferSize]; + rateIdx++; + } + elemIdx = 0; + } + elemBuffer[elemIdx++] = elem; + } + + // Add the last buffer. + if (!this.m_isDone && elemIdx > 0) + { + Int32 chunkSize = elemIdx / wlen; + Int32 remainingCnt = elemIdx % wlen; + elemIdx = 0; + for (int i = 0; i < wlen; i++) + { + Int32 realChunkSize = (i < remainingCnt) ? chunkSize + 1 : chunkSize; + if (realChunkSize == 0) break; + + TSource[] chunk = new TSource[realChunkSize]; + Array.Copy(elemBuffer, elemIdx, chunk, 0, realChunkSize); + elemIdx += realChunkSize; + + if (this.m_orderPreserving) + { + Wrapper res = new Wrapper(null); + this.m_sourceQueue.Add(new Pair>(chunk, res)); + this.m_resultQueue.Add(res); + } + else + { + this.m_sourceQueue.Add(new Pair>(chunk, null)); + } + } + } + + // Now, the adding is complete. + this.m_sourceQueue.CompleteAdding(); + + DryadLinqLog.Add("ParallelApply ended reading at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + // Wait for all the workers to complete + Task.WaitAll(this.m_workers); + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = e; + } + finally + { + this.m_sourceQueue.CompleteAdding(); + this.m_resultQueue.CompleteAdding(); + } + } + + private void ApplyFunc() + { + Wrapper wrapperItem = null; + try + { + DryadLinqLog.Add("Parallel Apply worker started at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + + foreach (var item in this.m_sourceQueue.GetConsumingEnumerable()) + { + wrapperItem = item.Value; + IEnumerable res = this.m_procFunc(item.Key); + TResult[] res1 = res as TResult[]; + + if (this.m_orderPreserving) + { + if (res1 == null) + { + res1 = res.ToArray(); + } + lock (wrapperItem) + { + wrapperItem.item = res1; + Monitor.Pulse(wrapperItem); + } + } + else + { + if (res1 == null) + { + TResult[] buffer = new TResult[ResultBufferSize]; + int cnt = 0; + foreach (var elem in res) + { + if (cnt == ResultBufferSize) + { + this.m_resultQueue.Add(new Wrapper(buffer)); + buffer = new TResult[ResultBufferSize]; + cnt = 0; + } + buffer[cnt++] = elem; + } + if (cnt > 0) + { + if (cnt != ResultBufferSize) + { + TResult[] buffer1 = new TResult[cnt]; + Array.Copy(buffer, buffer1, cnt); + buffer = buffer1; + } + this.m_resultQueue.Add(new Wrapper(buffer)); + } + } + else + { + this.m_resultQueue.Add(new Wrapper(res1)); + } + } + if (this.m_isDone) return; + } + } + catch (Exception e) + { + this.m_isDone = true; + this.m_workerException = new DryadLinqException(HpcLinqErrorCode.FailureInUserApplyFunction, + SR.FailureInUserApplyFunction, e); + if (this.m_orderPreserving) + { + if (wrapperItem.item == null) + { + lock (wrapperItem) + { + wrapperItem.item = new TResult[0]; + Monitor.Pulse(wrapperItem); + } + } + } + else + { + this.m_resultQueue.Add(new Wrapper(new TResult[0])); + } + } + finally + { + DryadLinqLog.Add("Parallel Apply worker ended at {0}", + DateTime.Now.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + } + } + } + + internal class SortHelper + { + internal static IEnumerator + MergeSort(List keyArrayList, + List itemArrayList, + List itemArraySizeList, + IComparer comparer) + { + Int32 itemArrayCnt = itemArrayList.Count; + TKey[][] keyArrays = new TKey[itemArrayCnt][]; + TElement[][] itemArrays = new TElement[itemArrayCnt][]; + TKey[] keys = new TKey[itemArrayCnt]; + Int32[] indexArray = new Int32[itemArrayCnt]; + for (int i = 0; i < itemArrayCnt; i++) + { + keyArrays[i] = keyArrayList[i]; + itemArrays[i] = itemArrayList[i]; + keys[i] = keyArrayList[i][0]; + indexArray[i] = 0; + } + + while (itemArrayCnt > 0) + { + TKey key = keys[0]; + int idx = 0; + for (int i = 1; i < itemArrayCnt; i++) + { + if (comparer.Compare(key, keys[i]) > 0) + { + key = keys[i]; + idx = i; + } + } + + yield return itemArrays[idx][indexArray[idx]++]; + + if (indexArray[idx] < itemArraySizeList[idx]) + { + keys[idx] = keyArrays[idx][indexArray[idx]]; + } + else + { + itemArrayCnt--; + if (idx < itemArrayCnt) + { + itemArrays[idx] = itemArrays[itemArrayCnt]; + keyArrays[idx] = keyArrays[itemArrayCnt]; + itemArraySizeList[idx] = itemArraySizeList[itemArrayCnt]; + indexArray[idx] = indexArray[itemArrayCnt]; + keys[idx] = keys[itemArrayCnt]; + } + } + } + } + + internal static IEnumerable + HeapMergeSort(IMultiEnumerable source, + Func keySelector, + IComparer comparer, + bool isDescending) + { + comparer = TypeSystem.GetComparer(comparer); + + // Initialize + IEnumerable[] readers = new IEnumerable[source.NumberOfInputs]; + for (int i = 0; i < readers.Length; i++) + { + readers[i] = source[i]; + } + return HpcLinqUtil.MergeSort(readers, keySelector, comparer, isDescending); + } + + internal static IEnumerator + MergeSort(IEnumerator[] enumArray, + Func keySelector, + IComparer comparer, + bool isDescending) + { + TElement[][] itemArrays = new TElement[enumArray.Length][]; + TKey[] keys = new TKey[enumArray.Length]; + Int32[] indexArray = new Int32[enumArray.Length]; + Int32 mergeCnt = 0; + + for (int i = 0; i < enumArray.Length; i++) + { + if (enumArray[i].MoveNext()) + { + enumArray[mergeCnt] = enumArray[i]; + itemArrays[mergeCnt] = enumArray[mergeCnt].Current; + keys[mergeCnt] = keySelector(itemArrays[mergeCnt][0]); + indexArray[mergeCnt] = 0; + mergeCnt++; + } + } + + while (mergeCnt > 1) + { + TKey key = keys[0]; + Int32 idx = 0; + for (int i = 1; i < mergeCnt; i++) + { + int cmp = comparer.Compare(key, keys[i]); + cmp = (isDescending) ? -cmp : cmp; + if (cmp > 0) + { + key = keys[i]; + idx = i; + } + } + + yield return itemArrays[idx][indexArray[idx]++]; + + if (indexArray[idx] < itemArrays[idx].Length) + { + keys[idx] = keySelector(itemArrays[idx][indexArray[idx]]); + } + else if (enumArray[idx].MoveNext()) + { + itemArrays[idx] = enumArray[idx].Current; + keys[idx] = keySelector(itemArrays[idx][0]); + indexArray[idx] = 0; + } + else + { + mergeCnt--; + if (idx < mergeCnt) + { + enumArray[idx] = enumArray[mergeCnt]; + itemArrays[idx] = itemArrays[mergeCnt]; + indexArray[idx] = indexArray[mergeCnt]; + keys[idx] = keys[mergeCnt]; + } + } + } + + if (mergeCnt == 1) + { + IEnumerator elems = enumArray[0]; + TElement[] itemArray = itemArrays[0]; + Int32 index = indexArray[0]; + while (true) + { + for (int i = index; i < itemArray.Length; i++) + { + yield return itemArray[i]; + } + if (!elems.MoveNext()) break; + itemArray = elems.Current; + index = 0; + } + } + } + } + + internal struct MinusComparer : IComparer, IEquatable> + { + private IComparer m_comparer; + + internal MinusComparer(IComparer comparer) + { + this.m_comparer = comparer; + } + + internal static IComparer Make(IComparer comparer) + { + if (comparer is MinusComparer) + { + return ((MinusComparer)comparer).InnerComparer; + } + return new MinusComparer(comparer); + } + + internal IComparer InnerComparer + { + get { return this.m_comparer; } + } + + public int Compare(T x, T y) + { + return this.m_comparer.Compare(y, x); + } + + public bool Equals(MinusComparer val) + { + return this.InnerComparer.Equals(val.InnerComparer); + } + } + + internal struct ReferenceEqualityComparer : IEqualityComparer + { + public bool Equals(T x, T y) + { + return Object.ReferenceEquals(x, y); + } + + public int GetHashCode(T x) + { + return x.GetHashCode(); + } + } + + public class FileEnumerable : IEnumerable + { + private HpcLinqFactory m_factory; + private string m_fileName; + private HpcRecordReader m_reader; + + internal FileEnumerable(T[] elems, HpcLinqFactory factory) + { + this.m_factory = factory; + this.m_fileName = HpcLinqUtil.MakeUniqueName(); + this.m_reader = null; + + //@@TODO[p3]: could potentially use compression to reduce I/O + //note: only sequential is used for writing. + NativeBlockStream ns = new HpcLinqFileStream(this.m_fileName, FileAccess.Write); + HpcRecordWriter writer = this.m_factory.MakeWriter(ns); + try + { + for (int i = 0; i < elems.Length; i++) + { + writer.WriteRecordSync(elems[i]); + } + } + finally + { + writer.Close(); + } + } + + ~FileEnumerable() + { + this.Dispose(false); + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (this.m_reader != null) + { + this.m_reader.Close(); + } + File.Delete(this.m_fileName); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + private IEnumerator GetEnumerator() + { + NativeBlockStream ns = new HpcLinqFileStream(this.m_fileName, FileAccess.Read); + this.m_reader = this.m_factory.MakeReader(ns); + + T rec = default(T); + try + { + if (HpcLinqVertex.s_multiThreading) + { + this.m_reader.StartWorker(); + while (this.m_reader.ReadRecordAsync(ref rec)) + { + yield return rec; + } + } + else + { + while (this.m_reader.ReadRecordSync(ref rec)) + { + yield return rec; + } + } + } + finally + { + this.m_reader.Close(); + } + } + } +} diff --git a/LinqToDryad/DryadLinqVertexParams.cs b/LinqToDryad/DryadLinqVertexParams.cs new file mode 100644 index 0000000..71de101 --- /dev/null +++ b/LinqToDryad/DryadLinqVertexParams.cs @@ -0,0 +1,90 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Microsoft.Research.DryadLinq.Internal +{ + public class HpcLinqVertexParams + { + private string m_vertexStageName; + private bool m_useLargeBuffer; + private bool m_keepInputPortOrder; + private string m_remoteArch; + private int _inputArity; + private int _outputArity; + private bool m_multiThreading; + + public HpcLinqVertexParams(int inputArity, int outputArity) + { + _inputArity = inputArity; + _outputArity = outputArity; + } + + public string VertexStageName + { + get { return this.m_vertexStageName; } + set { this.m_vertexStageName = value; } + } + + public int InputArity + { + get { return _inputArity; } + set { _inputArity = value; } + } + + public int OutputArity + { + get { return _outputArity; } + set { _outputArity = value; } + } + + + public string RemoteArch + { + get { return this.m_remoteArch; } + set { this.m_remoteArch = value; } + } + + public bool UseLargeBuffer + { + get { return this.m_useLargeBuffer; } + set { this.m_useLargeBuffer = value; } + } + + public bool KeepInputPortOrder + { + get { return this.m_keepInputPortOrder; } + set { this.m_keepInputPortOrder = value; } + } + + public bool MultiThreading + { + get { return m_multiThreading; } + set { m_multiThreading = value; } + } + } +} diff --git a/LinqToDryad/DryadQueryDoc.cs b/LinqToDryad/DryadQueryDoc.cs new file mode 100644 index 0000000..2f6b822 --- /dev/null +++ b/LinqToDryad/DryadQueryDoc.cs @@ -0,0 +1,54 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.Text; +using System.Xml; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + internal class DryadQueryDoc + { + // Create the XML element specifying the vertex code. + internal static XmlElement CreateVertexEntryElem(XmlDocument queryDoc, string dllFileName, string vertexMethod) + { + XmlElement entryElem = queryDoc.CreateElement("Entry"); + + XmlElement elem = queryDoc.CreateElement("AssemblyName"); + elem.InnerText = dllFileName; + entryElem.AppendChild(elem); + + elem = queryDoc.CreateElement("ClassName"); + elem.InnerText = HpcLinqCodeGen.VertexClassFullName; + entryElem.AppendChild(elem); + + elem = queryDoc.CreateElement("MethodName"); + elem.InnerText = vertexMethod; + entryElem.AppendChild(elem); + + return entryElem; + } + } +} diff --git a/LinqToDryad/DryadQueryExplain.cs b/LinqToDryad/DryadQueryExplain.cs new file mode 100644 index 0000000..3d8776d --- /dev/null +++ b/LinqToDryad/DryadQueryExplain.cs @@ -0,0 +1,785 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Xml; +using System.Diagnostics; +using System.Xml.Linq; +using System.Drawing; +using System.Drawing.Drawing2D; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + //@@TODO[P1]: update for new APIs. + + /// + /// This class explains in detail the generated plan. + /// + internal sealed class DryadQueryExplain + { + /// + /// Visit the set of nodes in the query plan and build an explanation of the plan. + /// + /// Return plan description here. + /// Nodes to explain. + internal void CodeShowVisit(StringBuilder plan, DryadQueryNode[] nodes) + { + HashSet visited = new HashSet(); + foreach (DryadQueryNode n in nodes) + { + CodeShowVisit(plan, n, visited); + } + } + + /// + /// Helper for CodeShowVisit: do not revisit a node twice. + /// + /// Return plan here. + /// Node to explain. + /// Set of nodes already visited. + private void CodeShowVisit(StringBuilder plan, DryadQueryNode n, HashSet visited) + { + if (visited.Contains(n)) return; + visited.Add(n); + + foreach (DryadQueryNode c in n.Children) + { + CodeShowVisit(plan, c, visited); + } + + ExplainNode(plan, n); + } + + /// + /// Explain one query node. + /// + /// Return plan here. + /// Node to explain. + internal static void ExplainNode(StringBuilder plan, DryadQueryNode n) + { + if (n is DryadTeeNode || n is DryadOutputNode) + { + return; + } + else if (n is DryadInputNode) + { + plan.AppendLine("Input:"); + plan.Append("\t"); + n.BuildString(plan); + plan.AppendLine(); + return; + } + + plan.Append(n.m_vertexEntryMethod); + plan.AppendLine(":"); + + HashSet allchildren = new HashSet(); + + if (n is DryadSuperNode) + { + DryadSuperNode sn = n as DryadSuperNode; + List tovisit = new List(); + + tovisit.Add(sn.RootNode); + + while (tovisit.Count > 0) + { + DryadQueryNode t = tovisit[0]; + tovisit.RemoveAt(0); + if (!(t is DryadSuperNode)) + allchildren.Add(t); + foreach (DryadQueryNode tc in t.Children) + { + if (!allchildren.Contains(tc) && sn.Contains(tc)) + tovisit.Add(tc); + } + } + } + else + allchildren.Add(n); + + foreach (DryadQueryNode nc in allchildren.Reverse()) + { + Expression expression = null; // expression to print + List additional = new List(); // additional arguments to print + int argsToSkip = 0; + string methodname = nc.OpName; + + plan.Append("\t"); + + if (nc is DryadMergeNode) + { + expression = ((DryadMergeNode)nc).ComparerExpression; + } + else if (nc is DryadHashPartitionNode) + { + DryadHashPartitionNode hp = (DryadHashPartitionNode)nc; + expression = hp.KeySelectExpression; + additional.Add(hp.NumberOfPartitions.ToString()); + } + else if (nc is DryadGroupByNode) + { + DryadGroupByNode gb = (DryadGroupByNode)nc; + expression = gb.KeySelectExpression; + if (gb.ElemSelectExpression != null) + additional.Add(HpcLinqExpression.Summarize(gb.ElemSelectExpression)); + if (gb.ResSelectExpression != null) + additional.Add(HpcLinqExpression.Summarize(gb.ResSelectExpression)); + if (gb.ComparerExpression != null) + additional.Add(HpcLinqExpression.Summarize(gb.ComparerExpression)); + if (gb.SeedExpression != null) + additional.Add(HpcLinqExpression.Summarize(gb.SeedExpression)); + if (gb.AccumulatorExpression != null) + additional.Add(HpcLinqExpression.Summarize(gb.AccumulatorExpression)); + } + else if (nc is DryadOrderByNode) + { + DryadOrderByNode ob = (DryadOrderByNode)nc; + expression = ob.KeySelectExpression; + if (ob.ComparerExpression != null) + additional.Add(HpcLinqExpression.Summarize(ob.ComparerExpression)); + } + else if (nc is DryadWhereNode) { + expression = ((DryadWhereNode)nc).WhereExpression; + } + else if (nc is DryadSelectNode) { + DryadSelectNode s = (DryadSelectNode)nc; + expression = s.SelectExpression; + if (s.ResultSelectExpression != null) + additional.Add(HpcLinqExpression.Summarize(s.ResultSelectExpression)); + } + else if (nc is DryadAggregateNode) + { + DryadAggregateNode a = (DryadAggregateNode)nc; + expression = a.FuncLambda; + if (a.SeedExpression != null) + additional.Add(HpcLinqExpression.Summarize(a.SeedExpression)); + if (a.ResultLambda != null) + additional.Add(HpcLinqExpression.Summarize(a.ResultLambda)); + } + else if (nc is DryadPartitionOpNode) { + expression = ((DryadPartitionOpNode)nc).ControlExpression; + } + else if (nc is DryadJoinNode) + { + DryadJoinNode j = (DryadJoinNode)nc; + expression = j.OuterKeySelectorExpression; + additional.Add(HpcLinqExpression.Summarize(j.InnerKeySelectorExpression)); + additional.Add(HpcLinqExpression.Summarize(j.ResultSelectorExpression)); + if (j.ComparerExpression != null) + additional.Add(HpcLinqExpression.Summarize(j.ComparerExpression)); + } + else if (nc is DryadDistinctNode) + { + expression = ((DryadDistinctNode)nc).ComparerExpression; + } + else if (nc is DryadContainsNode) + { + DryadContainsNode c = (DryadContainsNode)nc; + expression = c.ValueExpression; + if (c.ComparerExpression != null) + additional.Add(HpcLinqExpression.Summarize(c.ComparerExpression)); + } + else if (nc is DryadBasicAggregateNode) + { + expression = ((DryadBasicAggregateNode)nc).SelectExpression; + } + else if (nc is DryadConcatNode) + // nothing to do + { + } + else if (nc is DryadSetOperationNode) + { + expression = ((DryadSetOperationNode)nc).ComparerExpression; + } + else if (nc is DryadRangePartitionNode) + { + DryadRangePartitionNode r = (DryadRangePartitionNode)nc; + expression = r.CountExpression; + // TODO: there's some other possible interesting info + } + else if (nc is DryadApplyNode) + { + expression = ((DryadApplyNode)nc).LambdaExpression; + } + + else if (nc is DryadForkNode) + { + expression = ((DryadForkNode)nc).ForkLambda; + } + + else if (nc is DryadTeeNode) + { + // nothing + } + else if (nc is DryadDynamicNode) + { + // nothing + } + else + { + expression = nc.QueryExpression; + } + + if (expression is MethodCallExpression) + { + MethodCallExpression mc = (MethodCallExpression)expression; + methodname = mc.Method.Name; // overwrite methodname + + // determine which arguments to skip + #region LINQMETHODS + switch (mc.Method.Name) + { + case "Aggregate": + case "AggregateAsQuery": + case "Select": + case "LongSelect": + case "SelectMany": + case "LongSelectMany": + case "OfType": + case "Where": + case "LongWhere": + case "First": + case "FirstOrDefault": + case "FirstAsQuery": + case "Single": + case "SingleOrDefault": + case "SingleAsQuery": + case "Last": + case "LastOrDefault": + case "LastAsQuery": + case "Distinct": + case "Any": + case "AnyAsQuery": + case "All": + case "AllAsQuery": + case "Count": + case "CountAsQuery": + case "LongCount": + case "LongCountAsQuery": + case "Sum": + case "SumAsQuery": + case "Min": + case "MinAsQuery": + case "Max": + case "MaxAsQuery": + case "Average": + case "AverageAsQuery": + case "GroupBy": + case "OrderBy": + case "OrderByDescending": + case "ThenBy": + case "ThenByDescending": + case "Take": + case "TakeWhile": + case "LongTakeWhile": + case "Skip": + case "SkipWhile": + case "LongSkipWhile": + case "Contains": + case "ContainsAsQuery": + case "Reverse": + case "Merge": + case "HashPartition": + case "RangePartition": + case "Fork": + case "ForkChoose": + case "AssumeHashPartition": + case "AssumeRangePartition": + case "AssumeOrderBy": + case "ToPartitionedTableLazy": + case "AddCacheEntry": + case "SlidingWindow": + case "SelectWithPartitionIndex": + case "ApplyWithPartitionIndex": + argsToSkip = 1; + break; + case "Join": + case "GroupJoin": + case "Concat": + case "MultiConcat": + case "Union": + case "Intersect": + case "Except": + case "SequenceEqual": + case "SequenceEqualAsQuery": + case "Zip": + argsToSkip = 2; + break; + case "Apply": + case "ApplyPerPartition": + if (mc.Arguments.Count < 3) + argsToSkip = 1; + else + argsToSkip = 2; + break; + default: + throw DryadLinqException.Create(HpcLinqErrorCode.OperatorNotSupported, + String.Format(SR.OperatorNotSupported, mc.Method.Name), + expression); + } + #endregion + + plan.Append(methodname); + plan.Append("("); + + int argno = 0; + foreach (var arg in mc.Arguments) + { + argno++; + if (argno <= argsToSkip) continue; + if (argno > argsToSkip + 1) + { + plan.Append(","); + } + plan.Append(HpcLinqExpression.Summarize(arg)); + } + plan.AppendLine(")"); + } + else + { + // expression is not methodcall + plan.Append(methodname); + plan.Append("("); + if (expression != null) + { + plan.Append(HpcLinqExpression.Summarize(expression)); + } + foreach (string e in additional) + { + plan.Append(","); + plan.Append(e); + } + plan.AppendLine(")"); + } + } + } + + /// + /// Explain a query plan in terms of elementary operations. + /// + /// Query generator. + /// A string explaining the plan. + internal string Explain(HpcLinqQueryGen gen) + { + StringBuilder plan = new StringBuilder(); + gen.CodeGenVisit(); + this.CodeShowVisit(plan, gen.QueryPlan()); + return plan.ToString(); + } + } + + /// + /// Summary information about a job query plan. + /// + internal class DryadLinqJobStaticPlan + { + /// + /// Connection between two stages. + /// + public class Connection + { + /// + /// Arity of connection. + /// + public enum ConnectionType + { + /// + /// Point-to-point connection between two stages. + /// + PointToPoint, + /// + /// Cross-product connection between two stages. + /// + AllToAll + }; + + /// + /// Type of channel backing the connection. + /// + public enum ChannelType + { + /// + /// Persistent file. + /// + DiskFile, + /// + /// In-memory fifo. + /// + Fifo, + /// + /// TCP pipe. + /// + TCP + } + + /// + /// Stage originating the connection. + /// + public Stage From { internal set; get; } + /// + /// Stage terminating the connection. + /// + public Stage To { internal set; get; } + /// + /// Type of connection. + /// + public ConnectionType Arity { get; internal set; } + /// + /// Type of channel backing the connection. + /// + public ChannelType ChannelKind { get; internal set; } + /// + /// Dynamic manager associated with the connection. + /// + public string ConnectionManager { get; internal set; } + + /// + /// Color used to represent the connection. + /// + /// A string describing the color. + public string Color() + { + switch (this.ChannelKind) + { + case ChannelType.DiskFile: + return "black"; + case ChannelType.Fifo: + return "red"; + case ChannelType.TCP: + return "yellow"; + default: + throw new Exception(String.Format(SR.UnknownChannelType, this.ChannelKind.ToString())); + } + } + } + + /// + /// Per-node connection information (should be per-edge...) + /// + struct ConnectionInformation + { + /// + /// Type of connection. + /// + public Connection.ConnectionType Arity { get; internal set; } + /// + /// Type of channel backing the connection. + /// + public Connection.ChannelType ChannelKind { get; internal set; } + /// + /// Dynamic manager associated with the connection. + /// + public string ConnectionManager { get; internal set; } + } + + /// + /// Information about a stage. + /// + public class Stage + { + /// + /// Stage name. + /// + public string Name { get; internal set; } + /// + /// Code executed in the stage. + /// + public string[] Code { get; internal set; } + /// + /// DryadLINQ operator implemented by the stage. + /// + public string Operator { get; internal set; } + /// + /// Number of vertices in stage. + /// + public int Replication { get; internal set; } + /// + /// Unique identifier. + /// + public int Id { get; set; } + + /// + /// True if the stage is an input. + /// + public bool IsInput { get; internal set; } + /// + /// True if the stage is an output. + /// + public bool IsOutput { get; internal set; } + /// + /// True if the stage is a tee. + /// + public bool IsTee { get; internal set; } + /// + /// True if the stage is a concatenation. + /// + public bool IsConcat { get; internal set; } + /// + /// True if the stage is virtual (no real vertices synthesized). + /// + public bool IsVirtual { get { return this.IsInput || this.IsOutput || this.IsTee || this.IsConcat; } } + /// + /// Only defined for tables. + /// + public string Uri { get; internal set; } + /// + /// Only defined for tables. + /// + public string UriType { get; internal set; } + } + + /// + /// File containing the plan. + /// + string xmlPlanFile; + /// + /// Map from stage id to stage. + /// + Dictionary stages; + /// + /// List of inter-stage connections in the plan. + /// + List connections; + /// + /// Store here per-node connection information (map from node id). + /// + Dictionary perNodeConnectionInfo; + + /// + /// Create a dryadlinq job plan starting from an xml plan file. + /// + /// Plan file to parse. + public DryadLinqJobStaticPlan(string xmlPlanFile) + { + this.stages = new Dictionary(); + this.connections = new List(); + this.perNodeConnectionInfo = new Dictionary(); + this.xmlPlanFile = xmlPlanFile; + this.ParseQueryPlan(); + } + + /// + /// Parse an XML query plan and represent that information. + /// + private void ParseQueryPlan() + { + if (!File.Exists(this.xmlPlanFile)) + throw new Exception(String.Format( SR.CannotReadQueryPlan , this.xmlPlanFile)); + + XDocument plan = XDocument.Load(this.xmlPlanFile); + XElement query = plan.Root.Elements().Where(e => e.Name == "QueryPlan").First(); + IEnumerable vertices = query.Elements().Where(e => e.Name == "Vertex"); + + foreach (XElement v in vertices) + { + Stage stage = new Stage(); + stage.Id = int.Parse(v.Element("UniqueId").Value); + stage.Replication = int.Parse(v.Element("Partitions").Value); + stage.Operator = v.Element("Type").Value; + stage.Name = v.Element("Name").Value; + { + string code = v.Element("Explain").Value; + stage.Code = code.Split('\n'). + Skip(1). // drop stage name + Select(l => l.Trim()). // remove leading tab + ToArray(); + } + this.stages.Add(stage.Id, stage); + + { + // These should be connection attributes, not stage attributes. + string cht = v.Element("ChannelType").Value; + string connectionManager = v.Element("DynamicManager").Element("Type").Value; + string connection = v.Element("ConnectionOperator").Value; + ConnectionInformation info = new ConnectionInformation(); + info.ConnectionManager = connectionManager; + switch (connection) + { + case "Pointwise": + info.Arity = Connection.ConnectionType.PointToPoint; + break; + case "CrossProduct": + info.Arity = Connection.ConnectionType.AllToAll; + break; + default: + throw new Exception(String.Format( SR.UnknownConnectionType , connection)); + } + switch (cht) + { + case "DiskFile": + info.ChannelKind = Connection.ChannelType.DiskFile; + break; + case "TCPPipe": + info.ChannelKind = Connection.ChannelType.TCP; + break; + case "MemoryFIFO": + info.ChannelKind = Connection.ChannelType.Fifo; + break; + default: + throw new Exception(String.Format( SR.UnknownChannelType2 , cht)); + } + this.perNodeConnectionInfo.Add(stage.Id, info); + } + + switch (stage.Operator) + { + case "InputTable": + stage.IsInput = true; + stage.UriType = v.Element("StorageSet").Element("Type").Value; + stage.Uri = v.Element("StorageSet").Element("SourceURI").Value; + break; + case "OutputTable": + stage.IsOutput = true; + stage.UriType = v.Element("StorageSet").Element("Type").Value; + stage.Uri = v.Element("StorageSet").Element("SinkURI").Value; + break; + case "Tee": + stage.IsTee = true; + break; + case "Concat": + stage.IsConcat = true; + break; + default: + break; + } + + if (v.Elements("Children").Count() == 0) + continue; + + bool first = true; + IEnumerable children = v.Element("Children").Elements().Where(e => e.Name == "Child"); + foreach (XElement child in children) + { + // This code parallels the graphbuilder.cpp for XmlExecHost + Connection conn = new Connection(); + int fromid = int.Parse(child.Element("UniqueId").Value); + ConnectionInformation fromConnectionInformation = this.perNodeConnectionInfo[fromid]; + Stage from = this.stages[fromid]; + conn.From = from; + conn.To = stage; + conn.ChannelKind = fromConnectionInformation.ChannelKind; + + switch (fromConnectionInformation.ConnectionManager) + { + case "FullAggregator": + case "HashDistributor": + case "RangeDistributor": + // Ignore except first child + if (first) + { + first = false; + conn.ConnectionManager = fromConnectionInformation.ConnectionManager; + } + else + { + conn.ConnectionManager = ""; + } + break; + case "PartialAggregator": + case "Broadcast": + // All children have the same connection manager + conn.ConnectionManager = fromConnectionInformation.ConnectionManager; + break; + case "Splitter": + // The connection manager depends on the number of children + if (first) + { + first = false; + if (children.Count() == 1) + conn.ConnectionManager = fromConnectionInformation.ConnectionManager; + else + conn.ConnectionManager = "SemiSplitter"; + } + else + { + conn.ConnectionManager = ""; + } + break; + case "None": + case "": + break; + } + + + conn.Arity = fromConnectionInformation.Arity; + + this.connections.Add(conn); + } + } + } + + /// + /// Find the stage given the stage id as a string. + /// + /// Stage id. + /// A handle to the stage with the specified static Id. + public Stage GetStageByStaticId(string stageId) + { + int id = int.Parse(stageId); + return this.stages[id]; + } + + /// + /// Find the stage given the stage name. + /// + /// Name of stage to return. + /// The stage with the given name or null. + public Stage GetStageByName(string name) + { + foreach (Stage s in this.stages.Values) + { + if (s.Name.Equals(name)) + return s; + } + + return null; + } + + /// + /// The list of all stages in the plan. + /// + /// An iterator over the list of stages. + public IEnumerable GetAllStages() + { + return this.stages.Values; + } + + /// + /// The list of all connections in the plan. + /// + /// An iterator over a list of connections. + public IEnumerable GetAllConnections() + { + return this.connections; + } + } +} diff --git a/LinqToDryad/DryadQueryGen.cs b/LinqToDryad/DryadQueryGen.cs new file mode 100644 index 0000000..5da94a2 --- /dev/null +++ b/LinqToDryad/DryadQueryGen.cs @@ -0,0 +1,4658 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Text; +using System.IO; +using System.Reflection; +using System.Threading; +using System.Linq; +using System.Linq.Expressions; +using System.CodeDom; +using System.Xml; +using System.Diagnostics; + +using Microsoft.Research.DryadLinq.Internal; +using Microsoft.Hpc.Dryad; +using System.Globalization; + +namespace Microsoft.Research.DryadLinq +{ + /// + /// This class handles code generation for multiple queries and execution invocation. + /// + internal class HpcLinqQueryGen + { + internal const int StartPhaseId = -4; + private const string DryadLinqProgram = "DryadLinqProgram__.xml"; + private const string QueryGraph = "QueryGraph__.txt"; + private const string VertexHostExe = "VertexHost.exe"; + + private static int s_uniqueProgId = -1; + private static object s_queryGenLock = new Object(); + + private int m_currentPhaseId = StartPhaseId; + private int m_nextVertexId = 0; + private bool m_codeGenDone = false; + private HpcLinqCodeGen m_codeGen; + private Expression[] m_queryExprs; + private DryadQueryNode[] m_queryPlan1; + private DryadQueryNode[] m_queryPlan2; + private DryadQueryNode[] m_queryPlan3; + private string[] m_outputTableUris; + private bool[] m_isTempOutput; + private string[] m_outputDatapaths; + private Type[] m_outputTypes; + private QueryNodeInfo[] m_queryNodeInfos; + private DryadLinqQuery[] m_outputTables; + private string m_DryadLinqProgram; + private string m_queryGraph; + private Dictionary m_exprNodeInfoMap; + private Dictionary m_referencedQueryMap; + private Dictionary m_inputUriMap; + private Dictionary m_outputUriMap; + private JobExecutor queryExecutor; + private HpcLinqContext m_context; + + internal HpcLinqQueryGen(HpcLinqContext context, + VertexCodeGen vertexCodeGen, + Expression queryExpr, + string tableUri, + bool isTempOutput) + { + this.m_queryExprs = new Expression[] { queryExpr }; + string fullTableUri = tableUri; + this.m_outputTableUris = new string[] { fullTableUri }; + this.m_isTempOutput = new bool[] { isTempOutput }; + this.m_context = context; + this.Initialize(vertexCodeGen); + } + + // This constructor is specifically to support Materialize() calls. + // it assumes that the Expressions all terminate with a ToDsc node. + internal HpcLinqQueryGen(HpcLinqContext context, + VertexCodeGen vertexCodeGen, + Expression[] qlist) + { + this.m_queryExprs = new Expression[qlist.Length]; + this.m_outputTableUris = new string[qlist.Length]; + this.m_isTempOutput = new bool[qlist.Length]; + this.m_context = context; + for (int i = 0; i < this.m_queryExprs.Length; i++) + { + MethodCallExpression mcExpr = (MethodCallExpression)qlist[i]; + string tableUri; + this.m_queryExprs[i] = mcExpr.Arguments[0]; + + + //this block supports the scenario: q-nonToDsc + if (mcExpr.Method.Name == ReflectedNames.DryadLinqIQueryable_AnonymousDscPlaceholder) + { + ExpressionSimplifier e1 = new ExpressionSimplifier(); + tableUri = e1.Eval(mcExpr.Arguments[1]); + this.m_isTempOutput[i] = true; + } + + //this block supports the scenario: q.ToDsc() + else if (mcExpr.Method.Name == ReflectedNames.DryadLinqIQueryable_ToDscWorker) + { + DscService dsc = context.DscService; + ExpressionSimplifier e2 = new ExpressionSimplifier(); + string streamName = e2.Eval(mcExpr.Arguments[2]); + + tableUri = DataPath.MakeDscStreamUri(dsc, streamName); + this.m_isTempOutput[i] = false; + } + + //this block supports the scenario: q.ToHdfs() + else if (mcExpr.Method.Name == ReflectedNames.DryadLinqIQueryable_ToHdfsWorker) + { + string hdfsHeadNode = context.HdfsService; + ExpressionSimplifier e2 = new ExpressionSimplifier(); + string streamName = e2.Eval(mcExpr.Arguments[2]); + + tableUri = DataPath.MakeHdfsStreamUri(hdfsHeadNode, streamName); + this.m_isTempOutput[i] = false; + } + else { + throw new InvalidOperationException(); // should not occur. + } + + this.m_outputTableUris[i] = tableUri; + + } + this.Initialize(vertexCodeGen); + } + + private void Initialize(VertexCodeGen vertexCodeGen) + { + this.m_codeGen = new HpcLinqCodeGen(this.m_context, vertexCodeGen); + this.m_queryPlan1 = null; + this.m_queryPlan2 = null; + this.m_queryPlan3 = null; + this.m_DryadLinqProgram = null; + this.m_queryPlan1 = null; + this.m_exprNodeInfoMap = new Dictionary(); + this.m_referencedQueryMap = new Dictionary(); + this.m_inputUriMap = new Dictionary(); + this.m_outputUriMap = new Dictionary(); + this.queryExecutor = new JobExecutor(this.m_context); + + // Initialize the data structures for the output tables + this.m_outputTypes = new Type[this.m_queryExprs.Length]; + this.m_outputDatapaths = new string[this.m_queryExprs.Length]; + this.m_queryNodeInfos = new QueryNodeInfo[this.m_queryExprs.Length]; + + for (int i = 0; i < this.m_queryExprs.Length; i++) + { + this.m_queryNodeInfos[i] = this.BuildNodeInfoGraph(this.m_queryExprs[i]); + this.m_queryNodeInfos[i] = new QueryNodeInfo(this.m_queryExprs[i], false, this.m_queryNodeInfos[i]); + + this.m_outputDatapaths[i] = DataPath.GetDataPath(this.m_outputTableUris[i]); + + Dictionary args = DataPath.GetArguments(this.m_outputTableUris[i]); + + if (!(DataPath.IsDsc(this.m_outputDatapaths[i]) || DataPath.IsHdfs(this.m_outputDatapaths[i]))) + { + throw new DryadLinqException(HpcLinqErrorCode.UnrecognizedDataSource, + String.Format(SR.UnrecognizedDataSource, this.m_outputTableUris[i])); + } + } + } + + internal HpcLinqContext Context + { + get { return this.m_context; } + } + + internal HpcLinqCodeGen CodeGen + { + get { return this.m_codeGen; } + } + + internal Dictionary ReferencedQueryMap + { + get { return this.m_referencedQueryMap; } + } + + /// + /// Probes the running assembly and its dependencies, and throws an exception + /// if any of them is targetted to x86. Returns silently if all managed assemblies + /// in the list are x64 or AnyCPU. Native or unloadable binaries are ignored. + /// + private void CheckAssemblyArchitectures() + { + // First create the list of assemblies to probe + // We use a stripped down version of the resource discovery logic in GenerateDryadProgram. + // i) We start with the same set of currently loaded binaries + // (== client app + its dependencies + dynamically loaded assemblies) + // ii) We take out user specified resource exclusions (this enables a workaround + // for x86 assemblies that must be loaded on the client side, think UI plugins, + // but aren't needed by the vertex code) + // + // The difference is we don't add the vertex DLL, or user resources. + List resourcesToExclude = new List(); + resourcesToExclude.AddRange(this.m_context.Configuration.ResourcesToRemove.Select(x => x.ToLower(CultureInfo.InvariantCulture))); + + IEnumerable loadedAssemblyPaths = TypeSystem.GetLoadedNonSystemAssemblyPaths().Select(x => x.ToLower(CultureInfo.InvariantCulture)); + var asembliesToCheck = loadedAssemblyPaths.Where(path => !resourcesToExclude.Contains(path)); + foreach (string path in asembliesToCheck) + { + Assembly asm = null; + try + { + asm = Assembly.ReflectionOnlyLoadFrom(path); + } + catch + { + // silently ignore load errors + } + + if (asm != null) + { + PortableExecutableKinds peKind; + ImageFileMachine machine; + asm.ManifestModule.GetPEKind(out peKind, out machine); + + // machine will always be reported as "I386" for both true x86 and AnyCPU assemblies + // peKind will have the "Required32Bit" flag set only for x86 binaries. Therefore we use peKind to make our decision. + if ((peKind & PortableExecutableKinds.Required32Bit) != 0) + { + string offendingAssemblyName = Path.GetFileName(path); + throw new DryadLinqException(HpcLinqErrorCode.Binaries32BitNotSupported, + String.Format(SR.Binaries32BitNotSupported, offendingAssemblyName)); + } + } + } + } + + internal DryadLinqQuery[] InvokeDryad() + { + lock (s_queryGenLock) + { + this.GenerateDryadProgram(); + + CheckAssemblyArchitectures(); + + // Invoke the background execution + this.queryExecutor.ExecuteAsync(this.m_DryadLinqProgram); + + // Create the resulting partitioned tables + this.m_outputTables = new DryadLinqQuery[this.m_outputTableUris.Length]; + MethodInfo minfo = typeof(Microsoft.Research.DryadLinq.DataProvider).GetMethod( + ReflectedNames.DataProvider_GetPartitionedTable, + BindingFlags.NonPublic | BindingFlags.Static); + for (int i = 0; i < this.m_outputTableUris.Length; i++) + { + MethodInfo minfo1 = minfo.MakeGenericMethod(this.m_outputTypes[i]); + object[] args = new object[] { this.m_context, this.m_outputTableUris[i] }; + this.m_outputTables[i] = (DryadLinqQuery)minfo1.Invoke(null, args); + this.m_outputTables[i].QueryExecutor = this.queryExecutor; + } + return this.m_outputTables; + } + } + + // Phase 1 of the query optimization + internal void GenerateQueryPlanPhase1() + { + if (this.m_queryPlan1 != null) return; + + // Apply some simple rewrite rules + SimpleRewriter rewriter = new SimpleRewriter(this.m_exprNodeInfoMap.Values.ToList()); + rewriter.Rewrite(); + + // Generate the query plan of phase1 + var referencedNodes = this.m_referencedQueryMap.Values; + this.m_queryPlan1 = new DryadQueryNode[this.m_queryExprs.Length + referencedNodes.Count]; + for (int i = 0; i < this.m_queryExprs.Length; i++) + { + this.m_queryPlan1[i] = this.Visit(this.m_queryNodeInfos[i].children[0].child); + } + int idx = this.m_queryExprs.Length; + foreach (QueryNodeInfo nodeInfo in referencedNodes) + { + // Add a Tee'd Merge + this.m_queryPlan1[idx] = this.Visit(nodeInfo.children[0].child); + DryadQueryNode mergeNode = new DryadMergeNode(true, false, nodeInfo.queryExpression, + this.m_queryPlan1[idx]); + this.m_queryPlan1[idx] = new DryadTeeNode(mergeNode.OutputTypes[0], true, + mergeNode.QueryExpression, mergeNode); + nodeInfo.queryNode = this.m_queryPlan1[idx]; + idx++; + } + + // Finally, add the output nodes. + Dictionary forkCounts = new Dictionary(); + for (int i = 0; i < this.m_queryExprs.Length; i++) + { + DryadQueryNode queryNode = this.m_queryPlan1[i]; + int cnt; + if (!forkCounts.TryGetValue(queryNode, out cnt)) + { + cnt = queryNode.Parents.Count; + } + forkCounts[queryNode] = cnt + 1; + } + + for (int i = 0; i < this.m_queryExprs.Length; i++) + { + HpcClientSideLog.Add("Query " + i + " Output: " + this.m_outputDatapaths[i]); + + DryadQueryNode queryNode = this.m_queryPlan1[i]; + if (TypeSystem.IsAnonymousType(queryNode.OutputTypes[0])) + { + throw new DryadLinqException(HpcLinqErrorCode.OutputTypeCannotBeAnonymous, + SR.OutputTypeCannotBeAnonymous); + } + + // Add dummy Apply to make Dryad happy (it doesn't like to hook inputs straight to outputs) + if ((queryNode is DryadInputNode) || (forkCounts[queryNode] > 1)) + { + // Add a dummy Apply + Type paramType = typeof(IEnumerable<>).MakeGenericType(queryNode.OutputTypes[0]); + ParameterExpression param = Expression.Parameter(paramType, "x"); + Type type = typeof(Func<,>).MakeGenericType(paramType, paramType); + LambdaExpression applyExpr = Expression.Lambda(type, param, param); + DryadQueryNode applyNode = new DryadApplyNode(applyExpr, this.m_queryExprs[i], queryNode); + applyNode.OutputDataSetInfo = queryNode.OutputDataSetInfo; + queryNode = applyNode; + } + + if (queryNode is DryadConcatNode) + { + // Again, we add dummy Apply in certain cases to make Dryad happy + ((DryadConcatNode)queryNode).FixInputs(); + } + + // Add the output node + DscCompressionScheme outputScheme = this.m_context.Configuration.OutputDataCompressionScheme; + DryadOutputNode outputNode = new DryadOutputNode(this.m_context, + this.m_outputDatapaths[i], + this.m_isTempOutput[i], + outputScheme, + this.m_queryExprs[i], + queryNode); + + this.m_queryPlan1[i] = outputNode; + + if (this.m_outputUriMap.ContainsKey(this.m_outputDatapaths[i].ToLower())) + { + throw new DryadLinqException(HpcLinqErrorCode.MultipleOutputsWithSameDscUri, + String.Format(SR.MultipleOutputsWithSameDscUri, this.m_outputDatapaths[i])); + } + + this.m_outputUriMap.Add(this.m_outputDatapaths[i].ToLower(), outputNode); + this.m_outputTypes[i] = this.m_queryPlan1[i].OutputTypes[0]; + + // Remove useless Tees to make Dryad happy + if ((queryNode is DryadTeeNode) && (forkCounts[queryNode] == 1)) + { + DryadQueryNode teeChild = queryNode.Children[0]; + teeChild.UpdateParent(queryNode, outputNode); + outputNode.UpdateChildren(queryNode, teeChild); + } + } + } + + // Phase 2 of the query optimization + internal DryadQueryNode[] GenerateQueryPlanPhase2() + { + if (this.m_queryPlan2 == null) + { + this.GenerateQueryPlanPhase1(); + this.m_queryPlan2 = new DryadQueryNode[this.m_queryPlan1.Length]; + for (int i = 0; i < this.m_queryPlan1.Length; i++) + { + this.m_queryPlan2[i] = this.VisitPhase2(this.m_queryPlan1[i]); + } + this.m_currentPhaseId++; + } + return this.m_queryPlan2; + } + + private DryadQueryNode VisitPhase2(DryadQueryNode node) + { + DryadQueryNode resNode = node; + if (node.m_uniqueId == this.m_currentPhaseId) + { + if (node is DryadForkNode) + { + // For now, we require every branch of a fork be used: + DryadForkNode forkNode = (DryadForkNode)node; + for (int i = 0; i < forkNode.Parents.Count; i++) + { + if ((forkNode.Parents[i] is DryadTeeNode) && + (forkNode.Parents[i].Parents.Count == 0)) + { + throw DryadLinqException.Create(HpcLinqErrorCode.BranchOfForkNotUsed, + string.Format(SR.BranchOfForkNotUsed, i), + node.QueryExpression); + } + } + } + + resNode = node.SuperNode; + if (resNode == null) + { + for (int i = 0; i < node.Children.Length; i++) + { + node.Children[i] = this.VisitPhase2(node.Children[i]); + } + resNode = node.PipelineReduce(); + resNode.m_uniqueId++; + + // Insert a Tee node if needed: + DryadQueryNode outputNode = resNode.OutputNode; + if (outputNode.IsForked && + !(outputNode is DryadForkNode) && + !(outputNode is DryadTeeNode)) + { + resNode = resNode.InsertTee(true); + } + } + } + return resNode; + } + + // Phase 3 of the query optimization + internal DryadQueryNode[] GenerateQueryPlanPhase3() + { + if (this.m_queryPlan3 == null) + { + this.GenerateQueryPlanPhase2(); + this.m_queryPlan3 = this.m_queryPlan2; + for (int i = 0; i < this.m_queryPlan2.Length; i++) + { + this.VisitPhase3(this.m_queryPlan2[i]); + } + this.m_currentPhaseId++; + } + return this.m_queryPlan3; + } + + private void VisitPhase3(DryadQueryNode node) + { + if (node.m_uniqueId == this.m_currentPhaseId) + { + node.m_uniqueId++; + + // Remove some useless Tee nodes + foreach (DryadQueryNode child in node.Children) + { + if ((child is DryadTeeNode) && !child.IsForked) + { + DryadQueryNode teeChild = child.Children[0]; + teeChild.UpdateParent(child, node); + node.UpdateChildren(child, teeChild); + } + } + + // Remove some useless Merge nodes + if ((node is DryadMergeNode) && + !node.IsForked && + !(node.Parents[0] is DryadOutputNode) && + !node.Children[0].IsDynamic && + node.Children[0].PartitionCount == 1) + { + node.Children[0].UpdateParent(node, node.Parents[0]); + node.Parents[0].UpdateChildren(node, node.Children[0]); + } + + // Add dynamic managers for tee nodes. + if ((StaticConfig.DynamicOptLevel & StaticConfig.NoDynamicOpt) != 0 && + node is DryadTeeNode && + node.DynamicManager.ManagerType == DynamicManagerType.None) + { + // insert a dynamic broadcast manager on Tee + node.DynamicManager = DynamicManager.Broadcast; + } + + // Recurse on the children of node + foreach (DryadQueryNode child in node.Children) + { + this.VisitPhase3(child); + } + } + } + + // The main method that generates the query plan. + internal DryadQueryNode[] QueryPlan() + { + return this.GenerateQueryPlanPhase3(); + } + + // Generate the vertex code for all the query nodes. + internal void CodeGenVisit() + { + if (!this.m_codeGenDone) + { + DryadQueryNode[] optimizedPlan = this.QueryPlan(); + + // make sure none of the outputs share a URI with inputs + foreach (var kvp in this.m_outputUriMap) + { + string outputPath = kvp.Key; + if(m_inputUriMap.ContainsKey(outputPath)) + { + throw new DryadLinqException(HpcLinqErrorCode.OutputUriAlsoQueryInput, + String.Format(SR.OutputUriAlsoQueryInput, outputPath)); + } + } + + foreach (DryadQueryNode node in optimizedPlan) + { + this.CodeGenVisit(node); + } + this.m_currentPhaseId++; + this.m_codeGenDone = true; + } + } + + private void CodeGenVisit(DryadQueryNode node) + { + if (node.m_uniqueId == this.m_currentPhaseId) + { + node.m_uniqueId++; + + // We process the types first so that children will also know about all + // proxies/mappings that should be used. + node.CreateCodeAndMappingsForVertexTypes(false); + + // Recurse on the children + foreach (DryadQueryNode child in node.Children) + { + this.CodeGenVisit(child); + } + + if (node.NodeType == QueryNodeType.InputTable) + { + // not used as a vertex + string t = ((DryadInputNode)node).Table.DataSourceUri; + int index = t.LastIndexOf('/'); + int bk = t.LastIndexOf('\\'); + if (index < bk) index = bk; + node.m_vertexEntryMethod = t.Substring(index + 1); + } + else if (node.NodeType == QueryNodeType.OutputTable) + { + // not used as a vertex + string t = ((DryadOutputNode)node).MetaDataUri; + int index = t.LastIndexOf('/'); + int bk = t.LastIndexOf('\\'); + if (index < bk) index = bk; + int len = Math.Min(8, t.Length - index - 1); + node.m_vertexEntryMethod = t.Substring(index + 1, len); + } + else if (node.NodeType == QueryNodeType.Tee) + { + // not used as a vertex + node.m_vertexEntryMethod = HpcLinqCodeGen.MakeUniqueName("Tee"); + // broadcast manager code generation + if (node.DynamicManager.ManagerType != DynamicManagerType.None) + { + node.DynamicManager.CreateVertexCode(); + } + } + else if (node.NodeType == QueryNodeType.Concat) + { + // not used as a vertex + node.m_vertexEntryMethod = HpcLinqCodeGen.MakeUniqueName("Concat"); + } + else + { + CodeMemberMethod vertexMethod = this.m_codeGen.AddVertexMethod(node); + node.m_vertexEntryMethod = vertexMethod.Name; + node.DynamicManager.CreateVertexCode(); + } + } + } + + // Assign unique ids to all the query nodes + private void AssignUniqueId() + { + if (this.m_currentPhaseId != -1) + { + //@@TODO: this should not be reachable. could change to Assert/InvalidOpEx + throw new DryadLinqException(HpcLinqErrorCode.Internal, + "Internal error: Optimization phase should be -1, not " + + this.m_currentPhaseId); + } + foreach (DryadQueryNode node in this.QueryPlan()) + { + this.AssignUniqueId(node); + } + } + + private void AssignUniqueId(DryadQueryNode node) + { + if (node.m_uniqueId == this.m_currentPhaseId) + { + foreach (Pair refChild in node.GetReferencedQueries()) + { + this.AssignUniqueId(refChild.Value); + } + foreach (DryadQueryNode child in node.Children) + { + this.AssignUniqueId(child); + } + if (node.m_uniqueId == this.m_currentPhaseId) + { + node.m_uniqueId = this.m_nextVertexId++; + + if (node.OutputNode is DryadForkNode) + { + foreach (DryadQueryNode pnode in node.Parents) + { + if (pnode.m_uniqueId == this.m_currentPhaseId) + { + pnode.m_uniqueId = this.m_nextVertexId++; + } + } + } + } + } + } + + private void CreateQueryPlan(XmlDocument queryDoc, XmlElement queryPlan) + { + this.AssignUniqueId(); + + HashSet seen = new HashSet(); + foreach (DryadQueryNode node in this.QueryPlan()) + { + node.AddToQueryPlan(queryDoc, queryPlan, seen); + } + } + + /// + /// Find the pdb file associated with a given filename. + /// + /// Filename with debugging information. + /// The associated pdb. + internal static string FindPDB(string filename) + { + string basename = Path.GetFileNameWithoutExtension(filename); + string directory = Path.GetDirectoryName(filename); + string pdbname = directory + Path.DirectorySeparatorChar + basename + ".pdb"; + return (File.Exists(pdbname)) ? pdbname : null; + } + + /// + /// Add a resource to the Xml plan. + /// + /// Document holding the xml plan. + /// Parent node. + /// Resource to add. + /// Handle to the inserted node. + private void AddResourceToPlan(XmlDocument queryDoc, + XmlElement parent, + string resource, + IEnumerable resourcesToExclude) + { + AddResourceToPlan_Core(queryDoc, parent, resource, resourcesToExclude); + + if (this.m_context.Configuration.CompileForVertexDebugging) + { + string pdb = FindPDB(resource); + if (pdb != null) + { + AddResourceToPlan_Core(queryDoc, parent, pdb, resourcesToExclude); + } + } + } + + // Add a resource to the plan unless it was specifically in the exclusions list. + private void AddResourceToPlan_Core(XmlDocument queryDoc, + XmlElement parent, + string resource, + IEnumerable resourcesToExclude) + { + if (resourcesToExclude.Contains(resource, StringComparer.OrdinalIgnoreCase)) + { + return; + } + XmlElement resourceElem = queryDoc.CreateElement("Resource"); + resourceElem.InnerText = queryExecutor.AddResource(resource); + parent.AppendChild(resourceElem); + } + + private void GenerateAppConfigResource(string appConfigPath) + { + // Generates an app config XML for the VertexHost which + // + // 1) specifies the server GC mode + // + // 2) requests a .NET runtime version equal to that of the client + // submitting the job (if the client has a .NET version higher than 3.5) + // also specifies as a fallback + // in case the cluster nodes don't have the client's newer .Net version + // (v2.0.50727 corresponds to .Net 3.5, both Sp1 and Sp2 pointing to this + // version ID, since it's actually the underlying CLR's version). This rule + // only kicks in if the query config has the MatchClientNetFrameworkVersion flag set. + // + // 3) disables Authenticode checks for the VH by specifying + // . This is necessary + // because some cluster nodes don't have an Internet connection, in which + // case which the Authenticode check leads to a 20sec delay during process startup. + // + // NOTE: useLegacyV2RuntimeActivationPolicy="true" is needed becuase the VH binary is mixed mode. + string clientVersionString = ""; + if (Environment.Version.Major > 2 && this.m_context.Configuration.MatchClientNetFrameworkVersion) + { + clientVersionString = String.Format(CultureInfo.InvariantCulture, + @" ", + Environment.Version.Major, + Environment.Version.Minor); + // NOTE: We use the "v4.0" syntax instead of the longer "v4.0.30319" because as of .NET4 + // the format of this app config tag has been simplified to "major.minor" + } + + string appConfigBody = +@" + + + + + + + + +"; + + appConfigBody += clientVersionString; // add the client specific version string if we generated one. + appConfigBody += +@" + + +"; + + File.WriteAllText(appConfigPath, appConfigBody); + } + + /// + /// Generate the executable code. + /// + /// The path to the queryPlanXml location + internal string GenerateDryadProgram() + { + HpcLinqObjectStore.Clear(); + + // Any resource that we try to add will be tested against this list. + IEnumerable resourcesToExclude = this.m_context.Configuration.ResourcesToRemove; + + // BuildDryadLinqAssembly: + // 1. Performs query optimizations + // 2. Generate vertex code for all the nodes in the query plan + this.m_codeGen.BuildDryadLinqAssembly(this); + + DryadQueryExplain explain = new DryadQueryExplain(); + HpcClientSideLog.Add("{0}", explain.Explain(this)); + + // Finally, write out the query plan. + if (this.m_DryadLinqProgram == null) + { + int progId = Interlocked.Increment(ref s_uniqueProgId); + this.m_DryadLinqProgram = HpcLinqCodeGen.GetPathForGeneratedFile(DryadLinqProgram, progId); + this.m_queryGraph = HpcLinqCodeGen.GetPathForGeneratedFile(QueryGraph, progId); + } + + XmlDocument queryDoc = new XmlDocument(); + queryDoc.LoadXml(""); + + // Write the assembly version information + Assembly dryadlinqassembly = Assembly.GetExecutingAssembly(); + AssemblyName asmName = dryadlinqassembly.GetName(); + XmlElement elem = queryDoc.CreateElement("DryadLinqVersion"); + elem.InnerText = asmName.Version.ToString(); + queryDoc.DocumentElement.AppendChild(elem); + + // Add the cluster name element + elem = queryDoc.CreateElement("ClusterName"); + elem.InnerText = this.m_context.Configuration.HeadNode; + queryDoc.DocumentElement.AppendChild(elem); + + // Add the minimum number of nodes + int minComputeNodes = 1; + if (this.m_context.Configuration.JobMinNodes.HasValue) + { + minComputeNodes = this.m_context.Configuration.JobMinNodes.Value; + } + elem = queryDoc.CreateElement("MinimumComputeNodes"); + elem.InnerText = minComputeNodes.ToString(); + queryDoc.DocumentElement.AppendChild(elem); + + // Add the maximum number of nodes + int maxComputeNodes = -1; + if (this.m_context.Configuration.JobMaxNodes.HasValue) + { + maxComputeNodes = this.m_context.Configuration.JobMaxNodes.Value; + } + elem = queryDoc.CreateElement("MaximumComputeNodes"); + elem.InnerText = maxComputeNodes.ToString(); + queryDoc.DocumentElement.AppendChild(elem); + + // intermediate data compression + elem = queryDoc.CreateElement("IntermediateDataCompression"); + elem.InnerText = ((int)this.m_context.Configuration.IntermediateDataCompressionScheme).ToString(); + queryDoc.DocumentElement.AppendChild(elem); + + // Speculative Duplication Node + elem = queryDoc.CreateElement("EnableSpeculativeDuplication"); + elem.InnerText = this.m_context.Configuration.EnableSpeculativeDuplication.ToString(); + queryDoc.DocumentElement.AppendChild(elem); + + // Add the visualization element + //@@TODO[p2]: remove this element from the queryXML. + elem = queryDoc.CreateElement("Visualization"); + elem.InnerText = "none"; + queryDoc.DocumentElement.AppendChild(elem); + + // Add the query name element + elem = queryDoc.CreateElement("QueryName"); + if (String.IsNullOrEmpty(this.m_context.Configuration.JobFriendlyName)) + { + elem.InnerText = asmName.Name; + } + else + { + elem.InnerText = this.m_context.Configuration.JobFriendlyName; + } + queryDoc.DocumentElement.AppendChild(elem); + + + + // Add the XmlExecHostArgs element + elem = queryDoc.CreateElement("XmlExecHostArgs"); + queryDoc.DocumentElement.AppendChild(elem); + + // Add an element for each argument + string[] args = StaticConfig.XmlExecHostArgs.Split(new char[] {' '}, StringSplitOptions.RemoveEmptyEntries); + List xmlExecResources = new List(); + for (int i = 0; i < args.Length; ++i) + { + string arg = args[i]; + XmlElement argElem = queryDoc.CreateElement("Argument"); + argElem.InnerText = arg; + elem.AppendChild(argElem); + if (arg.Equals("-bw") || arg.Equals("-r")) + { + xmlExecResources.Add(args[i+1]); + } + } + + // Add the resources element + elem = queryDoc.CreateElement("Resources"); + queryDoc.DocumentElement.AppendChild(elem); + + // Add resource item for this LINQ DLL + AddResourceToPlan(queryDoc, elem, this.m_codeGen.GetTargetLocation(), resourcesToExclude); + + // Add resource item for each loaded DLL that isn't a system DLL + IEnumerable loadedAssemblyPaths = TypeSystem.GetLoadedNonSystemAssemblyPaths(); + foreach(string assemblyPath in loadedAssemblyPaths) + { + AddResourceToPlan(queryDoc, elem, assemblyPath, resourcesToExclude); + } + + // Add the xmlExec resources + foreach (string resourcePath in xmlExecResources) + { + AddResourceToPlan(queryDoc, elem, resourcePath, resourcesToExclude); + } + + // Add codegen resources + foreach (string resourcePath in this.m_codeGen.VertexCodeGen.GetResources()) + { + AddResourceToPlan(queryDoc, elem, resourcePath, resourcesToExclude); + } + + // Create an app config file for the VertexHost process, and add it to the resources + string vertexHostAppConfigPath = HpcLinqCodeGen.GetPathForGeneratedFile(Path.ChangeExtension(VertexHostExe, "exe.config"), null); + GenerateAppConfigResource(vertexHostAppConfigPath); + AddResourceToPlan(queryDoc, elem, vertexHostAppConfigPath, resourcesToExclude); + + // Save and add the object store as a resource + if (!HpcLinqObjectStore.IsEmpty) + { + HpcLinqObjectStore.Save(); + AddResourceToPlan(queryDoc, elem, HpcLinqObjectStore.GetClientSideObjectStorePath(), resourcesToExclude); + } + + // Add resource item for user-added resources + foreach (string userResource in this.m_context.Configuration.ResourcesToAdd.Distinct(StringComparer.OrdinalIgnoreCase)) + { + AddResourceToPlan(queryDoc, elem, userResource, resourcesToExclude); + } + + // Add the query plan element + XmlElement queryPlanElem = queryDoc.CreateElement("QueryPlan"); + queryDoc.DocumentElement.AppendChild(queryPlanElem); + + // Add the query tree as a sequence of nodes + this.CreateQueryPlan(queryDoc, queryPlanElem); + + // Finally, save the DryadQuery doc to a file + queryDoc.Save(this.m_DryadLinqProgram); + + return this.m_DryadLinqProgram; + } + + private void BuildReferencedQuery(Expression expr) + { + ExpressionQuerySet querySet = new ExpressionQuerySet(); + querySet.Visit(expr); + foreach (Expression qexpr in querySet.QuerySet) + { + QueryNodeInfo nodeInfo = BuildNodeInfoGraph(qexpr); + this.m_referencedQueryMap[qexpr] = new QueryNodeInfo(qexpr, false, nodeInfo); + } + } + + private void BuildReferencedQuery(int startIdx, ReadOnlyCollection exprs) + { + ExpressionQuerySet querySet = new ExpressionQuerySet(); + for (int i = startIdx; i < exprs.Count; i++) + { + querySet.Visit(exprs[i]); + } + foreach (Expression qexpr in querySet.QuerySet) + { + QueryNodeInfo nodeInfo = BuildNodeInfoGraph(qexpr); + this.m_referencedQueryMap[qexpr] = new QueryNodeInfo(qexpr, false, nodeInfo); + } + } + + //@@TODO: document what the 'NodeInfo' and 'ReferencedQuery' system does + private QueryNodeInfo BuildNodeInfoGraph(Expression expression) + { + QueryNodeInfo resNodeInfo = null; + if (this.m_exprNodeInfoMap.TryGetValue(expression, out resNodeInfo)) + { + return resNodeInfo; + } + MethodCallExpression mcExpr = expression as MethodCallExpression; + if (mcExpr != null && mcExpr.Method.IsStatic && TypeSystem.IsQueryOperatorCall(mcExpr)) + { + switch (mcExpr.Method.Name) + { + case "Join": + case "GroupJoin": + case "Union": + case "Intersect": + case "Except": + case "Zip": + case "SequenceEqual": + case "SequenceEqualAsQuery": + { + QueryNodeInfo child1 = BuildNodeInfoGraph(mcExpr.Arguments[0]); + QueryNodeInfo child2 = BuildNodeInfoGraph(mcExpr.Arguments[1]); + resNodeInfo = new QueryNodeInfo(mcExpr, true, child1, child2); + this.BuildReferencedQuery(2, mcExpr.Arguments); + break; + } + case "Concat": + { + NewArrayExpression others = mcExpr.Arguments[1] as NewArrayExpression; + if (others == null) + { + QueryNodeInfo child1 = BuildNodeInfoGraph(mcExpr.Arguments[0]); + QueryNodeInfo child2 = BuildNodeInfoGraph(mcExpr.Arguments[1]); + resNodeInfo = new QueryNodeInfo(mcExpr, true, child1, child2); + } + else + { + QueryNodeInfo[] infos = new QueryNodeInfo[others.Expressions.Count + 1]; + infos[0] = BuildNodeInfoGraph(mcExpr.Arguments[0]); + for (int i = 0; i < others.Expressions.Count; ++i) + { + infos[i + 1] = BuildNodeInfoGraph(others.Expressions[i]); + } + resNodeInfo = new QueryNodeInfo(mcExpr, true, infos); + } + break; + } + case "Apply": + case "ApplyPerPartition": + { + QueryNodeInfo child1 = BuildNodeInfoGraph(mcExpr.Arguments[0]); + if (mcExpr.Arguments.Count == 2) + { + resNodeInfo = new QueryNodeInfo(mcExpr, true, child1); + this.BuildReferencedQuery(mcExpr.Arguments[1]); + } + else + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(mcExpr.Arguments[2]); + if (lambda.Parameters.Count == 2) + { + // Apply with two sources + QueryNodeInfo child2 = BuildNodeInfoGraph(mcExpr.Arguments[1]); + resNodeInfo = new QueryNodeInfo(mcExpr, true, child1, child2); + this.BuildReferencedQuery(mcExpr.Arguments[2]); + } + else + { + // Apply with multiple sources of the same type + NewArrayExpression others = (NewArrayExpression)mcExpr.Arguments[1]; + QueryNodeInfo[] infos = new QueryNodeInfo[others.Expressions.Count + 1]; + infos[0] = child1; + for (int i = 0; i < others.Expressions.Count; ++i) + { + infos[i + 1] = BuildNodeInfoGraph(others.Expressions[i]); + } + resNodeInfo = new QueryNodeInfo(mcExpr, true, infos); + this.BuildReferencedQuery(mcExpr.Arguments[2]); + } + } + break; + } + case "RangePartition": + { + QueryNodeInfo child1 = BuildNodeInfoGraph(mcExpr.Arguments[0]); + if (mcExpr.Arguments.Count == 6) + { + // This is a key part of handling for RangePartition( ... , IQueryable<> keysQuery, ...) + // The keys expression is established as a child[1] nodeInfo for the rangePartition node. + QueryNodeInfo child2 = BuildNodeInfoGraph(mcExpr.Arguments[2]); + resNodeInfo = new QueryNodeInfo(mcExpr, true, child1, child2); + } + else + { + resNodeInfo = new QueryNodeInfo(mcExpr, true, child1); + } + this.BuildReferencedQuery(mcExpr.Arguments[1]); + break; + } + default: + { + QueryNodeInfo child1 = BuildNodeInfoGraph(mcExpr.Arguments[0]); + resNodeInfo = new QueryNodeInfo(mcExpr, true, child1); + this.BuildReferencedQuery(1, mcExpr.Arguments); + break; + } + } + } + if (resNodeInfo == null) + { + resNodeInfo = new QueryNodeInfo(expression, false); + } + this.m_exprNodeInfoMap.Add(expression, resNodeInfo); + return resNodeInfo; + } + + private static bool IsMergeNodeNeeded(DryadQueryNode node) + { + return node.IsDynamic || node.PartitionCount > 1; + } + + internal DryadQueryNode Visit(QueryNodeInfo nodeInfo) + { + Expression expression = nodeInfo.queryExpression; + if (expression.NodeType == ExpressionType.Call) + { + MethodCallExpression mcExpr = (MethodCallExpression)expression; + if (mcExpr.Method.IsStatic && TypeSystem.IsQueryOperatorCall(mcExpr)) + { + return this.VisitQueryOperatorCall(nodeInfo); + } + + throw DryadLinqException.Create(HpcLinqErrorCode.OperatorNotSupported, + String.Format(SR.OperatorNotSupported, mcExpr.Method.Name), + expression); + } + else if (expression.NodeType == ExpressionType.Constant) + { + DryadInputNode inputNode = new DryadInputNode(this, (ConstantExpression)expression); + if (!this.m_inputUriMap.ContainsKey(inputNode.Table.DataSourceUri.ToLower())) + { + this.m_inputUriMap.Add(inputNode.Table.DataSourceUri.ToLower(), inputNode); + } + return inputNode; + } + else + { + string errMsg = "Can't handle expression of type " + expression.NodeType; + throw DryadLinqException.Create(HpcLinqErrorCode.UnsupportedExpressionsType, + String.Format(SR.UnsupportedExpressionsType,expression.NodeType), + expression); + } + } + + private static bool IsLambda(Expression expr, int n) + { + LambdaExpression lambdaExpr = HpcLinqExpression.GetLambda(expr); + return (lambdaExpr != null && lambdaExpr.Parameters.Count == n); + } + + // Checks if the child of the source node is a groupby node without any result selectors. + private static bool IsGroupByWithoutResultSelector(Expression source) + { + bool isGroupBy = false; + if (source.NodeType == ExpressionType.Call) + { + MethodCallExpression expression = (MethodCallExpression)source; + if (expression.Method.IsStatic && + TypeSystem.IsQueryOperatorCall(expression) && + expression.Method.Name == "GroupBy") + { + if (expression.Arguments.Count == 2) + { + isGroupBy = true; + } + else if (expression.Arguments.Count == 3) + { + isGroupBy = !IsLambda(expression.Arguments[2], 2); + } + else if (expression.Arguments.Count == 4) + { + isGroupBy = !(IsLambda(expression.Arguments[2], 2) || + IsLambda(expression.Arguments[3], 2)); + } + } + } + return isGroupBy; + } + + private DryadQueryNode CreateOffset(bool isLong, Expression queryExpr, DryadQueryNode child) + { + // Count node + DryadQueryNode countNode = new DryadBasicAggregateNode(null, AggregateOpType.LongCount, + true, false, queryExpr, child); + + // Apply node for x => Offsets(x) + Type paramType = typeof(IEnumerable<>).MakeGenericType(typeof(long)); + ParameterExpression param = Expression.Parameter(paramType, "x"); + MethodInfo minfo = typeof(HpcLinqEnumerable).GetMethod("Offsets"); + Expression body = Expression.Call(minfo, param, Expression.Constant(isLong, typeof(bool))); + Type type = typeof(Func<,>).MakeGenericType(param.Type, body.Type); + LambdaExpression procFunc = Expression.Lambda(type, body, param); + DryadQueryNode mergeCountNode = new DryadMergeNode(true, true, queryExpr, countNode); + DryadQueryNode offsetsNode = new DryadApplyNode(procFunc, queryExpr, mergeCountNode); + + // HashPartition + LambdaExpression keySelectExpr = IdentityFunction.Instance(typeof(IndexedValue)); + int pcount = child.OutputPartition.Count; + DryadQueryNode hdistNode = new DryadHashPartitionNode(keySelectExpr, null, null, pcount, + false, queryExpr, offsetsNode); + DryadQueryNode resNode = new DryadMergeNode(false, true, queryExpr, hdistNode); + return resNode; + } + + private DryadQueryNode PromoteConcat(QueryNodeInfo source, + DryadQueryNode sourceNode, + Func func) + { + DryadQueryNode resNode = sourceNode; + if ((resNode is DryadConcatNode) && !source.IsForked) + { + DryadQueryNode[] children = resNode.Children; + DryadQueryNode[] newChildren = new DryadQueryNode[children.Length]; + for (int i = 0; i < children.Length; i++) + { + children[i].Parents.Remove(resNode); + newChildren[i] = func(children[i]); + } + resNode = new DryadConcatNode(source.queryExpression, newChildren); + } + else + { + resNode = func(resNode); + } + return resNode; + } + + private DryadQueryNode VisitSelect(QueryNodeInfo source, + QueryNodeType nodeType, + LambdaExpression selector, + LambdaExpression resultSelector, + MethodCallExpression queryExpr) + { + DryadQueryNode selectNode; + if (selector.Type.GetGenericArguments().Length == 2) + { + // If this select's child is a groupby node, push this select into its child, if + // 1. The groupby node is not tee'd, and + // 2. The groupby node has no result selector, and + // 3. The selector is decomposable + if (!source.IsForked && + IsGroupByWithoutResultSelector(source.queryExpression) && + Decomposition.GetDecompositionInfoList(selector, m_codeGen) != null) + { + MethodCallExpression expr = (MethodCallExpression)source.queryExpression; + LambdaExpression keySelectExpr = HpcLinqExpression.GetLambda(expr.Arguments[1]); + + // Figure out elemSelectExpr and comparerExpr + LambdaExpression elemSelectExpr = null; + Expression comparerExpr = null; + if (expr.Arguments.Count == 3) + { + elemSelectExpr = HpcLinqExpression.GetLambda(expr.Arguments[2]); + if (elemSelectExpr == null) + { + comparerExpr = expr.Arguments[2]; + } + } + else if (expr.Arguments.Count == 4) + { + elemSelectExpr = HpcLinqExpression.GetLambda(expr.Arguments[2]); + comparerExpr = expr.Arguments[3]; + } + + // Construct new query expression by building result selector expression + // and pushing it to groupby node. + selectNode = VisitGroupBy(source.children[0].child, keySelectExpr, + elemSelectExpr, selector, comparerExpr, queryExpr); + if (nodeType == QueryNodeType.SelectMany) + { + Type selectorRetType = selector.Type.GetGenericArguments()[1]; + LambdaExpression id = IdentityFunction.Instance(selectorRetType); + selectNode = new DryadSelectNode(nodeType, id, resultSelector, queryExpr, selectNode); + } + } + else + { + DryadQueryNode child = this.Visit(source); + selectNode = this.PromoteConcat( + source, child, + x => new DryadSelectNode(nodeType, selector, resultSelector, queryExpr, x)); + } + } + else + { + // The "indexed" version + DryadQueryNode child = this.Visit(source); + if (!child.IsDynamic && child.OutputPartition.Count == 1) + { + selectNode = this.PromoteConcat( + source, child, + x => new DryadSelectNode(nodeType, selector, resultSelector, queryExpr, x)); + } + else + { + child.IsForked = true; + + // Create (x, y) => Select(x, y, selector) + Type ptype1 = typeof(IEnumerable<>).MakeGenericType(child.OutputTypes[0]); + Type ptype2 = typeof(IEnumerable<>).MakeGenericType(typeof(IndexedValue)); + ParameterExpression param1 = Expression.Parameter(ptype1, HpcLinqCodeGen.MakeUniqueName("x")); + ParameterExpression param2 = Expression.Parameter(ptype2, HpcLinqCodeGen.MakeUniqueName("y")); + + string methodName = queryExpr.Method.Name; + Type[] selectorTypeArgs = selector.Type.GetGenericArguments(); + Type typeArg2 = selectorTypeArgs[selectorTypeArgs.Length - 1]; + if (nodeType == QueryNodeType.SelectMany) + { + if (resultSelector != null) + { + methodName += "Result"; + } + typeArg2 = typeArg2.GetGenericArguments()[0]; + } + + string targetMethodName = methodName + "WithStartIndex"; + MethodInfo minfo = typeof(HpcLinqEnumerable).GetMethod(targetMethodName); + Expression body; + if (resultSelector == null) + { + minfo = minfo.MakeGenericMethod(child.OutputTypes[0], typeArg2); + body = Expression.Call(minfo, param1, param2, selector); + } + else + { + minfo = minfo.MakeGenericMethod(child.OutputTypes[0], typeArg2, resultSelector.Body.Type); + body = Expression.Call(minfo, param1, param2, selector, resultSelector); + } + Type type = typeof(Func<,,>).MakeGenericType(ptype1, ptype2, body.Type); + LambdaExpression procFunc = Expression.Lambda(type, body, param1, param2); + + bool isLong = methodName.StartsWith("Long", StringComparison.Ordinal); + DryadQueryNode offsetNode = this.CreateOffset(isLong, queryExpr, child); + selectNode = new DryadApplyNode(procFunc, queryExpr, child, offsetNode); + } + } + return selectNode; + } + + private DryadQueryNode VisitWhere(QueryNodeInfo source, + LambdaExpression predicate, + MethodCallExpression queryExpr) + { + DryadQueryNode child = this.Visit(source); + DryadQueryNode whereNode; + if (predicate.Type.GetGenericArguments().Length == 2 || + (!child.IsDynamic && child.OutputPartition.Count == 1)) + { + whereNode = this.PromoteConcat(source, child, x => new DryadWhereNode(predicate, queryExpr, x)); + } + else + { + // The "indexed" version + // Create (x, y) => DryadWhere(x, y, predicate) + Type ptype1 = typeof(IEnumerable<>).MakeGenericType(child.OutputTypes[0]); + Type ptype2 = typeof(IEnumerable<>).MakeGenericType(typeof(IndexedValue)); + ParameterExpression param1 = Expression.Parameter(ptype1, HpcLinqCodeGen.MakeUniqueName("x")); + ParameterExpression param2 = Expression.Parameter(ptype2, HpcLinqCodeGen.MakeUniqueName("y")); + string targetMethod = queryExpr.Method.Name + "WithStartIndex"; + MethodInfo minfo = typeof(HpcLinqEnumerable).GetMethod(targetMethod); + minfo = minfo.MakeGenericMethod(child.OutputTypes[0]); + Expression body = Expression.Call(minfo, param1, param2, predicate); + Type type = typeof(Func<,,>).MakeGenericType(ptype1, ptype2, body.Type); + LambdaExpression procFunc = Expression.Lambda(type, body, param1, param2); + + child.IsForked = true; + bool isLong = (queryExpr.Method.Name == "LongWhere"); + DryadQueryNode offsetNode = this.CreateOffset(isLong, queryExpr, child); + whereNode = new DryadApplyNode(procFunc, queryExpr, child, offsetNode); + } + return whereNode; + } + + private DryadQueryNode VisitJoin(QueryNodeInfo outerSource, + QueryNodeInfo innerSource, + QueryNodeType nodeType, + LambdaExpression outerKeySelector, + LambdaExpression innerKeySelector, + LambdaExpression resultSelector, + Expression comparerExpr, + Expression queryExpr) + { + DryadQueryNode outerChild = this.Visit(outerSource); + DryadQueryNode innerChild = this.Visit(innerSource); + DryadQueryNode joinNode = null; + + Type keyType = outerKeySelector.Type.GetGenericArguments()[1]; + if (comparerExpr == null && !TypeSystem.HasDefaultEqualityComparer(keyType)) + { + throw DryadLinqException.Create(HpcLinqErrorCode.ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable, + string.Format(SR.ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable, keyType), + queryExpr); + } + + // The comparer object: + object comparer = null; + if (comparerExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + comparer = evaluator.Eval(comparerExpr); + } + + if (outerChild.IsDynamic || innerChild.IsDynamic) + { + // Well, let us do the simplest thing for now + outerChild = new DryadHashPartitionNode(outerKeySelector, + comparerExpr, + StaticConfig.DefaultPartitionCount, + queryExpr, + outerChild); + if (IsMergeNodeNeeded(outerChild)) + { + outerChild = new DryadMergeNode(false, false, queryExpr, outerChild); + } + + innerChild = new DryadHashPartitionNode(innerKeySelector, + comparerExpr, + StaticConfig.DefaultPartitionCount, + queryExpr, + innerChild); + if (IsMergeNodeNeeded(innerChild)) + { + innerChild = new DryadMergeNode(false, false, queryExpr, innerChild); + } + + joinNode = new DryadJoinNode(nodeType, + "Hash" + nodeType, + outerKeySelector, + innerKeySelector, + resultSelector, + comparerExpr, + queryExpr, + outerChild, + innerChild); + return joinNode; + } + + + bool isOuterDescending = outerChild.OutputDataSetInfo.orderByInfo.IsDescending; + bool isInnerDescending = innerChild.OutputDataSetInfo.orderByInfo.IsDescending; + + // Partition outer and inner if needed + if (outerChild.OutputPartition.ParType == PartitionType.Range && + outerChild.OutputPartition.HasKeys && + outerChild.OutputPartition.IsPartitionedBy(outerKeySelector, comparer)) + { + if (innerChild.OutputPartition.ParType != PartitionType.Range || + !innerChild.OutputPartition.IsPartitionedBy(innerKeySelector, comparer) || + !outerChild.OutputPartition.IsSamePartition(innerChild.OutputPartition)) + { + // Range distribute inner using outer's partition. + innerChild = outerChild.OutputPartition.CreatePartitionNode(innerKeySelector, innerChild); + if (IsMergeNodeNeeded(innerChild)) + { + innerChild = new DryadMergeNode(false, false, queryExpr, innerChild); + } + } + } + else if (innerChild.OutputPartition.ParType == PartitionType.Range && + innerChild.OutputPartition.HasKeys && + innerChild.OutputPartition.IsPartitionedBy(innerKeySelector, comparer)) + { + // Range distribute outer using inner's partition. + outerChild = innerChild.OutputPartition.CreatePartitionNode(outerKeySelector, outerChild); + if (IsMergeNodeNeeded(outerChild)) + { + outerChild = new DryadMergeNode(false, false, queryExpr, outerChild); + } + } + else if (outerChild.OutputPartition.ParType == PartitionType.Hash && + outerChild.OutputPartition.IsPartitionedBy(outerKeySelector, comparer)) + { + if (innerChild.OutputPartition.ParType != PartitionType.Hash || + !innerChild.OutputPartition.IsPartitionedBy(innerKeySelector, comparer) || + !outerChild.OutputPartition.IsSamePartition(innerChild.OutputPartition)) + { + innerChild = new DryadHashPartitionNode(innerKeySelector, + comparerExpr, + outerChild.OutputPartition.Count, + queryExpr, + innerChild); + if (IsMergeNodeNeeded(innerChild)) + { + innerChild = new DryadMergeNode(false, false, queryExpr, innerChild); + } + } + } + else if (innerChild.OutputPartition.ParType == PartitionType.Hash && + innerChild.OutputPartition.IsPartitionedBy(innerKeySelector, comparer)) + { + outerChild = new DryadHashPartitionNode(outerKeySelector, + comparerExpr, + innerChild.OutputPartition.Count, + queryExpr, + outerChild); + if (IsMergeNodeNeeded(outerChild)) + { + outerChild = new DryadMergeNode(false, false, queryExpr, outerChild); + } + } + else + { + // No luck. Hash partition both outer and inner + int parCnt = Math.Max(outerChild.OutputPartition.Count, innerChild.OutputPartition.Count); + if (parCnt > 1) + { + outerChild = new DryadHashPartitionNode(outerKeySelector, + comparerExpr, + parCnt, + queryExpr, + outerChild); + if (IsMergeNodeNeeded(outerChild)) + { + outerChild = new DryadMergeNode(false, false, queryExpr, outerChild); + } + + innerChild = new DryadHashPartitionNode(innerKeySelector, + comparerExpr, + parCnt, + queryExpr, + innerChild); + if (IsMergeNodeNeeded(innerChild)) + { + innerChild = new DryadMergeNode(false, false, queryExpr, innerChild); + } + } + } + + // Perform either merge or hash join + string opName = "Hash"; + if (outerChild.IsOrderedBy(outerKeySelector, comparer)) + { + if (!innerChild.IsOrderedBy(innerKeySelector, comparer) || + isOuterDescending != isInnerDescending) + { + // Sort inner if unsorted + innerChild = new DryadOrderByNode(innerKeySelector, comparerExpr, + isOuterDescending, queryExpr, innerChild); + } + opName = "Merge"; + } + else if (innerChild.IsOrderedBy(innerKeySelector, comparer)) + { + if (!outerChild.IsOrderedBy(outerKeySelector, comparer) || + isOuterDescending != isInnerDescending) + { + // Sort outer if unsorted + outerChild = new DryadOrderByNode(outerKeySelector, comparerExpr, + isInnerDescending, queryExpr, outerChild); + } + opName = "Merge"; + } + + joinNode = new DryadJoinNode(nodeType, + opName + nodeType, + outerKeySelector, + innerKeySelector, + resultSelector, + comparerExpr, + queryExpr, + outerChild, + innerChild); + return joinNode; + } + + private DryadQueryNode VisitDistinct(QueryNodeInfo source, + Expression comparerExpr, + Expression queryExpr) + { + DryadQueryNode child = this.Visit(source); + + Type keyType = child.OutputTypes[0]; + if (comparerExpr == null && !TypeSystem.HasDefaultEqualityComparer(keyType)) + { + throw DryadLinqException.Create(HpcLinqErrorCode.ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable, + string.Format(SR.ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable, keyType), + queryExpr); + } + + object comparer = null; + if (comparerExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + comparer = evaluator.Eval(comparerExpr); + } + + LambdaExpression keySelectExpr = IdentityFunction.Instance(keyType); + if (!child.OutputPartition.IsPartitionedBy(keySelectExpr, comparer)) + { + if (child.IsDynamic || child.OutputPartition.Count > 1) + { + child = new DryadDistinctNode(true, comparerExpr, queryExpr, child); + bool isDynamic = (StaticConfig.DynamicOptLevel & StaticConfig.DynamicHashPartitionLevel) != 0; + child = new DryadHashPartitionNode(keySelectExpr, + comparerExpr, + child.OutputPartition.Count, + isDynamic, + queryExpr, + child); + child = new DryadMergeNode(false, false, queryExpr, child); + } + } + DryadQueryNode resNode = new DryadDistinctNode(false, comparerExpr, queryExpr, child); + return resNode; + } + + private DryadQueryNode VisitConcat(QueryNodeInfo source, MethodCallExpression queryExpr) + { + DryadQueryNode[] childs = new DryadQueryNode[source.children.Count]; + for (int i = 0; i < source.children.Count; ++i) + { + childs[i] = this.Visit(source.children[i].child); + } + DryadQueryNode resNode = new DryadConcatNode(queryExpr, childs); + + int parCount = resNode.OutputPartition.Count; + if (!resNode.IsDynamic && parCount > StaticConfig.MaxPartitionCount) + { + // Too many partitions, need to repartition + int newParCount = parCount / 2; + DryadQueryNode countNode = new DryadBasicAggregateNode(null, AggregateOpType.LongCount, + true, false, queryExpr, resNode); + DryadQueryNode mergeCountNode = new DryadMergeNode(true, false, queryExpr, countNode); + + // Apply node for s => IndexedCount(s) + Type paramType = typeof(IEnumerable<>).MakeGenericType(typeof(long)); + ParameterExpression param = Expression.Parameter(paramType, "s"); + MethodInfo minfo = typeof(HpcLinqHelper).GetMethod("IndexedCount"); + minfo = minfo.MakeGenericMethod(typeof(long)); + Expression body = Expression.Call(minfo, param); + Type funcType = typeof(Func<,>).MakeGenericType(param.Type, body.Type); + LambdaExpression indexedCountFunc = Expression.Lambda(funcType, body, param); + DryadQueryNode indexedCountNode = new DryadApplyNode(indexedCountFunc, queryExpr, mergeCountNode); + + // HashPartition(x => x.index, parCount) + param = Expression.Parameter(body.Type.GetGenericArguments()[0], "x"); + Expression keySelectBody = Expression.Property(param, "Index"); + funcType = typeof(Func<,>).MakeGenericType(param.Type, keySelectBody.Type); + LambdaExpression keySelectExpr = Expression.Lambda(funcType, keySelectBody, param); + DryadQueryNode distCountNode = new DryadHashPartitionNode(keySelectExpr, + null, + parCount, + queryExpr, + indexedCountNode); + + // Apply node for (x, y) => AddPartitionIndex(x, y, newParCount) + ParameterExpression param1 = Expression.Parameter(body.Type, "x"); + Type paramType2 = typeof(IEnumerable<>).MakeGenericType(resNode.OutputTypes[0]); + ParameterExpression param2 = Expression.Parameter(paramType2, "y"); + minfo = typeof(HpcLinqHelper).GetMethod("AddPartitionIndex"); + minfo = minfo.MakeGenericMethod(resNode.OutputTypes[0]); + body = Expression.Call(minfo, param1, param2, Expression.Constant(newParCount)); + funcType = typeof(Func<,,>).MakeGenericType(param1.Type, param2.Type, body.Type); + LambdaExpression addIndexFunc = Expression.Lambda(funcType, body, param1, param2); + DryadQueryNode addIndexNode = new DryadApplyNode(addIndexFunc, queryExpr, distCountNode, resNode); + + // HashPartition(x => x.index, x => x.value, newParCount) + param = Expression.Parameter(body.Type.GetGenericArguments()[0], "x"); + body = Expression.Property(param, "Index"); + funcType = typeof(Func<,>).MakeGenericType(param.Type, body.Type); + keySelectExpr = Expression.Lambda(funcType, body, param); + body = Expression.Property(param, "Value"); + funcType = typeof(Func<,>).MakeGenericType(param.Type, body.Type); + LambdaExpression resultSelectExpr = Expression.Lambda(funcType, body, param); + resNode = new DryadHashPartitionNode(keySelectExpr, + resultSelectExpr, + null, + newParCount, + false, + queryExpr, + addIndexNode); + resNode = new DryadMergeNode(true, true, queryExpr, resNode); + } + return resNode; + } + + private DryadQueryNode VisitSetOperation(QueryNodeInfo source1, + QueryNodeInfo source2, + QueryNodeType nodeType, + Expression comparerExpr, + Expression queryExpr) + { + DryadQueryNode child1 = this.Visit(source1); + DryadQueryNode child2 = this.Visit(source2); + DryadQueryNode resNode = null; + + Type keyType = child1.OutputTypes[0]; + if (comparerExpr == null && !TypeSystem.HasDefaultEqualityComparer(keyType)) + { + throw DryadLinqException.Create(HpcLinqErrorCode.ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable, + string.Format(SR.ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable, keyType), + queryExpr); + } + + // The comparer object: + object comparer = null; + if (comparerExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + comparer = evaluator.Eval(comparerExpr); + } + + LambdaExpression keySelectExpr = IdentityFunction.Instance(keyType); + if (child1.IsDynamic || child2.IsDynamic) + { + // Well, let us do the simplest thing for now + child1 = new DryadHashPartitionNode(keySelectExpr, + null, + StaticConfig.DefaultPartitionCount, + queryExpr, + child1); + if (IsMergeNodeNeeded(child1)) + { + child1 = new DryadMergeNode(false, false, queryExpr, child1); + } + + child2 = new DryadHashPartitionNode(keySelectExpr, + null, + StaticConfig.DefaultPartitionCount, + queryExpr, + child2); + if (IsMergeNodeNeeded(child2)) + { + child2 = new DryadMergeNode(false, false, queryExpr, child2); + } + + resNode = new DryadSetOperationNode(nodeType, nodeType.ToString(), comparerExpr, + queryExpr, child1, child2); + return resNode; + } + + bool isDescending1 = child1.OutputDataSetInfo.orderByInfo.IsDescending; + bool isDescending2 = child2.OutputDataSetInfo.orderByInfo.IsDescending; + + // Partition child1 and child2 if needed + if (child1.OutputPartition.ParType == PartitionType.Range && + child1.OutputPartition.HasKeys && + child1.OutputPartition.IsPartitionedBy(keySelectExpr, comparer)) + { + if (child2.OutputPartition.ParType != PartitionType.Range || + !child2.OutputPartition.IsPartitionedBy(keySelectExpr, comparer) || + child1.OutputPartition.IsSamePartition(child2.OutputPartition)) + { + // Range distribute child2 using child1's partition + child2 = child1.OutputPartition.CreatePartitionNode(keySelectExpr, child2); + if (IsMergeNodeNeeded(child2)) + { + child2 = new DryadMergeNode(false, false, queryExpr, child2); + } + } + } + else if (child2.OutputPartition.ParType == PartitionType.Range && + child2.OutputPartition.HasKeys && + child2.OutputPartition.IsPartitionedBy(keySelectExpr, comparer)) + { + // Range distribute child1 using child2's partition + child1 = child2.OutputPartition.CreatePartitionNode(keySelectExpr, child1); + if (IsMergeNodeNeeded(child1)) + { + child1 = new DryadMergeNode(false, false, queryExpr, child1); + } + } + else if (child1.OutputPartition.ParType == PartitionType.Hash && + child1.OutputPartition.IsPartitionedBy(keySelectExpr, comparer)) + { + if (child2.OutputPartition.ParType != PartitionType.Hash || + !child2.OutputPartition.IsPartitionedBy(keySelectExpr, comparer) || + !child1.OutputPartition.IsSamePartition(child2.OutputPartition)) + { + // Hash distribute child2: + child2 = new DryadHashPartitionNode(keySelectExpr, + comparerExpr, + child1.OutputPartition.Count, + queryExpr, + child2); + if (IsMergeNodeNeeded(child2)) + { + child2 = new DryadMergeNode(false, false, queryExpr, child2); + } + } + } + else if (child2.OutputPartition.ParType == PartitionType.Hash && + child2.OutputPartition.IsPartitionedBy(keySelectExpr, comparer)) + { + child1 = new DryadHashPartitionNode(keySelectExpr, + comparerExpr, + child2.OutputPartition.Count, + queryExpr, + child1); + if (IsMergeNodeNeeded(child1)) + { + child1 = new DryadMergeNode(false, false, queryExpr, child1); + } + } + else + { + // No luck. Hash distribute both child1 and child2, then perform hash operation + int parCnt = Math.Max(child1.OutputPartition.Count, child2.OutputPartition.Count); + if (parCnt > 1) + { + child1 = new DryadHashPartitionNode(keySelectExpr, comparerExpr, parCnt, queryExpr, child1); + if (IsMergeNodeNeeded(child1)) + { + child1 = new DryadMergeNode(false, false, queryExpr, child1); + } + + child2 = new DryadHashPartitionNode(keySelectExpr, comparerExpr, parCnt, queryExpr, child2); + if (IsMergeNodeNeeded(child2)) + { + child2 = new DryadMergeNode(false, false, queryExpr, child2); + } + } + } + + // Perform either hash or ordered operation + string opName = ""; + if (child1.IsOrderedBy(keySelectExpr, comparer)) + { + if (!child1.IsOrderedBy(keySelectExpr, comparer) || + isDescending1 != isDescending2) + { + // Sort inner if unsorted + child2 = new DryadOrderByNode(keySelectExpr, comparerExpr, isDescending1, queryExpr, child2); + } + opName = "Ordered"; + } + else if (child2.IsOrderedBy(keySelectExpr, comparer)) + { + if (!child1.IsOrderedBy(keySelectExpr, comparer) || + isDescending1 != isDescending2) + { + // Sort outer if unsorted + child1 = new DryadOrderByNode(keySelectExpr, comparerExpr, isDescending2, queryExpr, child1); + } + opName = "Ordered"; + } + + resNode = new DryadSetOperationNode(nodeType, opName + nodeType, comparerExpr, queryExpr, child1, child2); + return resNode; + } + + private DryadQueryNode VisitContains(QueryNodeInfo source, + Expression valueExpr, + Expression comparerExpr, + bool isQuery, + Expression queryExpr) + { + DryadQueryNode child = this.Visit(source); + + Type keyType = child.OutputTypes[0]; + if (comparerExpr == null && !TypeSystem.HasDefaultEqualityComparer(keyType)) + { + throw DryadLinqException.Create(HpcLinqErrorCode.ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable, + string.Format(SR.ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable, keyType), + queryExpr); + } + + DryadQueryNode resNode = this.PromoteConcat( + source, child, + x => new DryadContainsNode(valueExpr, comparerExpr, queryExpr, x)); + resNode = new DryadBasicAggregateNode(null, AggregateOpType.Any, false, isQuery, queryExpr, resNode); + return resNode; + } + + private DryadQueryNode VisitQuantifier(QueryNodeInfo source, + LambdaExpression lambda, + AggregateOpType aggType, + bool isQuery, + Expression queryExpr) + { + DryadQueryNode child = this.Visit(source); + DryadQueryNode resNode = this.PromoteConcat( + source, child, + x => new DryadBasicAggregateNode( + lambda, aggType, true, isQuery, queryExpr, x)); + resNode = new DryadBasicAggregateNode(null, aggType, false, isQuery, queryExpr, resNode); + return resNode; + } + + private DryadQueryNode VisitAggregate(QueryNodeInfo source, + Expression seed, + LambdaExpression funcLambda, + LambdaExpression resultLambda, + bool isQuery, + Expression queryExpr) + { + DryadQueryNode child = this.Visit(source); + DryadQueryNode resNode = child; + if (HpcLinqExpression.IsAssociative(funcLambda)) + { + LambdaExpression combinerLambda = HpcLinqExpression.GetAssociativeCombiner(funcLambda); + ResourceAttribute attrib = AttributeSystem.GetResourceAttrib(funcLambda); + bool funcIsExpensive = (attrib != null && attrib.IsExpensive); + resNode = this.PromoteConcat( + source, child, + delegate(DryadQueryNode x) { + DryadQueryNode y = new DryadAggregateNode("AssocAggregate", seed, + funcLambda, combinerLambda, resultLambda, + 1, isQuery, queryExpr, x, false); + return new DryadAggregateNode("AssocTreeAggregate", seed, + funcLambda, combinerLambda, resultLambda, + 2, isQuery, queryExpr, y, false); + }); + resNode = new DryadAggregateNode("AssocAggregate", seed, funcLambda, combinerLambda, resultLambda, + 3, isQuery, queryExpr, resNode, funcIsExpensive); + } + else + { + resNode = new DryadAggregateNode("Aggregate", seed, funcLambda, null, resultLambda, 3, + isQuery, queryExpr, child, false); + } + return resNode; + } + + private DryadQueryNode VisitBasicAggregate(QueryNodeInfo source, + LambdaExpression lambda, + AggregateOpType aggType, + bool isQuery, + Expression queryExpr) + { + DryadQueryNode child = this.Visit(source); + if (aggType == AggregateOpType.Min || aggType == AggregateOpType.Max) + { + Type elemType = child.OutputTypes[0]; + if (lambda != null) + { + elemType = lambda.Body.Type; + } + if (!TypeSystem.HasDefaultComparer(elemType)) + { + throw DryadLinqException.Create(HpcLinqErrorCode.AggregationOperatorRequiresIComparable, + String.Format(SR.AggregationOperatorRequiresIComparable, aggType ), + queryExpr); + } + } + DryadQueryNode resNode = this.PromoteConcat( + source, child, + x => new DryadBasicAggregateNode(lambda, aggType, true, isQuery, queryExpr, x)); + + switch (aggType) + { + case AggregateOpType.Count: + case AggregateOpType.LongCount: + { + resNode = new DryadBasicAggregateNode(null, AggregateOpType.Sum, false, + isQuery, queryExpr, resNode); + break; + } + case AggregateOpType.Sum: + case AggregateOpType.Min: + case AggregateOpType.Max: + case AggregateOpType.Average: + { + resNode = new DryadBasicAggregateNode(null, aggType, false, + isQuery, queryExpr, resNode); + break; + } + default: + { + throw DryadLinqException.Create(HpcLinqErrorCode.OperatorNotSupported, + String.Format(SR.OperatorNotSupported, aggType), + queryExpr); + } + } + return resNode; + } + + private DryadQueryNode VisitGroupBy(QueryNodeInfo source, + LambdaExpression keySelectExpr, + LambdaExpression elemSelectExpr, + LambdaExpression resultSelectExpr, + Expression comparerExpr, + Expression queryExpr) + { + DryadQueryNode child = this.Visit(source); + + ExpressionInfo einfo = new ExpressionInfo(keySelectExpr); + if (einfo.IsExpensive) + { + // Any method call that is not tagged as "expensive=false" will be deemed expensive. + // if the keySelector is expensive, we rewrite the query so that the key-function is invoked only once + // and the record key passed around via a Pair. + // keyFunc becomes pair=>pair.Key + // elementSelector must be rewritten so that references to (record) become (pair.Value) + + Type[] vkTypes = keySelectExpr.Type.GetGenericArguments(); + Type pairType = typeof(Pair<,>).MakeGenericType(vkTypes[1], vkTypes[0]); + ParameterExpression pairParam = Expression.Parameter(pairType, "e"); + + // Add Select(x => new Pair(key(x), x)) + ParameterExpression valueParam = keySelectExpr.Parameters[0]; + Expression body = Expression.New(pairType.GetConstructors()[0], keySelectExpr.Body, valueParam); + Type delegateType = typeof(Func<,>).MakeGenericType(valueParam.Type, body.Type); + LambdaExpression selectExpr = Expression.Lambda(delegateType, body, valueParam); + child = new DryadSelectNode(QueryNodeType.Select, selectExpr, null, queryExpr, child); + + // Change keySelector to e => e.Key + PropertyInfo keyInfo = pairParam.Type.GetProperty("Key"); + body = Expression.Property(pairParam, keyInfo); + delegateType = typeof(Func<,>).MakeGenericType(pairParam.Type, body.Type); + keySelectExpr = Expression.Lambda(delegateType, body, pairParam); + + // Add or change elementSelector with e.Value + PropertyInfo valueInfo = pairParam.Type.GetProperty("Value"); + body = Expression.Property(pairParam, valueInfo); + if (elemSelectExpr != null) + { + ParameterSubst subst = new ParameterSubst(elemSelectExpr.Parameters[0], body); + body = subst.Visit(elemSelectExpr.Body); + } + delegateType = typeof(Func<,>).MakeGenericType(pairParam.Type, body.Type); + elemSelectExpr = Expression.Lambda(delegateType, body, pairParam); + } + + Type keyType = keySelectExpr.Type.GetGenericArguments()[1]; + if (comparerExpr == null && !TypeSystem.HasDefaultEqualityComparer(keyType)) + { + throw DryadLinqException.Create(HpcLinqErrorCode.ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable, + string.Format(SR.ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable, keyType), + queryExpr); + } + Type elemType; + if (elemSelectExpr == null) + { + elemType = keySelectExpr.Type.GetGenericArguments()[0]; + } + else + { + elemType = elemSelectExpr.Type.GetGenericArguments()[1]; + } + + // The comparer object: + object comparer = null; + if (comparerExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + comparer = evaluator.Eval(comparerExpr); + } + + LambdaExpression keySelectExpr1 = keySelectExpr; + LambdaExpression elemSelectExpr1 = elemSelectExpr; + LambdaExpression resultSelectExpr1 = resultSelectExpr; + LambdaExpression seedExpr1 = null; + LambdaExpression accumulateExpr1 = null; + + List dInfoList = null; + if (resultSelectExpr != null) + { + dInfoList = Decomposition.GetDecompositionInfoList(resultSelectExpr, this.m_codeGen); + } + + String groupByOpName = "GroupBy"; + DryadQueryNode groupByNode = child; + bool isPartitioned = child.OutputPartition.IsPartitionedBy(keySelectExpr, comparer); + if (dInfoList != null) + { + // ** Decomposable GroupBy-Reduce + // This block creates the first GroupByNode and does some preparation for subsequent nodes. + if (child.IsOrderedBy(keySelectExpr, comparer)) + { + groupByOpName = "OrderedGroupBy"; + } + + int dcnt = dInfoList.Count; + ParameterExpression keyParam; + if (resultSelectExpr.Parameters.Count == 1) + { + keyParam = Expression.Parameter(keyType, HpcLinqCodeGen.MakeUniqueName("k")); + } + else + { + keyParam = resultSelectExpr.Parameters[0]; + } + + // Seed: + ParameterExpression param2 = Expression.Parameter( + elemType, HpcLinqCodeGen.MakeUniqueName("e")); + Expression zeroExpr = Expression.Constant(0, typeof(int)); + Expression seedBody = zeroExpr; + if (dcnt != 0) + { + LambdaExpression seed = dInfoList[dcnt-1].Seed; + ParameterSubst subst = new ParameterSubst(seed.Parameters[0], param2); + seedBody = subst.Visit(seed.Body); + for (int i = dcnt - 2; i >= 0; i--) + { + seed = dInfoList[i].Seed; + subst = new ParameterSubst(seed.Parameters[0], param2); + Expression firstExpr = subst.Visit(seed.Body); + Type newPairType = typeof(Pair<,>).MakeGenericType(firstExpr.Type, seedBody.Type); + seedBody = Expression.New(newPairType.GetConstructors()[0], firstExpr, seedBody); + } + } + LambdaExpression seedExpr = Expression.Lambda(seedBody, param2); + + // Accumulate: + ParameterExpression param1 = Expression.Parameter( + seedBody.Type, HpcLinqCodeGen.MakeUniqueName("a")); + Expression accumulateBody = zeroExpr; + if (dcnt != 0) + { + accumulateBody = Decomposition.AccumulateList(param1, param2, dInfoList, 0); + } + LambdaExpression accumulateExpr = Expression.Lambda(accumulateBody, param1, param2); + + // Now prepare for the merge-aggregator and/or in the secondary group-by. + // keySelectExpr1: e => e.Key + Type reducerResType = typeof(Pair<,>).MakeGenericType(keyParam.Type, accumulateBody.Type); + ParameterExpression reducerResParam = Expression.Parameter(reducerResType, "e"); + PropertyInfo keyInfo = reducerResParam.Type.GetProperty("Key"); + Expression body = Expression.Property(reducerResParam, keyInfo); + Type delegateType = typeof(Func<,>).MakeGenericType(reducerResParam.Type, body.Type); + keySelectExpr1 = Expression.Lambda(delegateType, body, reducerResParam); + + // elemSelectExpr1: e => e.Value + PropertyInfo valueInfo = reducerResParam.Type.GetProperty("Value"); + body = Expression.Property(reducerResParam, valueInfo); + delegateType = typeof(Func<,>).MakeGenericType(reducerResParam.Type, body.Type); + elemSelectExpr1 = Expression.Lambda(delegateType, body, reducerResParam); + + // SeedExpr1 + param2 = Expression.Parameter(elemSelectExpr1.Body.Type, + HpcLinqCodeGen.MakeUniqueName("e")); + seedExpr1 = Expression.Lambda(param2, param2); + + // AccumulateExpr1 + Expression recursiveAccumulateBody = zeroExpr; + if (dcnt != 0) + { + recursiveAccumulateBody = Decomposition.RecursiveAccumulateList(param1, param2, dInfoList, 0); + } + accumulateExpr1 = Expression.Lambda(recursiveAccumulateBody, param1, param2); + + // resultSelectExpr1 + resultSelectExpr1 = null; + + // The first groupByNode. + // If the input was already correctly partitioned, this will be the only groupByNode. + bool isPartial = StaticConfig.GroupByLocalAggregationIsPartial && !isPartitioned; + groupByNode = new DryadGroupByNode( + groupByOpName, keySelectExpr, elemSelectExpr, null, + seedExpr, accumulateExpr, accumulateExpr1, comparerExpr, + isPartial, queryExpr, child); + } + else + { + // Can't do partial aggregation. + // Use sort, mergesort, and ordered groupby, if TKey implements IComparable. + if ((comparer != null && TypeSystem.IsComparer(comparer, keyType)) || + (comparer == null && TypeSystem.HasDefaultComparer(keyType))) + { + if (!child.IsOrderedBy(keySelectExpr, comparer)) + { + groupByNode = new DryadOrderByNode(keySelectExpr, comparerExpr, true, queryExpr, child); + } + groupByOpName = "OrderedGroupBy"; + } + + // Add a GroupByNode if it is partitioned or has elementSelector. + // If the input was already correctly partitioned, this will be the only groupByNode. + if (isPartitioned) + { + groupByNode = new DryadGroupByNode(groupByOpName, + keySelectExpr, + elemSelectExpr, + resultSelectExpr, + null, // seed + null, // accumulate + null, // recursiveAccumulate + comparerExpr, + false, // isPartial + queryExpr, + groupByNode); + } + else if (elemSelectExpr != null) + { + // Local GroupBy without resultSelector: + groupByNode = new DryadGroupByNode(groupByOpName, + keySelectExpr, + elemSelectExpr, + null, // resultSelect + null, // seed + null, // accumulate + null, // recursiveAccumulate + comparerExpr, + StaticConfig.GroupByLocalAggregationIsPartial, // isPartial + queryExpr, + groupByNode); + + // keySelectExpr1: g => g.Key + ParameterExpression groupParam = Expression.Parameter(groupByNode.OutputTypes[0], "g"); + PropertyInfo keyInfo = groupParam.Type.GetProperty("Key"); + Expression body = Expression.Property(groupParam, keyInfo); + Type delegateType = typeof(Func<,>).MakeGenericType(groupParam.Type, body.Type); + keySelectExpr1 = Expression.Lambda(delegateType, body, groupParam); + + // No elementSelector + elemSelectExpr1 = null; + + // resultSelectExpr1 + ParameterExpression keyParam; + Type groupType = typeof(IEnumerable<>).MakeGenericType(groupByNode.OutputTypes[0]); + groupParam = Expression.Parameter(groupType, HpcLinqCodeGen.MakeUniqueName("g")); + if (resultSelectExpr == null) + { + // resultSelectExpr1: (k, g) => MakeDryadLinqGroup(k, g) + keyParam = Expression.Parameter(keySelectExpr1.Body.Type, HpcLinqCodeGen.MakeUniqueName("k")); + MethodInfo groupingInfo = typeof(HpcLinqEnumerable).GetMethod("MakeHpcLinqGroup"); + groupingInfo = groupingInfo.MakeGenericMethod(keyParam.Type, elemType); + body = Expression.Call(groupingInfo, keyParam, groupParam); + } + else + { + // resultSelectExpr1: (k, g) => resultSelectExpr(k, FlattenGroups(g)) + keyParam = resultSelectExpr.Parameters[0]; + MethodInfo flattenInfo = typeof(HpcLinqEnumerable).GetMethod("FlattenGroups"); + flattenInfo = flattenInfo.MakeGenericMethod(keyParam.Type, elemType); + Expression groupExpr = Expression.Call(flattenInfo, groupParam); + ParameterSubst subst = new ParameterSubst(resultSelectExpr.Parameters[1], groupExpr); + body = subst.Visit(resultSelectExpr.Body); + } + delegateType = typeof(Func<,,>).MakeGenericType(keyParam.Type, groupParam.Type, body.Type); + resultSelectExpr1 = Expression.Lambda(delegateType, body, keyParam, groupParam); + } + } + + // At this point, the first GroupByNode has been created. + DryadMergeNode mergeNode = null; + DryadQueryNode groupByNode1 = groupByNode; + if (!isPartitioned) + { + // Create HashPartitionNode, MergeNode, and second GroupByNode + + // Note, if we are doing decomposable-GroupByReduce, there is still some work to go after this + // - attach the combiner to the first merge-node + // - attach the combiner to the merge-node + // - attach finalizer to second GroupBy + int parCount = (groupByNode.IsDynamic) ? StaticConfig.DefaultPartitionCount : groupByNode.OutputPartition.Count; + bool isDynamic = (StaticConfig.DynamicOptLevel & StaticConfig.DynamicHashPartitionLevel) != 0; + DryadQueryNode hdistNode = new DryadHashPartitionNode(keySelectExpr1, comparerExpr, parCount, + isDynamic, queryExpr, groupByNode); + + // Create the Merge Node + if (groupByOpName == "OrderedGroupBy") + { + // Mergesort with the same keySelector of the hash partition + mergeNode = new DryadMergeNode(keySelectExpr1, comparerExpr, true, false, queryExpr, hdistNode); + } + else + { + // Random merge + mergeNode = new DryadMergeNode(false, false, queryExpr, hdistNode); + } + groupByNode1 = new DryadGroupByNode(groupByOpName, keySelectExpr1, elemSelectExpr1, + resultSelectExpr1, seedExpr1, accumulateExpr1, + accumulateExpr1, comparerExpr, false, queryExpr, + mergeNode); + } + + // Final tidy-up for decomposable GroupBy-Reduce pattern. + // - attach combiner to first GroupByNode + // - attache combiner to MergeNode as an aggregator + // - build a SelectNode to project out results and call finalizer on them. + if (dInfoList != null) + { + // Add dynamic aggregator to the merge-node, if applicable + if (StaticConfig.GroupByDynamicReduce && !isPartitioned) + { + mergeNode.AddAggregateNode(groupByNode1); + } + + // Add the final Select node + Type keyResultPairType = typeof(Pair<,>).MakeGenericType(keyType, seedExpr1.Body.Type); + ParameterExpression keyResultPairParam = Expression.Parameter(keyResultPairType, + HpcLinqCodeGen.MakeUniqueName("e")); + PropertyInfo valuePropInfo_1 = keyResultPairType.GetProperty("Value"); + Expression combinedValueExpr = Expression.Property(keyResultPairParam, valuePropInfo_1); + + // First, build the combinerList + int dcnt = dInfoList.Count; + Expression[] combinerList = new Expression[dcnt]; + for (int i = 0; i < dcnt; i++) + { + if (i + 1 == dcnt) + { + combinerList[i] = combinedValueExpr; + } + else + { + PropertyInfo keyPropInfo = combinedValueExpr.Type.GetProperty("Key"); + combinerList[i] = Expression.Property(combinedValueExpr, keyPropInfo); + PropertyInfo valuePropInfo = combinedValueExpr.Type.GetProperty("Value"); + combinedValueExpr = Expression.Property(combinedValueExpr, valuePropInfo); + } + LambdaExpression finalizerExpr = dInfoList[i].FinalReducer; + if (finalizerExpr != null) + { + ParameterSubst subst = new ParameterSubst(finalizerExpr.Parameters[0], combinerList[i]); + combinerList[i] = subst.Visit(finalizerExpr.Body); + } + } + + // Build the funcList + Expression[] funcList = new Expression[dcnt]; + for (int i = 0; i < dcnt; i++) + { + funcList[i] = dInfoList[i].Func; + } + + // Apply the substitutions + CombinerSubst combinerSubst = new CombinerSubst(resultSelectExpr, keyResultPairParam, funcList, combinerList); + Expression finalizerSelectBody = combinerSubst.Visit(); + + // Finally, the Select node + Type delegateType = typeof(Func<,>).MakeGenericType(keyResultPairType, finalizerSelectBody.Type); + LambdaExpression selectExpr = Expression.Lambda(delegateType, finalizerSelectBody, keyResultPairParam); + groupByNode1 = new DryadSelectNode(QueryNodeType.Select, selectExpr, null, queryExpr, groupByNode1); + } + return groupByNode1; + } + + // Creates an "auto-sampling range-partition sub-query" + private DryadQueryNode CreateRangePartition(bool isDynamic, + LambdaExpression keySelectExpr, + LambdaExpression resultSelectExpr, + Expression comparerExpr, + Expression isDescendingExpr, + Expression queryExpr, + Expression partitionCountExpr, + DryadQueryNode child) + { + // Make child a Tee node + child.IsForked = true; + + // The partition count + Expression countExpr = null; + + if (isDescendingExpr == null) + { + isDescendingExpr = Expression.Constant(false, typeof(bool)); //default for isDescending is false. + } + + // NOTE: for MayRTM, isDynamic should never be true + if (!isDynamic) + { + if (partitionCountExpr != null) + { + countExpr = partitionCountExpr; + } + else + { + // If partitionCount was not explicitly set, use the child's partition count. + countExpr = Expression.Constant(child.OutputPartition.Count); + } + } + + Type recordType = child.OutputTypes[0]; + Type keyType = keySelectExpr.Type.GetGenericArguments()[1]; + + // Create x => Phase1Sampling(x_1, keySelector, denv) + Type lambdaParamType1 = typeof(IEnumerable<>).MakeGenericType(recordType); + ParameterExpression lambdaParam1 = Expression.Parameter(lambdaParamType1, "x_1"); + + ParameterExpression denvParam = Expression.Parameter(typeof(HpcLinqVertexEnv), "denv"); + + MethodInfo minfo = typeof(HpcLinqSampler).GetMethod("Phase1Sampling"); + Expression body = Expression.Call(minfo.MakeGenericMethod(recordType, keyType), + lambdaParam1, keySelectExpr, denvParam); + Type type = typeof(Func<,>).MakeGenericType(lambdaParam1.Type, body.Type); + LambdaExpression samplingExpr = Expression.Lambda(type, body, lambdaParam1); + + // Create the Sampling node + DryadApplyNode samplingNode = new DryadApplyNode(samplingExpr, queryExpr, child); + + // Create x => RangeSampler(x, keySelectExpr, comparer, isDescendingExpr) + Type lambdaParamType = typeof(IEnumerable<>).MakeGenericType(keyType); + ParameterExpression lambdaParam = Expression.Parameter(lambdaParamType, "x_2"); + + //For RTM, isDynamic should never be true. + //string methodName = (isDynamic) ? "RangeSampler_Dynamic" : "RangeSampler_Static"; + Debug.Assert(isDynamic == false, "Internal error: isDynamic is true."); + string methodName = "RangeSampler_Static"; + + minfo = typeof(HpcLinqSampler).GetMethod(methodName); + minfo = minfo.MakeGenericMethod(keyType); + Expression comparerArgExpr = comparerExpr; + if (comparerExpr == null) + { + if (!TypeSystem.HasDefaultComparer(keyType)) + { + throw DryadLinqException.Create(HpcLinqErrorCode.ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable, + string.Format(SR.ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable, keyType), + queryExpr); + } + comparerArgExpr = Expression.Constant(null, typeof(IComparer<>).MakeGenericType(keyType)); + } + + Expression lastArg; + if (isDynamic) + { + lastArg = denvParam; + } + else + { + lastArg = countExpr; + } + + body = Expression.Call(minfo, lambdaParam, comparerArgExpr, isDescendingExpr, lastArg); + type = typeof(Func<,>).MakeGenericType(lambdaParam.Type, body.Type); + LambdaExpression samplerExpr = Expression.Lambda(type, body, lambdaParam); + + // Create the sample node + DryadQueryNode sampleDataNode = new DryadMergeNode(false, true, queryExpr, samplingNode); + DryadQueryNode sampleNode = new DryadApplyNode(samplerExpr, queryExpr, sampleDataNode); + sampleNode.IsForked = true; + + // Create the range distribute node + DryadQueryNode resNode = new DryadRangePartitionNode(keySelectExpr, + resultSelectExpr, + null, + comparerExpr, + isDescendingExpr, + countExpr, + queryExpr, + child, + sampleNode); + resNode = new DryadMergeNode(false, true, queryExpr, resNode); + + // Set the dynamic manager for sampleNode + if (isDynamic) + { + sampleDataNode.DynamicManager = new DynamicRangeDistributor(resNode); + } + + return resNode; + } + + private DryadQueryNode VisitOrderBy(QueryNodeInfo source, + LambdaExpression keySelectExpr, + Expression comparerExpr, + bool isDescending, + Expression queryExpr) + { + DryadQueryNode child = this.Visit(source); + + Type keyType = keySelectExpr.Type.GetGenericArguments()[1]; + if (comparerExpr == null && !TypeSystem.HasDefaultComparer(keyType)) + { + throw DryadLinqException.Create(HpcLinqErrorCode.ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable, + string.Format(SR.ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable, keyType), + queryExpr); + } + + DryadQueryNode resNode = child; + if (child.OutputPartition.ParType == PartitionType.Range && + child.OutputPartition.IsPartitionedBy(keySelectExpr, comparerExpr, isDescending)) + { + // Only need to sort each partition + resNode = new DryadOrderByNode(keySelectExpr, comparerExpr, isDescending, queryExpr, child); + } + else + { + Expression isDescendingExpr = Expression.Constant(isDescending); + bool dynamicOptEnabled = (StaticConfig.DynamicOptLevel & StaticConfig.DynamicRangePartitionLevel) != 0; + if (dynamicOptEnabled || child.IsDynamic) + { + // NOTE: for MayRTM, this path should not be taken. + resNode = this.CreateRangePartition(true, keySelectExpr, null, comparerExpr, + isDescendingExpr, queryExpr, null, child); + resNode = new DryadOrderByNode(keySelectExpr, comparerExpr, isDescending, queryExpr, resNode); + } + else + { + if (child.OutputPartition.Count > 1) + { + resNode = this.CreateRangePartition(false, keySelectExpr, null, comparerExpr, + isDescendingExpr, queryExpr, null, child); + } + resNode = new DryadOrderByNode(keySelectExpr, comparerExpr, isDescending, queryExpr, resNode); + } + } + return resNode; + } + + private DryadQueryNode FirstStagePartitionOp(string opName, + QueryNodeType nodeType, + Expression controlExpr, + MethodCallExpression queryExpr, + DryadQueryNode child) + { + if (nodeType == QueryNodeType.TakeWhile) + { + Type ptype = typeof(IEnumerable<>).MakeGenericType(child.OutputTypes[0]); + ParameterExpression param = Expression.Parameter(ptype, HpcLinqCodeGen.MakeUniqueName("x")); + MethodInfo minfo = typeof(HpcLinqEnumerable).GetMethod("GroupTakeWhile"); + minfo = minfo.MakeGenericMethod(child.OutputTypes[0]); + Expression body = Expression.Call(minfo, param, controlExpr); + Type type = typeof(Func<,>).MakeGenericType(ptype, body.Type); + LambdaExpression procFunc = Expression.Lambda(type, body, param); + return new DryadApplyNode(procFunc, queryExpr, child); + } + else + { + return new DryadPartitionOpNode(opName, nodeType, controlExpr, true, queryExpr, child); + } + } + + private DryadQueryNode VisitPartitionOp(string opName, + QueryNodeInfo source, + QueryNodeType nodeType, + Expression controlExpr, + MethodCallExpression queryExpr) + { + DryadQueryNode resNode; + if (nodeType == QueryNodeType.TakeWhile && + controlExpr.Type.GetGenericArguments().Length != 2) + { + // The "indexed" version. + resNode = this.Visit(source); + + // The following block used to be skipped for resNode.OutputPartition.Count == 1, + // which causes compilation error (bug 13593) + // @@TODO[p3] : implement a working optimization for nPartition==1 that calls + // directly to Linq TakeWhile. + // Note: the test is: if (resNode.IsDynamic || resNode.OutputPartition.Count > 1) + { + resNode.IsForked = true; + + bool isLong = (queryExpr.Method.Name == "LongTakeWhile"); + DryadQueryNode offsetNode = this.CreateOffset(isLong, queryExpr, resNode); + + // Create (x, y) => GroupIndexedTakeWhile(x, y, controlExpr) + Type ptype1 = typeof(IEnumerable<>).MakeGenericType(resNode.OutputTypes[0]); + Type ptype2 = typeof(IEnumerable<>).MakeGenericType(typeof(IndexedValue)); + ParameterExpression param1 = Expression.Parameter(ptype1, HpcLinqCodeGen.MakeUniqueName("x")); + ParameterExpression param2 = Expression.Parameter(ptype2, HpcLinqCodeGen.MakeUniqueName("y")); + string methodName = "GroupIndexed" + queryExpr.Method.Name; + MethodInfo minfo = typeof(HpcLinqEnumerable).GetMethod(methodName); + minfo = minfo.MakeGenericMethod(resNode.OutputTypes[0]); + Expression body = Expression.Call(minfo, param1, param2, controlExpr); + Type type = typeof(Func<,,>).MakeGenericType(ptype1, ptype2, body.Type); + LambdaExpression procFunc = Expression.Lambda(type, body, param1, param2); + resNode = new DryadApplyNode(procFunc, queryExpr, resNode, offsetNode); + } + } + else if (!source.IsForked && + (nodeType == QueryNodeType.Take || nodeType == QueryNodeType.TakeWhile) && + (source.OperatorName == "OrderBy" || source.OperatorName == "OrderByDescending")) + { + resNode = this.Visit(source.children[0].child); + + bool isDescending = (source.OperatorName == "OrderByDescending"); + MethodCallExpression sourceQueryExpr = (MethodCallExpression)source.queryExpression; + LambdaExpression keySelectExpr = HpcLinqExpression.GetLambda(sourceQueryExpr.Arguments[1]); + Expression comparerExpr = null; + if (sourceQueryExpr.Arguments.Count == 3) + { + comparerExpr = sourceQueryExpr.Arguments[2]; + } + resNode = this.PromoteConcat( + source.children[0].child, + resNode, + delegate(DryadQueryNode x) { + DryadQueryNode y = new DryadOrderByNode(keySelectExpr, comparerExpr, + isDescending, sourceQueryExpr, x); + return FirstStagePartitionOp(opName, nodeType, controlExpr, queryExpr, y); + }); + if (resNode.IsDynamic || resNode.OutputPartition.Count > 1) + { + // Need a mergesort + resNode = new DryadMergeNode(keySelectExpr, comparerExpr, isDescending, false, sourceQueryExpr, resNode); + } + } + else + { + resNode = this.Visit(source); + if (nodeType == QueryNodeType.Take || nodeType == QueryNodeType.TakeWhile) + { + resNode = this.PromoteConcat( + source, resNode, + x => FirstStagePartitionOp(opName, nodeType, controlExpr, queryExpr, x)); + } + } + resNode = new DryadPartitionOpNode(opName, nodeType, controlExpr, false, queryExpr, resNode); + return resNode; + } + + private DryadQueryNode VisitZip(QueryNodeInfo first, + QueryNodeInfo second, + LambdaExpression resultSelector, + MethodCallExpression queryExpr) + { + DryadQueryNode child1 = this.Visit(first); + DryadQueryNode child2 = this.Visit(second); + + if (child1.IsDynamic || child2.IsDynamic) + { + // Well, let us for now do it on a single machine + child1 = new DryadMergeNode(true, false, queryExpr, child1); + child2 = new DryadMergeNode(true, false, queryExpr, child2); + + // Apply node for (x, y) => Zip(x, y, resultSelector) + Type paramType1 = typeof(IEnumerable<>).MakeGenericType(child1.OutputTypes[0]); + ParameterExpression param1 = Expression.Parameter(paramType1, "s1"); + Type paramType2 = typeof(IEnumerable<>).MakeGenericType(child2.OutputTypes[0]); + ParameterExpression param2 = Expression.Parameter(paramType2, "s2"); + MethodInfo minfo = typeof(HpcLinqHelper).GetMethod("Zip"); + minfo = minfo.MakeGenericMethod(child1.OutputTypes[0]); + Expression body = Expression.Call(minfo, param1, param2, resultSelector); + Type funcType = typeof(Func<,>).MakeGenericType(param1.Type, param2.Type, body.Type); + LambdaExpression procFunc = Expression.Lambda(funcType, body, param1, param2); + return new DryadApplyNode(procFunc, queryExpr, child1, child2); + } + else + { + int parCount1 = child1.OutputPartition.Count; + int parCount2 = child2.OutputPartition.Count; + + // Count nodes + DryadQueryNode countNode1 = new DryadBasicAggregateNode(null, AggregateOpType.LongCount, + true, false, queryExpr, child1); + DryadQueryNode countNode2 = new DryadBasicAggregateNode(null, AggregateOpType.LongCount, + true, false, queryExpr, child2); + countNode1 = new DryadMergeNode(true, false, queryExpr, countNode1); + countNode2 = new DryadMergeNode(true, false, queryExpr, countNode2); + + // Apply node for (x, y) => ZipCount(x, y) + Type paramType1 = typeof(IEnumerable<>).MakeGenericType(typeof(long)); + ParameterExpression param1 = Expression.Parameter(paramType1, "x"); + ParameterExpression param2 = Expression.Parameter(paramType1, "y"); + MethodInfo minfo = typeof(HpcLinqHelper).GetMethod("ZipCount"); + Expression body = Expression.Call(minfo, param1, param2); + Type funcType = typeof(Func<,,>).MakeGenericType(param1.Type, param2.Type, body.Type); + LambdaExpression zipCount = Expression.Lambda(funcType, body, param1, param2); + DryadQueryNode indexedCountNode = new DryadApplyNode(zipCount, queryExpr, countNode1, countNode2); + + // HashPartition(x => x.index, parCount2) + ParameterExpression param = Expression.Parameter(body.Type.GetGenericArguments()[0], "x"); + Expression keySelectBody = Expression.Property(param, "Index"); + funcType = typeof(Func<,>).MakeGenericType(param.Type, keySelectBody.Type); + LambdaExpression keySelectExpr = Expression.Lambda(funcType, keySelectBody, param); + DryadQueryNode distCountNode = new DryadHashPartitionNode(keySelectExpr, + null, + parCount2, + queryExpr, + indexedCountNode); + + // Apply node for (x, y) => AssignPartitionIndex(x, y) + param1 = Expression.Parameter(body.Type, "x"); + Type paramType2 = typeof(IEnumerable<>).MakeGenericType(child2.OutputTypes[0]); + param2 = Expression.Parameter(paramType2, "y"); + minfo = typeof(HpcLinqHelper).GetMethod("AssignPartitionIndex"); + minfo = minfo.MakeGenericMethod(child2.OutputTypes[0]); + body = Expression.Call(minfo, param1, param2); + funcType = typeof(Func<,,>).MakeGenericType(param1.Type, param2.Type, body.Type); + LambdaExpression assignIndex = Expression.Lambda(funcType, body, param1, param2); + DryadQueryNode addIndexNode = new DryadApplyNode(assignIndex, queryExpr, distCountNode, child2); + + // HashPartition(x => x.index, x => x.value, parCount1) + param = Expression.Parameter(body.Type.GetGenericArguments()[0], "x"); + body = Expression.Property(param, "Index"); + funcType = typeof(Func<,>).MakeGenericType(param.Type, body.Type); + keySelectExpr = Expression.Lambda(funcType, body, param); + body = Expression.Property(param, "Value"); + funcType = typeof(Func<,>).MakeGenericType(param.Type, body.Type); + LambdaExpression resultSelectExpr = Expression.Lambda(funcType, body, param); + DryadQueryNode newChild2 = new DryadHashPartitionNode(keySelectExpr, + resultSelectExpr, + null, + parCount1, + false, + queryExpr, + addIndexNode); + newChild2 = new DryadMergeNode(true, true, queryExpr, newChild2); + + // Finally the zip node + return new DryadZipNode(resultSelector, queryExpr, child1, newChild2); + } + } + + // Basic plan: (reverse all partitions) then (reverse data in each partition) + // The main complication is to perform the first step. + // Approach: + // - tee the input. + // - have a dummy apply node that produces the singleton {0} at each partition + // - merge to get a seq {0,0,..} whose length = nPartition. + // - convert that seq to { (0,n), (1,n), ...} + // - hash-partition to send one item to each of the n workers. + // - use binary-apply to attach targetIndex to each source item + // Apply( seq1 = indexCountPair, seq2 = original data) => ({tgt, item0}, {tgt, item1}, .. ) + // - hash-partition to move items to target partition. + // - use local LINQ reverse to do the local data reversal. + private DryadQueryNode VisitReverse(QueryNodeInfo source, Expression queryExpr) + { + DryadQueryNode child = this.Visit(source); + if (child.IsDynamic) + { + throw new DryadLinqException("Reverse is only supported for static partition count"); + } + + child.IsForked = true; + + // Apply node for s => ValueZero(s) + Type paramType = typeof(IEnumerable<>).MakeGenericType(child.OutputTypes[0]); + ParameterExpression param = Expression.Parameter(paramType, "s"); + MethodInfo minfo = typeof(HpcLinqHelper).GetMethod("ValueZero"); + minfo = minfo.MakeGenericMethod(child.OutputTypes[0]); + Expression body = Expression.Call(minfo, param); + Type funcType = typeof(Func<,>).MakeGenericType(param.Type, body.Type); + LambdaExpression procFunc = Expression.Lambda(funcType, body, param); + DryadQueryNode valueZeroNode = new DryadApplyNode(procFunc, queryExpr, child); + + // Apply node for s => ReverseIndex(s) + paramType = typeof(IEnumerable<>).MakeGenericType(typeof(int)); + param = Expression.Parameter(paramType, "s"); + minfo = typeof(HpcLinqHelper).GetMethod("MakeIndexCountPairs"); + body = Expression.Call(minfo, param); + funcType = typeof(Func<,>).MakeGenericType(param.Type, body.Type); + procFunc = Expression.Lambda(funcType, body, param); + DryadQueryNode mergeZeroNode = new DryadMergeNode(true, true, queryExpr, valueZeroNode); + DryadQueryNode indexCountNode = new DryadApplyNode(procFunc, queryExpr, mergeZeroNode); + + // HashPartition to distribute the indexCounts -- one to each partition. + // each partition will receive (myPartitionID, pcount). + int pcount = child.OutputPartition.Count; + param = Expression.Parameter(body.Type.GetGenericArguments()[0], "x"); + Expression keySelectBody = Expression.Property(param, "Index"); + funcType = typeof(Func<,>).MakeGenericType(param.Type, keySelectBody.Type); + LambdaExpression keySelectExpr = Expression.Lambda(funcType, keySelectBody, param); + DryadQueryNode hdistNode = new DryadHashPartitionNode(keySelectExpr, + null, + pcount, + queryExpr, + indexCountNode); + + // Apply node for (x, y) => AddIndexForReverse(x, y) + ParameterExpression param1 = Expression.Parameter(body.Type, "x"); + Type paramType2 = typeof(IEnumerable<>).MakeGenericType(child.OutputTypes[0]); + ParameterExpression param2 = Expression.Parameter(paramType2, "y"); + minfo = typeof(HpcLinqHelper).GetMethod("AddIndexForReverse"); + minfo = minfo.MakeGenericMethod(child.OutputTypes[0]); + body = Expression.Call(minfo, param1, param2); + funcType = typeof(Func<,,>).MakeGenericType(param1.Type, param2.Type, body.Type); + LambdaExpression addIndexFunc = Expression.Lambda(funcType, body, param1, param2); + DryadQueryNode addIndexNode = new DryadApplyNode(addIndexFunc, queryExpr, hdistNode, child); + + // HashPartition(x => x.index, x => x.value, pcount) + // Moves all data to correct target partition. (each worker will direct all its items to one target partition) + param = Expression.Parameter(body.Type.GetGenericArguments()[0], "x"); + body = Expression.Property(param, "Index"); + funcType = typeof(Func<,>).MakeGenericType(param.Type, body.Type); + keySelectExpr = Expression.Lambda(funcType, body, param); + body = Expression.Property(param, "Value"); + funcType = typeof(Func<,>).MakeGenericType(param.Type, body.Type); + LambdaExpression resultSelectExpr = Expression.Lambda(funcType, body, param); + DryadQueryNode reversePartitionNode = new DryadHashPartitionNode( + keySelectExpr, resultSelectExpr, null, + pcount, false, queryExpr, addIndexNode); + + // Reverse node + paramType = typeof(IEnumerable<>).MakeGenericType(reversePartitionNode.OutputTypes[0]); + param = Expression.Parameter(paramType, "x"); + minfo = typeof(HpcLinqVertex).GetMethod("Reverse"); + minfo = minfo.MakeGenericMethod(child.OutputTypes[0]); + body = Expression.Call(minfo, param); + funcType = typeof(Func<,>).MakeGenericType(param.Type, body.Type); + procFunc = Expression.Lambda(funcType, body, param); + DryadQueryNode resNode = new DryadMergeNode(true, true, queryExpr, reversePartitionNode); + resNode = new DryadApplyNode(procFunc, queryExpr, resNode); + + return resNode; + } + + private DryadQueryNode VisitSequenceEqual(QueryNodeInfo source1, + QueryNodeInfo source2, + Expression comparerExpr, + Expression queryExpr) + { + DryadQueryNode child1 = this.Visit(source1); + DryadQueryNode child2 = this.Visit(source2); + + Type elemType = child1.OutputTypes[0]; + if (comparerExpr == null && !TypeSystem.HasDefaultEqualityComparer(elemType)) + { + throw DryadLinqException.Create(HpcLinqErrorCode.ComparerExpressionMustBeSpecifiedOrElementTypeMustBeIEquatable, + String.Format(SR.ComparerExpressionMustBeSpecifiedOrElementTypeMustBeIEquatable, elemType), + queryExpr); + } + + // Well, let us do it on a single machine for now + child1 = new DryadMergeNode(true, false, queryExpr, child1); + child2 = new DryadMergeNode(true, false, queryExpr, child2); + + // Apply node for (x, y) => SequenceEqual(x, y, c) + Type paramType = typeof(IEnumerable<>).MakeGenericType(elemType); + ParameterExpression param1 = Expression.Parameter(paramType, "s1"); + ParameterExpression param2 = Expression.Parameter(paramType, "s2"); + MethodInfo minfo = typeof(HpcLinqHelper).GetMethod("SequenceEqual"); + minfo = minfo.MakeGenericMethod(elemType); + if (comparerExpr == null) + { + comparerExpr = Expression.Constant(null, typeof(IEqualityComparer<>).MakeGenericType(elemType)); + } + Expression body = Expression.Call(minfo, param1, param2, comparerExpr); + Type funcType = typeof(Func<,,>).MakeGenericType(param1.Type, param2.Type, body.Type); + LambdaExpression procFunc = Expression.Lambda(funcType, body, param1, param2); + return new DryadApplyNode(procFunc, queryExpr, child1, child2); + } + + private DryadQueryNode VisitHashPartition(QueryNodeInfo source, + LambdaExpression keySelectExpr, + Expression comparerExpr, + Expression countExpr, + Expression queryExpr) + { + DryadQueryNode child = this.Visit(source); + Type keyType = keySelectExpr.Type.GetGenericArguments()[1]; + + if (comparerExpr == null && !TypeSystem.HasDefaultEqualityComparer(keyType)) + { + throw DryadLinqException.Create(HpcLinqErrorCode.ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable, + string.Format(SR.ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable, keyType), + queryExpr); + } + + + bool isDynamic = (StaticConfig.DynamicOptLevel & StaticConfig.DynamicHashPartitionLevel) != 0; + + int nOutputPartitions; + if (countExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + nOutputPartitions = evaluator.Eval(countExpr); + isDynamic = false; + } + else + { + // Note: For MayRTM, isDynamic will never be true. + nOutputPartitions = (isDynamic) ? 1 : child.OutputPartition.Count; + } + + DryadQueryNode resNode = new DryadHashPartitionNode( + keySelectExpr, null, comparerExpr, nOutputPartitions, + isDynamic, queryExpr, child); + resNode = new DryadMergeNode(false, true, queryExpr, resNode); + return resNode; + } + + private DryadQueryNode VisitRangePartition(QueryNodeInfo source, + LambdaExpression keySelectExpr, + Expression keysExpr, + Expression comparerExpr, + Expression isDescendingExpr, + Expression partitionCountExpr, + Expression queryExpr) + { + DryadQueryNode child = this.Visit(source); + Type keyType = keySelectExpr.Type.GetGenericArguments()[1]; + + if (comparerExpr == null && !TypeSystem.HasDefaultComparer(keyType)) + { + throw DryadLinqException.Create(HpcLinqErrorCode.ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable, + string.Format(SR.ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable, keyType), + queryExpr); + } + + DryadQueryNode resNode; + if (keysExpr == null) + { + // case: no keys are provided -- create range partitioner with auto-separator-selection + bool dynamicOptEnabled = (StaticConfig.DynamicOptLevel & StaticConfig.DynamicRangePartitionLevel) != 0; + bool useDynamic = (dynamicOptEnabled || child.IsDynamic); + + // NOTE: for MayRTM, useDynamic should always be false + resNode = this.CreateRangePartition(useDynamic, keySelectExpr, null, comparerExpr, isDescendingExpr, + queryExpr, partitionCountExpr, child); + } + else + { + // case: keys are local enum (eg an array) -- create range partitioner with keys input via Object-store. + resNode = new DryadRangePartitionNode(keySelectExpr, null, keysExpr, comparerExpr, + isDescendingExpr, null, queryExpr, child); + resNode = new DryadMergeNode(false, true, queryExpr, resNode); + } + return resNode; + } + + private DryadQueryNode VisitMultiApply(QueryNodeInfo source, + LambdaExpression procLambda, + bool perPartition, + bool isFirstOnly, + MethodCallExpression queryExpr) + { + DryadQueryNode[] childs = new DryadQueryNode[source.children.Count]; + for (int i = 0; i < source.children.Count; ++i) + { + childs[i] = this.Visit(source.children[i].child); + } + + bool isDynamic = childs.Any(x => x.IsDynamic); + if (perPartition && !isDynamic) + { + // Homomorphic case. + if (isFirstOnly) + { + for (int i = 1; i < childs.Length; ++i) + { + childs[i] = new DryadTeeNode(childs[i].OutputTypes[0], true, queryExpr, childs[i]); + childs[i].ConOpType = ConnectionOpType.CrossProduct; + childs[i] = new DryadMergeNode(childs[0].OutputPartition.Count, queryExpr, childs[i]); + } + } + else + { + int count = childs[0].OutputPartition.Count; + for (int i = 1; i < childs.Length; ++i) + { + if (childs[i].OutputPartition.Count != count) + { + throw DryadLinqException.Create(HpcLinqErrorCode.HomomorphicApplyNeedsSamePartitionCount, + SR.HomomorphicApplyNeedsSamePartitionCount, + queryExpr); + } + } + } + } + else + { + // Non-homomorphic case. + for (int i = 0; i < childs.Length; ++i) + { + if (childs[i].IsDynamic || childs[i].OutputPartition.Count > 1) + { + childs[i] = new DryadMergeNode(true, false, queryExpr, childs[i]); + } + } + } + DryadQueryNode applyNode = new DryadApplyNode(procLambda, true, queryExpr, childs); + return applyNode; + } + + private DryadQueryNode VisitApply(QueryNodeInfo source1, + QueryNodeInfo source2, + LambdaExpression procLambda, + bool perPartition, + bool isFirstOnly, + Expression queryExpr) + { + DryadQueryNode child1 = this.Visit(source1); + + DryadQueryNode applyNode; + if (source2 == null) + { + // Unary-apply case: + if (perPartition) + { + //homomorphic + applyNode = this.PromoteConcat(source1, child1, x => new DryadApplyNode(procLambda, queryExpr, x)); + } + else + { + //non-homomorphic + if (child1.IsDynamic || child1.OutputPartition.Count > 1) + { + child1 = new DryadMergeNode(true, false, queryExpr, child1); + } + applyNode = new DryadApplyNode(procLambda, queryExpr, child1); + } + } + else + { + // Binary-apply case: + DryadQueryNode child2 = this.Visit(source2); + + if (perPartition && isFirstOnly) + { + // The function is left homomorphic: + if (!child2.IsForked && (child1.IsDynamic || child1.OutputPartition.Count > 1)) + { + // The normal cases.. + if (IsMergeNodeNeeded(child2)) + { + if (child1.IsDynamic) + { + child2 = new DryadMergeNode(true, false, queryExpr, child2); + child2.IsForked = true; + } + else + { + // Rather than do full merge and broadcast, which has lots of data-movement + // 1. Tee output2 with output cross-product + // 2. Do a merge-stage which will have input1.nPartition nodes each performing a merge. + // This acheives a distribution of the entire input2 to the Apply nodes with least data-movement. + child2 = new DryadTeeNode(child2.OutputTypes[0], true, queryExpr, child2); + child2.ConOpType = ConnectionOpType.CrossProduct; + child2 = new DryadMergeNode(child1.OutputPartition.Count, queryExpr, child2); + } + } + else + { + // the right-data is alread a single partition, so just tee it. + // this will provide a copy to each of the apply nodes. + child2 = new DryadTeeNode(child2.OutputTypes[0], true, queryExpr, child2); + } + } + else + { + // Less common cases.. + // a full merge of the right-data may be necessary. + if (child2.IsDynamic || child2.OutputPartition.Count > 1) + { + child2 = new DryadMergeNode(true, false, queryExpr, child2); + if (child1.IsDynamic || child1.OutputPartition.Count > 1) + { + child2.IsForked = true; + } + } + } + applyNode = new DryadApplyNode(procLambda, queryExpr, child1, child2); + } + else if (perPartition && !isFirstOnly && !child1.IsDynamic && !child2.IsDynamic) + { + // Full homomorphic + // No merging occurs. + // NOTE: We generally expect that both the left and right datasets have matching partitionCount. + // however, we don't test for it yet as users might know what they are doing, and it makes + // LocalDebug inconsistent as LocalDebug doesn't throw in that situation. + applyNode = new DryadApplyNode(procLambda, queryExpr, child1, child2); + } + else + { + // Non-homomorphic + // Full merges of both data sets is necessary. + if (child1.IsDynamic || child1.OutputPartition.Count > 1) + { + child1 = new DryadMergeNode(true, false, queryExpr, child1); + } + if (child2.IsDynamic || child2.OutputPartition.Count > 1) + { + child2 = new DryadMergeNode(true, false, queryExpr, child2); + } + applyNode = new DryadApplyNode(procLambda, queryExpr, child1, child2); + } + } + return applyNode; + } + + private DryadQueryNode VisitFork(QueryNodeInfo source, + LambdaExpression forkLambda, + Expression keysExpr, + Expression queryExpr) + { + DryadQueryNode child = this.Visit(source); + return new DryadForkNode(forkLambda, keysExpr, queryExpr, child); + } + + private DryadQueryNode VisitForkChoose(QueryNodeInfo source, + Expression indexExpr, + Expression queryExpr) + { + DryadQueryNode child = this.Visit(source); + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + int index = evaluator.Eval(indexExpr); + return ((DryadForkNode)child).Parents[index]; + } + + private DryadQueryNode VisitAssumeHashPartition(QueryNodeInfo source, + LambdaExpression keySelectExpr, + Expression keysExpr, + Expression comparerExpr, + Expression queryExpr) + { + DryadQueryNode child = this.Visit(source); + + Type keyType = keySelectExpr.Type.GetGenericArguments()[1]; + if (comparerExpr == null && !TypeSystem.HasDefaultEqualityComparer(keyType)) + { + throw DryadLinqException.Create(HpcLinqErrorCode.ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable, + String.Format(SR.ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable, keyType.FullName), + queryExpr); + } + + object comparer = null; + if (comparerExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + comparer = evaluator.Eval(comparerExpr); + } + DataSetInfo outputInfo = new DataSetInfo(child.OutputDataSetInfo); + outputInfo.partitionInfo = PartitionInfo.CreateHash(keySelectExpr, + child.OutputPartition.Count, + comparer, + keyType); + child.OutputDataSetInfo = outputInfo; + return child; + } + + private DryadQueryNode VisitAssumeRangePartition(QueryNodeInfo source, + LambdaExpression keySelectExpr, + Expression keysExpr, + Expression comparerExpr, + Expression isDescendingExpr, + Expression queryExpr) + { + DryadQueryNode child = this.Visit(source); + + Type keyType = keySelectExpr.Type.GetGenericArguments()[1]; + if (comparerExpr == null && !TypeSystem.HasDefaultComparer(keyType)) + { + throw DryadLinqException.Create(HpcLinqErrorCode.ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable, + String.Format(SR.ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable, keyType), + queryExpr); + } + + + object comparer = null; + if (comparerExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + comparer = evaluator.Eval(comparerExpr); + } + + object keys = null; + if (keysExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + keys = evaluator.Eval(keysExpr); + } + + //count the number of keys provided. + if (keys != null) + { + int nSeparators = 0; + var ie = ((IEnumerable)keys).GetEnumerator(); + while (ie.MoveNext()) + { + nSeparators++; + } + + if (!child.IsDynamic && nSeparators != child.PartitionCount - 1) + { + throw DryadLinqException.Create( + HpcLinqErrorCode.BadSeparatorCount, + String.Format(SR.BadSeparatorCount, nSeparators, child.PartitionCount - 1), + queryExpr); + } + } + + + bool? isDescending = null; + if (isDescendingExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + isDescending = evaluator.Eval(isDescendingExpr); + } + + DataSetInfo outputInfo = new DataSetInfo(child.OutputDataSetInfo); + outputInfo.partitionInfo = PartitionInfo.CreateRange(keySelectExpr, + keys, + comparer, + isDescending, + child.OutputPartition.Count, + keyType); + child.OutputDataSetInfo = outputInfo; + return child; + } + + private DryadQueryNode VisitAssumeOrderBy(QueryNodeInfo source, + LambdaExpression keySelectExpr, + Expression comparerExpr, + Expression isDescendingExpr, + Expression queryExpr) + { + DryadQueryNode child = this.Visit(source); + + Type keyType = keySelectExpr.Type.GetGenericArguments()[1]; + if (comparerExpr == null && !TypeSystem.HasDefaultComparer(keyType)) + { + throw DryadLinqException.Create(HpcLinqErrorCode.ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable, + String.Format(SR.ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable, keyType), + queryExpr); + } + + object comparer = null; + if (comparerExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + comparer = evaluator.Eval(comparerExpr); + } + + ExpressionSimplifier bevaluator = new ExpressionSimplifier(); + bool isDescending = bevaluator.Eval(isDescendingExpr); + + DataSetInfo outputInfo = new DataSetInfo(child.OutputDataSetInfo); + outputInfo.orderByInfo = OrderByInfo.Create(keySelectExpr, comparer, isDescending, keyType); + child.OutputDataSetInfo = outputInfo; + return child; + } + + private DryadQueryNode VisitSlidingWindow(QueryNodeInfo source, + LambdaExpression procLambda, + Expression windowSizeExpr, + Expression queryExpr) + { + // var windows = source.Apply(s => HpcLinqHelper.Last(s, windowSize)); + // var slided = windows.Apply(s => HpcLinqHelper.Slide(s)).HashPartition(x => x.Index); + // slided.Apply(source, (x, y) => HpcLinqHelper.ProcessWindows(x, y, procFunc, windowSize)); + DryadQueryNode child = this.Visit(source); + if (child.IsDynamic) + { + throw new DryadLinqException("SlidingWindow is only supported for static partition count"); + } + + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + Expression windowSize = Expression.Constant(evaluator.Eval(windowSizeExpr), typeof(int)); + + child.IsForked = true; + + // Apply node for s => Last(s, windowSize) + Type paramType = typeof(IEnumerable<>).MakeGenericType(child.OutputTypes[0]); + ParameterExpression param = Expression.Parameter(paramType, HpcLinqCodeGen.MakeUniqueName("s")); + MethodInfo minfo = typeof(HpcLinqHelper).GetMethod("Last"); + minfo = minfo.MakeGenericMethod(child.OutputTypes[0]); + Expression body = Expression.Call(minfo, param, windowSize); + Type funcType = typeof(Func<,>).MakeGenericType(param.Type, body.Type); + LambdaExpression procFunc = Expression.Lambda(funcType, body, param); + DryadQueryNode lastNode = new DryadApplyNode(procFunc, queryExpr, child); + lastNode = new DryadMergeNode(true, true, queryExpr, lastNode); + + // Apply node for s => Slide(s) + param = Expression.Parameter(body.Type, HpcLinqCodeGen.MakeUniqueName("s")); + minfo = typeof(HpcLinqHelper).GetMethod("Slide"); + minfo = minfo.MakeGenericMethod(child.OutputTypes[0]); + body = Expression.Call(minfo, param); + funcType = typeof(Func<,>).MakeGenericType(param.Type, body.Type); + procFunc = Expression.Lambda(funcType, body, param); + DryadQueryNode slideNode = new DryadApplyNode(procFunc, queryExpr, lastNode); + + // Hash partition to distribute from partition i to i+1 + int pcount = child.OutputPartition.Count; + param = Expression.Parameter(body.Type.GetGenericArguments()[0], "x"); + Expression keySelectBody = Expression.Property(param, "Index"); + funcType = typeof(Func<,>).MakeGenericType(param.Type, keySelectBody.Type); + LambdaExpression keySelectExpr = Expression.Lambda(funcType, keySelectBody, param); + DryadQueryNode hdistNode = new DryadHashPartitionNode(keySelectExpr, + null, + pcount, + queryExpr, + slideNode); + + // Apply node for (x, y) => ProcessWindows(x, y, proclambda, windowSize) + Type paramType1 = typeof(IEnumerable<>).MakeGenericType(body.Type); + ParameterExpression param1 = Expression.Parameter(paramType1, HpcLinqCodeGen.MakeUniqueName("x")); + Type paramType2 = typeof(IEnumerable<>).MakeGenericType(child.OutputTypes[0]); + ParameterExpression param2 = Expression.Parameter(paramType2, HpcLinqCodeGen.MakeUniqueName("y")); + minfo = typeof(HpcLinqHelper).GetMethod("ProcessWindows"); + minfo = minfo.MakeGenericMethod(child.OutputTypes[0], procLambda.Body.Type); + body = Expression.Call(minfo, param1, param2, procLambda, windowSize); + funcType = typeof(Func<,,>).MakeGenericType(param1.Type, param2.Type, body.Type); + procFunc = Expression.Lambda(funcType, body, param1, param2); + return new DryadApplyNode(procFunc, queryExpr, hdistNode, child); + } + + private DryadQueryNode VisitApplyWithPartitionIndex( + QueryNodeInfo source, + LambdaExpression procLambda, + Expression queryExpr) + { + // var indices = source.Apply(s => ValueZero(s)).Apply(s => AssignIndex(s)); + // indices.Apply(source, (x, y) => ProcessWithIndex(x, y, procFunc)); + DryadQueryNode child = this.Visit(source); + if (child.IsDynamic) + { + throw new DryadLinqException("ApplyWithPartitionIndex is only supported for static partition count"); + } + + child.IsForked = true; + + // Apply node for s => ValueZero(s) + Type paramType = typeof(IEnumerable<>).MakeGenericType(child.OutputTypes[0]); + ParameterExpression param = Expression.Parameter(paramType, "s"); + MethodInfo minfo = typeof(HpcLinqHelper).GetMethod("ValueZero"); + minfo = minfo.MakeGenericMethod(child.OutputTypes[0]); + Expression body = Expression.Call(minfo, param); + Type funcType = typeof(Func<,>).MakeGenericType(param.Type, body.Type); + LambdaExpression procFunc = Expression.Lambda(funcType, body, param); + DryadQueryNode valueZeroNode = new DryadApplyNode(procFunc, queryExpr, child); + valueZeroNode = new DryadMergeNode(true, true, queryExpr, valueZeroNode); + + // Apply node for s => AssignIndex(s) + paramType = typeof(IEnumerable<>).MakeGenericType(typeof(int)); + param = Expression.Parameter(paramType, "s"); + minfo = typeof(HpcLinqHelper).GetMethod("AssignIndex"); + body = Expression.Call(minfo, param); + funcType = typeof(Func<,>).MakeGenericType(param.Type, body.Type); + procFunc = Expression.Lambda(funcType, body, param); + DryadQueryNode assignIndexNode = new DryadApplyNode(procFunc, queryExpr, valueZeroNode); + + // HashPartition to distribute the indices -- one to each partition. + int pcount = child.OutputPartition.Count; + param = Expression.Parameter(body.Type, "x"); + funcType = typeof(Func<,>).MakeGenericType(param.Type, param.Type); + LambdaExpression keySelectExpr = Expression.Lambda(funcType, param, param); + DryadQueryNode hdistNode = new DryadHashPartitionNode(keySelectExpr, + null, + pcount, + queryExpr, + assignIndexNode); + + // Apply node for (x, y) => ProcessWithIndex(x, y, procLambda)); + Type paramType1 = typeof(IEnumerable<>).MakeGenericType(child.OutputTypes[0]); + ParameterExpression param1 = Expression.Parameter(paramType1, HpcLinqCodeGen.MakeUniqueName("x")); + Type paramType2 = typeof(IEnumerable<>).MakeGenericType(typeof(int)); + ParameterExpression param2 = Expression.Parameter(paramType2, HpcLinqCodeGen.MakeUniqueName("y")); + minfo = typeof(HpcLinqHelper).GetMethod("ProcessWithIndex"); + minfo = minfo.MakeGenericMethod(child.OutputTypes[0], procLambda.Body.Type); + body = Expression.Call(minfo, param1, param2, procLambda); + funcType = typeof(Func<,,>).MakeGenericType(param1.Type, param2.Type, body.Type); + procFunc = Expression.Lambda(funcType, body, param1, param2); + return new DryadApplyNode(procFunc, queryExpr, child, hdistNode); + } + + private DryadQueryNode VisitFirst(QueryNodeInfo source, + LambdaExpression lambda, + AggregateOpType aggType, + bool isQuery, + Expression queryExpr) + { + DryadQueryNode child = this.Visit(source); + DryadQueryNode resNode = this.PromoteConcat( + source, child, + x => new DryadBasicAggregateNode(lambda, aggType, true, isQuery, queryExpr, x)); + return new DryadBasicAggregateNode(null, aggType, false, isQuery, queryExpr, resNode); + } + + private DryadQueryNode VisitThenBy(QueryNodeInfo source, + LambdaExpression keySelectExpr, + bool isDescending, + Expression queryExpr) + { + // YY: This makes it hard to maintain OrderByInfo. + throw DryadLinqException.Create(HpcLinqErrorCode.OperatorNotSupported, + String.Format(SR.OperatorNotSupported, "ThenBy"), + queryExpr); + } + + private DryadQueryNode VisitDefaultIfEmpty(QueryNodeInfo source, + Expression defaultValueExpr, + MethodCallExpression queryExpr) + { + // YY: Not very useful. We could add it later. + throw DryadLinqException.Create(HpcLinqErrorCode.OperatorNotSupported, + String.Format(SR.OperatorNotSupported, "DefaultIfEmpty"), + queryExpr); + } + + private DryadQueryNode VisitElementAt(string opName, + QueryNodeInfo source, + Expression indexExpr, + Expression queryExpr) + { + // YY: Not very useful. We could add it later. + throw DryadLinqException.Create(HpcLinqErrorCode.OperatorNotSupported, + String.Format(SR.OperatorNotSupported, opName), + queryExpr); + } + + private DryadQueryNode VisitOfType(QueryNodeInfo source, + Type ofType, + Expression queryExpr) + { + // YY: Not very useful. + throw DryadLinqException.Create(HpcLinqErrorCode.OperatorNotSupported, + String.Format(SR.OperatorNotSupported, "OfType"), + queryExpr); + } + + private DryadQueryNode VisitQueryOperatorCall(QueryNodeInfo nodeInfo) + { + DryadQueryNode resNode = nodeInfo.queryNode; + if (resNode != null) return resNode; + + MethodCallExpression expression = (MethodCallExpression)nodeInfo.queryExpression; + string methodName = expression.Method.Name; + + #region LINQMETHODS + switch (methodName) + { + case "Aggregate": + case "AggregateAsQuery": + { + bool isQuery = (methodName == "AggregateAsQuery"); + if (expression.Arguments.Count == 2) + { + LambdaExpression funcLambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (funcLambda != null && funcLambda.Parameters.Count == 2) + { + resNode = this.VisitAggregate(nodeInfo.children[0].child, + null, + funcLambda, + null, + isQuery, + expression); + } + } + else if (expression.Arguments.Count == 3) + { + LambdaExpression funcLambda = HpcLinqExpression.GetLambda(expression.Arguments[2]); + if (funcLambda != null && funcLambda.Parameters.Count == 2) + { + resNode = this.VisitAggregate(nodeInfo.children[0].child, + expression.Arguments[1], + funcLambda, + null, + isQuery, + expression); + } + } + else if (expression.Arguments.Count == 4) + { + LambdaExpression funcLambda = HpcLinqExpression.GetLambda(expression.Arguments[2]); + LambdaExpression resultLambda = HpcLinqExpression.GetLambda(expression.Arguments[3]); + if (funcLambda != null && funcLambda.Parameters.Count == 2 && + resultLambda != null && resultLambda.Parameters.Count == 1) + { + resNode = this.VisitAggregate(nodeInfo.children[0].child, + expression.Arguments[1], + funcLambda, + resultLambda, + isQuery, + expression); + } + } + break; + } + case "Select": + case "LongSelect": + { + if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && (lambda.Parameters.Count == 1 || lambda.Parameters.Count == 2)) + { + resNode = this.VisitSelect(nodeInfo.children[0].child, + QueryNodeType.Select, + lambda, + null, + expression); + } + } + break; + } + case "SelectMany": + case "LongSelectMany": + { + if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count <= 2) + { + resNode = this.VisitSelect(nodeInfo.children[0].child, + QueryNodeType.SelectMany, + lambda, + null, + expression); + } + } + else if (expression.Arguments.Count == 3) + { + LambdaExpression lambda1 = HpcLinqExpression.GetLambda(expression.Arguments[1]); + LambdaExpression lambda2 = HpcLinqExpression.GetLambda(expression.Arguments[2]); + if (lambda1 != null && (lambda1.Parameters.Count == 1 || lambda1.Parameters.Count == 2) && + lambda2 != null && lambda2.Parameters.Count == 2) + { + resNode = this.VisitSelect(nodeInfo.children[0].child, + QueryNodeType.SelectMany, + lambda1, + lambda2, + expression); + } + } + break; + } + case "Join": + case "GroupJoin": + { + QueryNodeType nodeType = (methodName == "Join") ? QueryNodeType.Join : QueryNodeType.GroupJoin; + if (expression.Arguments.Count == 5) + { + LambdaExpression lambda2 = HpcLinqExpression.GetLambda(expression.Arguments[2]); + LambdaExpression lambda3 = HpcLinqExpression.GetLambda(expression.Arguments[3]); + LambdaExpression lambda4 = HpcLinqExpression.GetLambda(expression.Arguments[4]); + if (lambda2 != null && lambda2.Parameters.Count == 1 && + lambda3 != null && lambda3.Parameters.Count == 1 && + lambda4 != null && lambda4.Parameters.Count == 2) + { + resNode = this.VisitJoin(nodeInfo.children[0].child, + nodeInfo.children[1].child, + nodeType, + lambda2, + lambda3, + lambda4, + null, + expression); + } + } + else if (expression.Arguments.Count == 6) + { + LambdaExpression lambda2 = HpcLinqExpression.GetLambda(expression.Arguments[2]); + LambdaExpression lambda3 = HpcLinqExpression.GetLambda(expression.Arguments[3]); + LambdaExpression lambda4 = HpcLinqExpression.GetLambda(expression.Arguments[4]); + if (lambda2 != null && lambda2.Parameters.Count == 1 && + lambda3 != null && lambda3.Parameters.Count == 1 && + lambda4 != null && lambda4.Parameters.Count == 2) + { + resNode = this.VisitJoin(nodeInfo.children[0].child, + nodeInfo.children[1].child, + nodeType, + lambda2, + lambda3, + lambda4, + expression.Arguments[5], + expression); + } + } + break; + } + case "OfType": + { + if (expression.Arguments.Count == 1) + { + Type ofType = expression.Method.GetGenericArguments()[0]; + resNode = this.VisitOfType(nodeInfo.children[0].child, ofType, expression); + } + break; + } + case "Where": + case "LongWhere": + { + if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && (lambda.Parameters.Count == 1 || lambda.Parameters.Count == 2)) + { + resNode = this.VisitWhere(nodeInfo.children[0].child, + lambda, + expression); + } + } + break; + } + case "First": + case "FirstOrDefault": + case "FirstAsQuery": + { + AggregateOpType aggType = (methodName == "FirstOrDefault") ? AggregateOpType.FirstOrDefault : AggregateOpType.First; + bool isQuery = (methodName == "FirstAsQuery"); + if (expression.Arguments.Count == 1) + { + resNode = this.VisitFirst(nodeInfo.children[0].child, null, aggType, isQuery, expression); + } + else if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitFirst(nodeInfo.children[0].child, + lambda, + aggType, + isQuery, + expression); + } + } + break; + } + case "Single": + case "SingleOrDefault": + case "SingleAsQuery": + { + AggregateOpType aggType = (methodName == "SingleOrDefault") ? AggregateOpType.SingleOrDefault : AggregateOpType.Single; + bool isQuery = (methodName == "SingleAsQuery"); + if (expression.Arguments.Count == 1) + { + resNode = this.VisitFirst(nodeInfo.children[0].child, null, aggType, isQuery, expression); + } + else if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitFirst(nodeInfo.children[0].child, + lambda, + aggType, + isQuery, + expression); + } + } + break; + } + case "Last": + case "LastOrDefault": + case "LastAsQuery": + { + AggregateOpType aggType = (methodName == "LastOrDefault") ? AggregateOpType.LastOrDefault : AggregateOpType.Last; + bool isQuery = (methodName == "LastAsQuery"); + if (expression.Arguments.Count == 1) + { + resNode = this.VisitFirst(nodeInfo.children[0].child, null, aggType, isQuery, expression); + } + else if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitFirst(nodeInfo.children[0].child, + lambda, + aggType, + isQuery, + expression); + } + } + break; + } + case "Distinct": + { + if (expression.Arguments.Count == 1) + { + resNode = this.VisitDistinct(nodeInfo.children[0].child, null, expression); + } + else if (expression.Arguments.Count == 2) + { + resNode = this.VisitDistinct(nodeInfo.children[0].child, + expression.Arguments[1], + expression); + } + break; + } + case "DefaultIfEmpty": + { + if (expression.Arguments.Count == 1) + { + resNode = this.VisitDefaultIfEmpty(nodeInfo.children[0].child, + null, + expression); + } + else if (expression.Arguments.Count == 2) + { + resNode = this.VisitDefaultIfEmpty(nodeInfo.children[0].child, + expression.Arguments[1], + expression); + } + break; + } + case "Concat": + { + if (expression.Arguments.Count == 2) + { + resNode = this.VisitConcat(nodeInfo, expression); + } + break; + } + case "Union": + { + if (expression.Arguments.Count == 2) + { + resNode = this.VisitSetOperation(nodeInfo.children[0].child, + nodeInfo.children[1].child, + QueryNodeType.Union, + null, + expression); + } + else if (expression.Arguments.Count == 3) + { + resNode = this.VisitSetOperation(nodeInfo.children[0].child, + nodeInfo.children[1].child, + QueryNodeType.Union, + expression.Arguments[2], + expression); + } + break; + } + case "Intersect": + { + if (expression.Arguments.Count == 2) + { + resNode = this.VisitSetOperation(nodeInfo.children[0].child, + nodeInfo.children[1].child, + QueryNodeType.Intersect, + null, + expression); + } + else if (expression.Arguments.Count == 3) + { + resNode = this.VisitSetOperation(nodeInfo.children[0].child, + nodeInfo.children[1].child, + QueryNodeType.Intersect, + expression.Arguments[2], + expression); + } + break; + } + case "Except": + { + if (expression.Arguments.Count == 2) + { + resNode = this.VisitSetOperation(nodeInfo.children[0].child, + nodeInfo.children[1].child, + QueryNodeType.Except, + null, + expression); + } + else if (expression.Arguments.Count == 3) + { + resNode = this.VisitSetOperation(nodeInfo.children[0].child, + nodeInfo.children[1].child, + QueryNodeType.Except, + expression.Arguments[2], + expression); + } + break; + } + case "Any": + case "AnyAsQuery": + { + bool isQuery = (methodName == "AnyAsQuery"); + if (expression.Arguments.Count == 1) + { + Type type = expression.Method.GetGenericArguments()[0]; + ParameterExpression param = Expression.Parameter(type, "x"); + Type delegateType = typeof(Func<,>).MakeGenericType(type, typeof(bool)); + Expression body = Expression.Constant(true); + LambdaExpression lambda = Expression.Lambda(delegateType, body, param); + resNode = this.VisitQuantifier(nodeInfo.children[0].child, + lambda, + AggregateOpType.Any, + isQuery, + expression); + } + else if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitQuantifier(nodeInfo.children[0].child, + lambda, + AggregateOpType.Any, + isQuery, + expression); + } + } + break; + } + case "All": + case "AllAsQuery": + { + bool isQuery = (methodName == "AllAsQuery"); + if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitQuantifier(nodeInfo.children[0].child, + lambda, + AggregateOpType.All, + isQuery, + expression); + } + } + break; + } + case "Count": + case "CountAsQuery": + { + bool isQuery = (methodName == "CountAsQuery"); + if (expression.Arguments.Count == 1) + { + resNode = this.VisitBasicAggregate(nodeInfo.children[0].child, + null, + AggregateOpType.Count, + isQuery, + expression); + } + else if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitBasicAggregate(nodeInfo.children[0].child, + lambda, + AggregateOpType.Count, + isQuery, + expression); + } + } + break; + } + case "LongCount": + case "LongCountAsQuery": + { + bool isQuery = (methodName == "LongCountAsQuery"); + if (expression.Arguments.Count == 1) + { + resNode = this.VisitBasicAggregate(nodeInfo.children[0].child, + null, + AggregateOpType.LongCount, + isQuery, + expression); + } + else if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitBasicAggregate(nodeInfo.children[0].child, + lambda, + AggregateOpType.LongCount, + isQuery, + expression); + } + } + break; + } + case "Sum": + case "SumAsQuery": + { + bool isQuery = (methodName == "SumAsQuery"); + if (expression.Arguments.Count == 1) + { + resNode = this.VisitBasicAggregate(nodeInfo.children[0].child, + null, + AggregateOpType.Sum, + isQuery, + expression); + } + else if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitBasicAggregate(nodeInfo.children[0].child, + lambda, + AggregateOpType.Sum, + isQuery, + expression); + } + } + break; + } + case "Min": + case "MinAsQuery": + { + bool isQuery = (methodName == "MinAsQuery"); + if (expression.Arguments.Count == 1) + { + resNode = this.VisitBasicAggregate(nodeInfo.children[0].child, + null, + AggregateOpType.Min, + isQuery, + expression); + } + else if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitBasicAggregate(nodeInfo.children[0].child, + lambda, + AggregateOpType.Min, + isQuery, + expression); + } + } + break; + } + case "Max": + case "MaxAsQuery": + { + bool isQuery = (methodName == "MaxAsQuery"); + if (expression.Arguments.Count == 1) + { + resNode = this.VisitBasicAggregate(nodeInfo.children[0].child, + null, + AggregateOpType.Max, + isQuery, + expression); + } + else if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitBasicAggregate(nodeInfo.children[0].child, + lambda, + AggregateOpType.Max, + isQuery, + expression); + } + } + break; + } + case "Average": + case "AverageAsQuery": + { + bool isQuery = (methodName == "AverageAsQuery"); + if (expression.Arguments.Count == 1) + { + resNode = this.VisitBasicAggregate(nodeInfo.children[0].child, + null, + AggregateOpType.Average, + isQuery, + expression); + } + else if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitBasicAggregate(nodeInfo.children[0].child, + lambda, + AggregateOpType.Average, + isQuery, + expression); + } + } + break; + } + case "GroupBy": + { + // groupby can take 2, 3, 4, or 5 arguments. + if (expression.Arguments.Count == 2) + { + //Supplied arguments are as follows:(source, key selector) + LambdaExpression keySelExpr = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (keySelExpr != null && keySelExpr.Parameters.Count == 1) + { + resNode = this.VisitGroupBy(nodeInfo.children[0].child, + keySelExpr, + null, + null, + null, + expression); + } + } + else if (expression.Arguments.Count == 3) + { + //Supplied arguments are as follows:(source, key selector, element selector/result selector/comparer) + LambdaExpression keySelExpr = HpcLinqExpression.GetLambda(expression.Arguments[1]); + LambdaExpression lambda2 = HpcLinqExpression.GetLambda(expression.Arguments[2]); + if (keySelExpr != null && lambda2 == null) + { + resNode = this.VisitGroupBy(nodeInfo.children[0].child, + keySelExpr, + null, + null, + expression.Arguments[2], + expression); + } + else if (keySelExpr != null && keySelExpr.Parameters.Count == 1 && lambda2 != null) + { + LambdaExpression elemSelExpr = null; + LambdaExpression resSelExpr = null; + if (lambda2.Parameters.Count == 1) + { + elemSelExpr = lambda2; + } + else if (lambda2.Parameters.Count == 2) + { + resSelExpr = lambda2; + } + resNode = this.VisitGroupBy(nodeInfo.children[0].child, + keySelExpr, + elemSelExpr, + resSelExpr, + null, + expression); + } + } + else if (expression.Arguments.Count == 4) + { + //Argument-0 is source and Argument-1 is key selector expression + LambdaExpression keySelExpr = HpcLinqExpression.GetLambda(expression.Arguments[1]); + LambdaExpression lambda2 = HpcLinqExpression.GetLambda(expression.Arguments[2]); + LambdaExpression lambda3 = HpcLinqExpression.GetLambda(expression.Arguments[3]); + if (keySelExpr != null && keySelExpr.Parameters.Count == 1 && lambda3 == null) + { + //Argument-2 can be either result selector, element selector and argument-3 is comparer + LambdaExpression elemSelExpr = null; + LambdaExpression resSelExpr = null; + if (lambda2.Parameters.Count == 1) + { + elemSelExpr = lambda2; + } + else if (lambda2.Parameters.Count == 2) + { + resSelExpr = lambda2; + } + resNode = this.VisitGroupBy(nodeInfo.children[0].child, + keySelExpr, + elemSelExpr, + resSelExpr, + expression.Arguments[3], + expression); + } + else if (keySelExpr != null && keySelExpr.Parameters.Count == 1 && + lambda2 != null && lambda2.Parameters.Count == 1 && + lambda3 != null && lambda3.Parameters.Count == 2) + { + //Argument-2 is element selector and argument-3 is result selector + resNode = this.VisitGroupBy(nodeInfo.children[0].child, + keySelExpr, + lambda2, + lambda3, + null, + expression); + } + } + else if (expression.Arguments.Count == 5) + { + //Supplied arguments are as follows:(source, key selector, element selector, result selector, comparer) + LambdaExpression keySelExpr = HpcLinqExpression.GetLambda(expression.Arguments[1]); + LambdaExpression elemSelExpr = HpcLinqExpression.GetLambda(expression.Arguments[2]); + LambdaExpression resSelExpr = HpcLinqExpression.GetLambda(expression.Arguments[3]); + if (keySelExpr != null && keySelExpr.Parameters.Count == 1 && + elemSelExpr != null && elemSelExpr.Parameters.Count == 1 && + resSelExpr != null && resSelExpr.Parameters.Count == 2) + { + resNode = this.VisitGroupBy(nodeInfo.children[0].child, + keySelExpr, + elemSelExpr, + resSelExpr, + expression.Arguments[4], + expression); + } + } + break; + } + case "OrderBy": + case "OrderByDescending": + { + bool isDescending = (methodName == "OrderByDescending"); + if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitOrderBy(nodeInfo.children[0].child, + lambda, + null, + isDescending, + expression); + } + } + else if (expression.Arguments.Count == 3) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitOrderBy(nodeInfo.children[0].child, + lambda, + expression.Arguments[2], + isDescending, + expression); + } + } + break; + } + case "ThenBy": + case "ThenByDescending": + { + bool isDescending = (methodName == "ThenByDescending"); + if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitThenBy(nodeInfo.children[0].child, + lambda, + isDescending, + expression); + } + } + break; + } + case "ElementAt": + case "ElementAtOrDefault": + { + if (expression.Arguments.Count == 2) + { + resNode = this.VisitElementAt(methodName, + nodeInfo.children[0].child, + expression.Arguments[1], + expression); + } + break; + } + case "Take": + { + if (expression.Arguments.Count == 2) + { + resNode = this.VisitPartitionOp(methodName, + nodeInfo.children[0].child, + QueryNodeType.Take, + expression.Arguments[1], + expression); + } + break; + } + case "TakeWhile": + case "LongTakeWhile": + { + if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && + (lambda.Parameters.Count == 1 || lambda.Parameters.Count == 2)) + { + resNode = this.VisitPartitionOp("TakeWhile", + nodeInfo.children[0].child, + QueryNodeType.TakeWhile, + lambda, + expression); + } + } + break; + } + case "Skip": + { + if (expression.Arguments.Count == 2) + { + resNode = this.VisitPartitionOp(methodName, + nodeInfo.children[0].child, + QueryNodeType.Skip, + expression.Arguments[1], + expression); + } + break; + } + case "SkipWhile": + case "LongSkipWhile": + { + if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && + (lambda.Parameters.Count == 1 || lambda.Parameters.Count == 2)) + { + resNode = this.VisitPartitionOp(methodName, + nodeInfo.children[0].child, + QueryNodeType.SkipWhile, + lambda, + expression); + } + } + break; + } + case "Contains": + case "ContainsAsQuery": + { + bool isQuery = (methodName == "ContainsAsQuery"); + if (expression.Arguments.Count == 2) + { + resNode = this.VisitContains(nodeInfo.children[0].child, + expression.Arguments[1], + null, + isQuery, + expression); + } + else if (expression.Arguments.Count == 3) + { + resNode = this.VisitContains(nodeInfo.children[0].child, + expression.Arguments[1], + expression.Arguments[2], + isQuery, + expression); + } + break; + } + case "Reverse": + { + resNode = this.VisitReverse(nodeInfo.children[0].child, expression); + break; + } + case "SequenceEqual": + case "SequenceEqualAsQuery": + { + if (expression.Arguments.Count == 2) + { + resNode = this.VisitSequenceEqual(nodeInfo.children[0].child, + nodeInfo.children[1].child, + null, + expression); + } + else if (expression.Arguments.Count == 3) + { + resNode = this.VisitSequenceEqual(nodeInfo.children[0].child, + nodeInfo.children[1].child, + expression.Arguments[2], + expression); + } + break; + } + case "Zip": + { + if (expression.Arguments.Count == 3) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[2]); + if (lambda != null && lambda.Parameters.Count == 2) + { + resNode = this.VisitZip(nodeInfo.children[0].child, + nodeInfo.children[1].child, + lambda, + expression); + } + } + break; + } + case "HashPartition": + { + if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitHashPartition(nodeInfo.children[0].child, + lambda, + null, + null, + expression); + } + } + else if (expression.Arguments.Count == 3) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + if (expression.Arguments[2].Type == typeof(int)) + { + resNode = this.VisitHashPartition(nodeInfo.children[0].child, + lambda, + null, + expression.Arguments[2], + expression); + } + else + { + resNode = this.VisitHashPartition(nodeInfo.children[0].child, + lambda, + expression.Arguments[2], + null, + expression); + } + } + } + else if (expression.Arguments.Count == 4) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitHashPartition(nodeInfo.children[0].child, + lambda, + expression.Arguments[2], + expression.Arguments[3], + expression); + } + } + break; + } + case "RangePartition": + { + //overloads: + // + // 2-param: + // (source, keySelector) + // + // 3-param: + // (source, keySelector, pcount) + // (source, keySelector, isDescending) + // (source, keySelector, rangeSeparators) + // + // 4-param: + // (source, keySelector, isDescending, pcount) + // (source, keySelector, keyComparer, isDescending) + // (source, keySelector, rangeSeparators, keyComparer) + // + // 5-param: + // (source, keySelector, keyComparer, isDescending, pcount) + // (source, keySelector, rangeSeparators, keyComparer, isDescending) + if (expression.Arguments.Count == 2) + { + // Case: (source, keySelector) + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + resNode = this.VisitRangePartition(nodeInfo.children[0].child, + lambda, + null, + null, + null, + null, + expression); + } + + else if (expression.Arguments.Count == 3) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + if (expression.Arguments[2].Type == typeof(int)) + { + // Case: (source, keySelector, pcount) + resNode = this.VisitRangePartition(nodeInfo.children[0].child, + lambda, + null, + null, + null, + expression.Arguments[2], + expression); + } + if (expression.Arguments[2].Type == typeof(bool)) + { + // Case: (source, keySelector, isDescending) + resNode = this.VisitRangePartition(nodeInfo.children[0].child, + lambda, + null, + null, + expression.Arguments[2], + null, + expression); + } + else if (expression.Arguments[2].Type.IsArray) + { + // Case: RangePartition(keySelector, TKey[] keys) + resNode = this.VisitRangePartition(nodeInfo.children[0].child, + lambda, + expression.Arguments[2], + null, + null, + null, + expression); + } + } + } + else if (expression.Arguments.Count == 4) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + if (expression.Arguments[2].Type == typeof(bool)) + { + //case: (source, keySelector, isDescending, pcount) + resNode = this.VisitRangePartition(nodeInfo.children[0].child, + lambda, + null, + null, + expression.Arguments[2], + expression.Arguments[3], + expression); + } + else if (expression.Arguments[3].Type == typeof(bool)) + { + //case: (source, keySelector, keyComparer, isDescending) + resNode = this.VisitRangePartition(nodeInfo.children[0].child, + lambda, + null, + expression.Arguments[2], + expression.Arguments[3], + null, + expression); + } + else if (expression.Arguments[2].Type.IsArray) + { + //case: (source, keySelector, rangeSeparators, keyComparer) + + resNode = this.VisitRangePartition(nodeInfo.children[0].child, + lambda, + expression.Arguments[2], + expression.Arguments[3], + null, + null, + expression); + } + } + } + else if (expression.Arguments.Count == 5) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + if (expression.Arguments[3].Type == typeof(bool)) + { + // case: (source, keySelector, keyComparer, isDescending, pcount) + resNode = this.VisitRangePartition(nodeInfo.children[0].child, + lambda, + null, + expression.Arguments[2], + expression.Arguments[3], + expression.Arguments[4], + expression); + } + else if (expression.Arguments[4].Type == typeof(bool)) + { + // case: (source, keySelector, rangeSeparators, keyComparer, isDescending) + resNode = this.VisitRangePartition(nodeInfo.children[0].child, + lambda, + expression.Arguments[2], + expression.Arguments[3], + expression.Arguments[4], + null, + expression); + } + + } + } + break; + } + case "Apply": + { + if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && + (lambda.Parameters.Count == 1 || lambda.Parameters.Count == 2)) + { + resNode = this.VisitApply(nodeInfo.children[0].child, + null, + lambda, + false, + false, + expression); + } + } + else if (expression.Arguments.Count == 3) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[2]); + if (lambda != null) + { + if (lambda.Parameters.Count == 2) + { + resNode = this.VisitApply(nodeInfo.children[0].child, + nodeInfo.children[1].child, + lambda, + false, + false, + expression); + } + else + { + // Apply with multiple sources of the same type + resNode = this.VisitMultiApply(nodeInfo, lambda, false, false, expression); + } + } + } + break; + } + case "ApplyPerPartition": + { + if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && + (lambda.Parameters.Count == 1 || lambda.Parameters.Count == 2)) + { + resNode = this.VisitApply(nodeInfo.children[0].child, + null, + lambda, + true, + false, + expression); + } + } + else if (expression.Arguments.Count == 4) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[2]); + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + bool isFirstOnly = evaluator.Eval(expression.Arguments[3]); + if (lambda != null) + { + if (lambda.Parameters.Count == 2) + { + resNode = this.VisitApply(nodeInfo.children[0].child, + nodeInfo.children[1].child, + lambda, + true, + isFirstOnly, + expression); + } + else + { + // Apply with multiple sources of the same type + resNode = this.VisitMultiApply(nodeInfo, lambda, true, isFirstOnly, expression); + } + } + } + break; + } + case "Fork": + { + if (expression.Arguments.Count == 2) // ForkSelect and ForkApply + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitFork(nodeInfo.children[0].child, + lambda, + null, + expression); + } + } + else if (expression.Arguments.Count == 3) // ForkByKey + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitFork(nodeInfo.children[0].child, + lambda, + expression.Arguments[2], + expression); + } + } + break; + } + case "ForkChoose": + { + if (expression.Arguments.Count == 2) + { + resNode = this.VisitForkChoose(nodeInfo.children[0].child, + expression.Arguments[1], + expression); + } + break; + } + case "AssumeHashPartition": + { + if (expression.Arguments.Count == 2) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitAssumeHashPartition(nodeInfo.children[0].child, + lambda, + null, + null, + expression); + } + } + else if (expression.Arguments.Count == 3) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitAssumeHashPartition(nodeInfo.children[0].child, + lambda, + null, + expression.Arguments[2], + expression); + } + } + break; + } + case "AssumeRangePartition": + { + if (expression.Arguments.Count == 3) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + if (expression.Arguments[2].Type.IsArray) + { + resNode = this.VisitAssumeRangePartition(nodeInfo.children[0].child, + lambda, + expression.Arguments[2], + null, + null, + expression); + } + else + { + resNode = this.VisitAssumeRangePartition(nodeInfo.children[0].child, + lambda, + null, + null, + expression.Arguments[2], + expression); + } + } + } + else if (expression.Arguments.Count == 4) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + if (expression.Arguments[2].Type.IsArray) + { + resNode = this.VisitAssumeRangePartition(nodeInfo.children[0].child, + lambda, + expression.Arguments[2], + expression.Arguments[3], + null, + expression); + } + else + { + resNode = this.VisitAssumeRangePartition(nodeInfo.children[0].child, + lambda, + null, + expression.Arguments[2], + expression.Arguments[3], + expression); + } + } + } + break; + } + case "AssumeOrderBy": + { + if (expression.Arguments.Count == 3) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitAssumeOrderBy(nodeInfo.children[0].child, + lambda, + null, + expression.Arguments[2], + expression); + } + } + else if (expression.Arguments.Count == 4) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitAssumeOrderBy(nodeInfo.children[0].child, + lambda, + expression.Arguments[2], + expression.Arguments[3], + expression); + } + } + break; + } + case "SlidingWindow": + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitSlidingWindow(nodeInfo.children[0].child, + lambda, + expression.Arguments[2], + expression); + } + break; + } + case "SelectWithPartitionIndex": + case "ApplyWithPartitionIndex": + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(expression.Arguments[1]); + if (lambda != null && lambda.Parameters.Count == 1) + { + resNode = this.VisitApplyWithPartitionIndex(nodeInfo.children[0].child, + lambda, + expression); + } + break; + } + case ReflectedNames.DryadLinqIQueryable_ToDscWorker: // was case "ToPartitionedTableLazy": + { + //SHOULD NOT VISIT.. The DryadQueryGen ctors should be interrogating ToDsc nodes directly. + //Later if we do allow ToDsc in the middle of query chain, then we need to either + // 1. update the source node with an outputeTableUri + // OR 2. create an actual node and handle it later on (tee, etc) + + throw DryadLinqException.Create(HpcLinqErrorCode.ToDscUsedIncorrectly, + String.Format(SR.ToDscUsedIncorrectly), + expression); + } + case ReflectedNames.DryadLinqIQueryable_ToHdfsWorker: // was case "ToPartitionedTableLazy": + { + //SHOULD NOT VISIT.. The DryadQueryGen ctors should be interrogating ToDsc nodes directly. + //Later if we do allow ToDsc in the middle of query chain, then we need to either + // 1. update the source node with an outputeTableUri + // OR 2. create an actual node and handle it later on (tee, etc) + + throw DryadLinqException.Create(HpcLinqErrorCode.ToHdfsUsedIncorrectly, String.Format(SR.ToHdfsUsedIncorrectly), expression); + } + } + #endregion + + if (resNode == null) + { + throw DryadLinqException.Create(HpcLinqErrorCode.OperatorNotSupported, + String.Format(SR.OperatorNotSupported, methodName), + expression); + } + resNode.IsForked = resNode.IsForked || nodeInfo.IsForked; + nodeInfo.queryNode = resNode; + return resNode; + } + } +} diff --git a/LinqToDryad/DryadQueryNode.cs b/LinqToDryad/DryadQueryNode.cs new file mode 100644 index 0000000..85d12bf --- /dev/null +++ b/LinqToDryad/DryadQueryNode.cs @@ -0,0 +1,4567 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; +using System.IO; +using System.Reflection; +using System.Linq; +using System.Linq.Expressions; +using System.CodeDom; +using System.Diagnostics; +using System.Xml; +using System.Data.Linq.Mapping; +using System.Data.Linq; +using Microsoft.Research.DryadLinq.Internal; + + +namespace Microsoft.Research.DryadLinq +{ + internal enum QueryNodeType + { + InputTable, + OutputTable, + Aggregate, + Select, + SelectMany, + Where, + Distinct, + BasicAggregate, + GroupBy, + OrderBy, + Skip, + SkipWhile, + Take, + TakeWhile, + Contains, + Join, + GroupJoin, + Union, + Intersect, + Except, + Concat, + Zip, + Super, + RangePartition, + HashPartition, + Merge, + Apply, + Fork, + Tee, + Dynamic, + Dummy + } + + internal enum AggregateOpType + { + Count, + LongCount, + Sum, + Min, + Max, + Average, + Any, + All, + First, + FirstOrDefault, + Single, + SingleOrDefault, + Last, + LastOrDefault + } + + internal enum ChannelType + { + MemoryFIFO, + TCPPipe, + DiskFile + } + + internal enum ConnectionOpType + { + Pointwise, + CrossProduct + } + + internal enum AffinityConstraintType + { + UseDefault, + HardConstraint, + OptimizationConstraint, + Preference, + DontCare + } + + internal abstract class DryadQueryNode + { + private QueryNodeType m_nodeType; + internal HpcLinqQueryGen m_queryGen; + private Expression m_queryExpression; + private List m_parents; + private DryadQueryNode[] m_children; + private DryadSuperNode m_superNode; + private bool m_isForked; + protected internal ChannelType m_channelType; + protected internal ConnectionOpType m_conOpType; + protected string m_opName; + internal int m_uniqueId; + internal string m_vertexEntryMethod; + protected internal int m_partitionCount; + internal DataSetInfo m_outputDataSetInfo; + protected internal DynamicManager m_dynamicManager; + protected internal List> m_referencedQueries; + + internal DryadQueryNode(QueryNodeType nodeType, + HpcLinqQueryGen queryGen, + Expression queryExpr, + params DryadQueryNode[] children) + { + this.m_nodeType = nodeType; + this.m_queryGen = queryGen; + this.m_queryExpression = queryExpr; + this.m_parents = new List(1); + this.m_children = children; + foreach (DryadQueryNode child in children) + { + child.Parents.Add(this); + } + this.m_superNode = null; + this.m_isForked = false; + this.m_uniqueId = HpcLinqQueryGen.StartPhaseId; + this.m_channelType = ChannelType.DiskFile; + this.m_conOpType = ConnectionOpType.Pointwise; + this.m_opName = null; + this.m_vertexEntryMethod = null; + this.m_outputDataSetInfo = null; + this.m_partitionCount = -1; + this.m_dynamicManager = null; + } + + internal QueryNodeType NodeType + { + get { return this.m_nodeType; } + } + + /// + /// The query generator to use for this query node. + /// + internal HpcLinqQueryGen QueryGen + { + get { return this.m_queryGen; } + } + + /// + /// (sub)Query expression corresponding to this node. + /// + internal Expression QueryExpression + { + get { return this.m_queryExpression; } + } + + /// + /// (sub)Query expression corresponding to this node, after it has been rewritten. + /// + internal virtual Expression RebuildQueryExpression(Expression inputExpr) + { + throw new NotSupportedException(SR.CannotRebuildOptimizedQueryExpression); + } + + /// + /// Children of this node: data sources. + /// + internal DryadQueryNode[] Children + { + get { return this.m_children; } + set { this.m_children = value; } + } + + /// + /// Parents of this node: data consumers. + /// + internal List Parents + { + get { return this.m_parents; } + } + + /// + /// A SuperNode contains many other elementary nodes inside. + /// + internal DryadSuperNode SuperNode + { + get { return this.m_superNode; } + set { this.m_superNode = value; } + } + + /// + /// Operation performed by node. + /// + internal string OpName + { + get { return this.m_opName; } + } + + internal ConnectionOpType ConOpType + { + get { return this.m_conOpType; } + set { this.m_conOpType = value; } + } + + internal ChannelType ChannelType + { + get { return this.m_channelType; } + set { this.m_channelType = value; } + } + + internal bool UseLargeWriteBuffer + { + get { + if (!StaticConfig.UseLargeBuffer || + this.IsStateful || + this.ChannelType != ChannelType.DiskFile) + { + return false; + } + return true; + } + } + + /// + /// One type for each input connection. + /// + internal virtual Type[] InputTypes + { + get { + Type[] types = new Type[this.Children.Length]; + for (int i = 0; i < types.Length; i++) + { + types[i] = this.Children[i].OutputTypes[0]; + } + return types; + } + } + + /// + /// Summary of the output data (static estimate). + /// + internal DataSetInfo OutputDataSetInfo + { + get { return this.m_outputDataSetInfo; } + set { this.m_outputDataSetInfo = value; } + } + + internal PartitionInfo OutputPartition + { + get { return this.m_outputDataSetInfo.partitionInfo; } + } + + internal int PartitionCount + { + get { return this.m_partitionCount; } + } + + internal bool IsOrderedBy(LambdaExpression expr, object comparer) + { + return this.m_outputDataSetInfo.orderByInfo.IsOrderedBy(expr, comparer); + } + + /// + /// Dynamic manager associated with first child. + /// + internal DynamicManager DynamicManager + { + get { return this.m_dynamicManager; } + set { this.m_dynamicManager = value; } + } + + internal DryadQueryNode OutputNode + { + get { + DryadQueryNode node = this; + if (node is DryadSuperNode) + { + node = ((DryadSuperNode)node).RootNode; + } + return node; + } + } + + internal bool IsDistributeNode + { + get { + DryadQueryNode curNode = this.OutputNode; + return (curNode.NodeType == QueryNodeType.RangePartition || + curNode.NodeType == QueryNodeType.HashPartition); + } + } + + internal bool IsForked + { + get { return this.m_isForked || (this.Parents.Count > 1); } + set { this.m_isForked = value; } + } + + internal virtual Int32 InputArity + { + get { + DryadQueryNode node = this; + if (node is DryadDynamicNode) + { + node = ((DryadDynamicNode)node).GetRealNode(0); + } + return node.Children.Length; + } + } + + internal Int32 OutputArity + { + get { + DryadQueryNode node = this.OutputNode; + + if (node is DryadForkNode) + { + return node.Parents.Count; + } + return (node.Parents.Count == 0) ? 0 : 1; + } + } + + internal virtual bool IsHomomorphic + { + get { return false; } + } + + internal virtual bool CanAttachPipeline + { + get { return false; } + } + + internal virtual Pipeline AttachedPipeline + { + set { + throw new DryadLinqException(HpcLinqErrorCode.Internal, SR.CannotAttach); + } + } + + internal virtual bool ContainsMerge + { + get { + return (this.DynamicManager == DynamicManager.PartialAggregator); + } + } + + internal virtual bool KeepInputPortOrder() + { + return false; + } + + /// + /// If true the node should not be pipelined with other stateful nodes. + /// + internal virtual bool IsStateful + { + get { return false; } + } + + internal abstract Type[] OutputTypes { get; } + + // not virtual yet as only SuperNode requires special handling and we don't need generality (yet) + // Handling for intermediate types is virtual as we anticipate more DryadQueryNodes will require this over time. + internal void CreateCodeAndMappingsForVertexTypes(bool intermediateTypesOnly) + { + // process the output types for this node + if (!intermediateTypesOnly) + { + for (int i = 0; i < this.OutputArity; i++) + { + this.QueryGen.CodeGen.AddDryadCodeForType(OutputTypes[i]); + } + } + + //process the intermediate types for this node + CreateCodeAndMappingsForIntermediateTypes(); + } + + //Default behavior: nothing + //However, some nodes may need to do some codegen/mapping for types that are not their output.s + internal virtual void CreateCodeAndMappingsForIntermediateTypes() + { + } + + internal abstract string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames); + + internal List> GetReferencedQueries() + { + if (this.m_referencedQueries == null) + { + ReferencedQuerySubst subst = new ReferencedQuerySubst(this.QueryGen.ReferencedQueryMap); + this.GetReferencedQueries(subst); + this.m_referencedQueries = subst.GetReferencedQueries(); + } + return this.m_referencedQueries; + } + + internal virtual void GetReferencedQueries(ReferencedQuerySubst subst) + { + } + + internal void AddSideReaders(CodeMemberMethod vertexMethod) + { + foreach (Pair nq in this.GetReferencedQueries()) + { + string factoryName = this.QueryGen.CodeGen.GetStaticFactoryName(nq.Value.OutputTypes[0]); + CodeExpression + readerExpr = new CodeMethodInvokeExpression( + new CodeArgumentReferenceExpression(HpcLinqCodeGen.DryadEnvName), + "MakeReader", + new CodeArgumentReferenceExpression(factoryName)); + readerExpr = new CodeMethodInvokeExpression(readerExpr, "ToArray"); + // readerExpr = new CodeMethodInvokeExpression(readerExpr, "AsQueryable"); + CodeStatement sideDecl = new CodeVariableDeclarationStatement("var", nq.Key, readerExpr); + vertexMethod.Statements.Add(sideDecl); + } + } + + /// + /// The summary to show for this node. + /// + /// + internal abstract void BuildString(StringBuilder builder); + + // Replace all occurences of oldNode in this.Parents by newNode. + // Return true iff oldNode is in this.Parents. + internal bool UpdateParent(DryadQueryNode oldNode, DryadQueryNode newNode) + { + bool found = false; + for (int i = 0; i < this.Parents.Count; i++) + { + if (Object.ReferenceEquals(oldNode, this.Parents[i])) + { + this.Parents[i] = newNode; + found = true; + } + } + if (!found) + { + this.Parents.Add(newNode); + } + return found; + } + + // Replace all occurences of oldNode in this.Children by newNode. + // Return true iff oldNode is in this.Children. + internal bool UpdateChildren(DryadQueryNode oldNode, DryadQueryNode newNode) + { + bool found = false; + for (int i = 0; i < this.Children.Length; i++) + { + if (Object.ReferenceEquals(oldNode, this.Children[i])) + { + this.Children[i] = newNode; + found = true; + } + } + return found; + } + + internal DryadQueryNode InsertTee(bool isForked) + { + if (this.OutputArity != 1) + { + //@@TODO: this should not be reachable. could change to Assert/InvalidOpEx + throw new DryadLinqException(HpcLinqErrorCode.Internal, SR.CannotAddTeeToNode); + } + List pnodes = new List(this.Parents); + this.Parents.Clear(); + DryadTeeNode teeNode = new DryadTeeNode(this.OutputTypes[0], isForked, this.QueryExpression, this); + teeNode.m_uniqueId = this.m_uniqueId; + teeNode.Parents.AddRange(pnodes); + DryadQueryNode oldNode = this.OutputNode; + foreach (DryadQueryNode pn in pnodes) + { + pn.UpdateChildren(oldNode, teeNode); + } + return teeNode; + } + + // Return true if the nodes this and child can't be pipelined together. + private bool CanNotBePipelinedWith(DryadQueryNode child) + { + if ((child is DryadInputNode) || + (child is DryadConcatNode) || + (child is DryadTeeNode) || child.IsForked || + ((child is DryadApplyNode) && ((DryadApplyNode)child).IsWriteToStream)) + { + return true; + } + if (child.ContainsMerge && (this.Children.Length > 1)) + { + for (int i = 0; i < this.Children.Length - 1; i++) + { + if (child == this.Children[i]) return true; + } + } + return false; + } + + // Determine if the current node can be reduced with some of its children + // into a supernode. + internal bool CanBePipelined() + { + if ((this is DryadOutputNode) || + (this is DryadInputNode) || + (this is DryadMergeNode) || + (this is DryadTeeNode) || + (this is DryadConcatNode)) + { + return false; + } + if ((this is DryadHashPartitionNode) && + ((DryadHashPartitionNode)this).IsDynamicDistributor) + { + return false; + } + if ((this is DryadRangePartitionNode) && + ((DryadRangePartitionNode)this).IsDynamicDistributor) + { + return false; + } + if ((this is DryadBasicAggregateNode) && + ((DryadBasicAggregateNode)this).IsMergeStage) + { + return false; + } + if ((this is DryadAggregateNode) && + ((DryadAggregateNode)this).IsMergeStage) + { + return false; + } + if ((this is DryadPartitionOpNode) && + ((DryadPartitionOpNode)this).IsMergeStage && + (this.Children[0].IsDynamic || this.Children[0].OutputPartition.Count > 1)) + { + return false; + } + if ((this is DryadApplyNode) && + ((DryadApplyNode)this).IsReadFromStream) + { + return false; + } + + bool canBePipelined = false; + foreach (DryadQueryNode child in this.Children) + { + if (!this.CanNotBePipelinedWith(child)) + { + canBePipelined = true; + break; + } + } + if (!canBePipelined) return false; + + // Not reducible if this node has a child of distribute node: + foreach (DryadQueryNode child in this.Children) + { + if (child.IsDistributeNode) return false; + } + + // Not reducible if this node and one of its children are stateful: + bool hasState = this.IsStateful; + foreach (DryadQueryNode child in this.Children) + { + if (child.IsStateful) + { + if (hasState) return false; + hasState = child.IsStateful; + } + } + return true; + } + + internal DryadQueryNode PipelineReduce() + { + if (!this.CanBePipelined()) + { + return this; + } + + DryadQueryNode[] nodeChildren = this.Children; + DryadSuperNode resNode = new DryadSuperNode(this); + List childList = new List(); + for (int i = 0; i < nodeChildren.Length; i++) + { + DryadQueryNode child = nodeChildren[i]; + if (this.CanNotBePipelinedWith(child)) + { + childList.Add(child); + bool found = child.UpdateParent(this, resNode); + } + else + { + if (child is DryadSuperNode) + { + DryadSuperNode superChild = (DryadSuperNode)child; + nodeChildren[i] = superChild.RootNode; + superChild.SwitchTo(resNode); + } + else + { + child.SuperNode = resNode; + } + + // Fix the child's children + foreach (DryadQueryNode child1 in child.Children) + { + childList.Add(child1); + bool found = child1.UpdateParent(child, resNode); + } + } + } + + DryadQueryNode[] resChildren = new DryadQueryNode[childList.Count]; + for (int i = 0; i < resChildren.Length; i++) + { + resChildren[i] = childList[i]; + } + resNode.Children = resChildren; + resNode.OutputDataSetInfo = resNode.RootNode.OutputDataSetInfo; + return resNode; + } + + // This can only be used before super nodes are formed. + internal bool IsDynamic + { + get { + if (this.m_dynamicManager.ManagerType == DynamicManagerType.Splitter || + this.m_dynamicManager.ManagerType == DynamicManagerType.PartialAggregator || + this.m_dynamicManager.ManagerType == DynamicManagerType.HashDistributor) + { + return true; + } + if (this is DryadMergeNode) + { + DryadQueryNode child = this.Children[0]; + if (child is DryadHashPartitionNode) + { + return ((DryadHashPartitionNode)child).IsDynamicDistributor; + } + if (child is DryadRangePartitionNode) + { + return ((DryadRangePartitionNode)child).IsDynamicDistributor; + } + } + if (this is DryadConcatNode) + { + foreach (DryadQueryNode child in this.Children) + { + if (child.IsDynamic) return true; + } + } + return false; + } + } + + // This can only be used before super nodes are formed. + protected DynamicManager InferDynamicManager() + { + DynamicManager dynamicMan = DynamicManager.None; + DryadQueryNode child = this.Children[0]; + if (child is DryadInputNode) + { + if (((DryadInputNode)child).Table.IsDynamic && this.Children.Length == 1) + { + dynamicMan = DynamicManager.PartialAggregator; + } + } + else if (child is DryadMergeNode) + { + if (((DryadMergeNode)child).IsSplitting) + { + dynamicMan = DynamicManager.Splitter; + } + } + else if (child is DryadConcatNode) + { + foreach (DryadQueryNode cc in child.Children) + { + if (cc is DryadInputNode) + { + if (((DryadInputNode)cc).Table.IsDynamic) + { + dynamicMan = DynamicManager.Splitter; + break; + } + } + else + { + DynamicManager ccdm = cc.InferDynamicManager(); + if (ccdm.ManagerType == DynamicManagerType.Splitter || + ccdm.ManagerType == DynamicManagerType.PartialAggregator) + { + dynamicMan = DynamicManager.Splitter; + break; + } + } + } + } + else + { + DynamicManager cdm = child.DynamicManager; + if (cdm.ManagerType == DynamicManagerType.Splitter || + cdm.ManagerType == DynamicManagerType.PartialAggregator) + { + dynamicMan = DynamicManager.Splitter; + } + } + return dynamicMan; + } + + internal virtual int AddToQueryPlan(XmlDocument queryDoc, + XmlElement queryPlan, + HashSet seen) + { + if (!seen.Contains(this.m_uniqueId)) + { + var refQueries = this.GetReferencedQueries(); + int[] cids = new int[refQueries.Count + this.Children.Length]; + for (int i = 0; i < refQueries.Count; i++) + { + cids[i] = refQueries[i].Value.AddToQueryPlan(queryDoc, queryPlan, seen); + } + for (int i = 0; i < this.Children.Length; i++) + { + cids[i+refQueries.Count] = this.Children[i].AddToQueryPlan(queryDoc, queryPlan, seen); + } + if (!seen.Contains(this.m_uniqueId)) + { + seen.Add(this.m_uniqueId); + XmlElement vertexElem = this.CreateVertexElem(queryDoc, this.m_uniqueId, this.m_vertexEntryMethod); + queryPlan.AppendChild(vertexElem); + + string dllName = this.QueryGen.CodeGen.GetDryadLinqDllName(); + XmlElement entryElem = DryadQueryDoc.CreateVertexEntryElem(queryDoc, dllName, this.m_vertexEntryMethod); + vertexElem.AppendChild(entryElem); + + XmlElement childrenElem = this.CreateVertexChildrenElem(queryDoc, cids); + vertexElem.AppendChild(childrenElem); + + if (this.OutputNode is DryadForkNode) + { + foreach (DryadQueryNode pnode in this.Parents) + { + pnode.AddToQueryPlan(queryDoc, queryPlan, seen); + } + } + } + } + return this.m_uniqueId; + } + + protected XmlElement CreateVertexElem(XmlDocument queryDoc, int uid, string name) + { + XmlElement vertexElem = queryDoc.CreateElement("Vertex"); + + XmlElement elem = queryDoc.CreateElement("UniqueId"); + elem.InnerText = Convert.ToString(uid); + vertexElem.AppendChild(elem); + + elem = queryDoc.CreateElement("Type"); + elem.InnerText = this.NodeType.ToString(); + vertexElem.AppendChild(elem); + + elem = queryDoc.CreateElement("Name"); + elem.InnerText = name; + vertexElem.AppendChild(elem); + + elem = queryDoc.CreateElement("Explain"); + StringBuilder plan = new StringBuilder(); + DryadQueryExplain.ExplainNode(plan, this); + XmlCDataSection data = queryDoc.CreateCDataSection(plan.ToString()); + elem.AppendChild(data); + vertexElem.AppendChild(elem); + + elem = queryDoc.CreateElement("Partitions"); + elem.InnerText = Convert.ToString(this.m_partitionCount); + vertexElem.AppendChild(elem); + + elem = queryDoc.CreateElement("ChannelType"); + elem.InnerText = this.ChannelType.ToString(); + vertexElem.AppendChild(elem); + + elem = queryDoc.CreateElement("ConnectionOperator"); + elem.InnerText = this.m_conOpType.ToString(); + vertexElem.AppendChild(elem); + + elem = this.m_dynamicManager.CreateElem(queryDoc); + vertexElem.AppendChild(elem); + + return vertexElem; + } + + protected XmlElement CreateVertexChildrenElem(XmlDocument queryDoc, params int[] cids) + { + XmlElement childrenElem = queryDoc.CreateElement("Children"); + for (int i = 0; i < cids.Length; i++) + { + XmlElement childElem = queryDoc.CreateElement("Child"); + childrenElem.AppendChild(childElem); + + XmlElement elem = queryDoc.CreateElement("UniqueId"); + elem.InnerText = Convert.ToString(cids[i]); + childElem.AppendChild(elem); + + elem = queryDoc.CreateElement("AffinityConstraint"); + elem.InnerText = AffinityConstraintType.UseDefault.ToString(); + childElem.AppendChild(elem); + } + return childrenElem; + } + + public override string ToString() + { + StringBuilder builder = new StringBuilder(); + this.BuildString(builder); + return builder.ToString(); + } + } + + internal class DryadInputNode : DryadQueryNode + { + //@@TODO[P2]: rename to m_query. also look for similar places to perform this change. + private DryadLinqQuery m_table; + + internal DryadInputNode(HpcLinqQueryGen queryGen, ConstantExpression queryExpr) + : base(QueryNodeType.InputTable, queryGen, queryExpr) + { + this.m_table = queryExpr.Value as DryadLinqQuery; + if (this.m_table == null) + { + throw DryadLinqException.Create(HpcLinqErrorCode.UnknownError, SR.InputMustBeHpcLinqSource, queryExpr); + } + if (TypeSystem.IsTypeOrAnyGenericParamsAnonymous(queryExpr.Type.GetGenericArguments()[0])) + { + throw DryadLinqException.Create(HpcLinqErrorCode.InputTypeCannotBeAnonymous, + SR.InputTypeCannotBeAnonymous, + queryExpr); + } + this.m_outputDataSetInfo = ((DryadLinqQuery)this.m_table).DataSetInfo; + this.m_partitionCount = this.m_outputDataSetInfo.partitionInfo.Count; + this.m_dynamicManager = DynamicManager.None; + } + + internal override Type[] OutputTypes + { + get { + Type[] typeArgs = this.QueryExpression.Type.GetGenericArguments(); + Debug.Assert(typeArgs != null && typeArgs.Length == 1); + return new Type[] { typeArgs[0] }; + } + } + + //@@TODO[P2]: rename to Query. Also look for other places. + internal DryadLinqQuery Table + { + get { return this.m_table; } + } + + internal override int AddToQueryPlan(XmlDocument queryDoc, + XmlElement queryPlan, + HashSet seen) + { + if (!seen.Contains(this.m_uniqueId)) + { + XmlElement vertexElem = this.CreateVertexElem(queryDoc, this.m_uniqueId, this.m_vertexEntryMethod); + queryPlan.AppendChild(vertexElem); + + XmlElement storageElem = queryDoc.CreateElement("StorageSet"); + vertexElem.AppendChild(storageElem); + + XmlElement elem = queryDoc.CreateElement("Type"); + if (DataPath.IsHdfs(this.Table.DataSourceUri)) + { + elem.InnerText = DataPath.HDFS_STORAGE_SET_TYPE; // hdfs input + } + else + { + elem.InnerText = DataPath.DSC_STORAGE_SET_TYPE; // dsc input + } + storageElem.AppendChild(elem); + + elem = queryDoc.CreateElement("SourceURI"); + elem.InnerText = this.m_table.DataSourceUri; + storageElem.AppendChild(elem); + } + + return this.m_uniqueId; + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + throw new NotImplementedException(); + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append(this.m_table.ToString()); + } + } + + internal class DryadOutputNode : DryadQueryNode + { + private string m_outputUri; + private Type m_outputType; + private DscCompressionScheme m_outputCompressionScheme; + private bool m_isTempOutput; + private HpcLinqContext m_context; + + internal DryadOutputNode(HpcLinqContext context, + string outputUri, + bool isTempOutput, + DscCompressionScheme outputScheme, + Expression queryExpr, + DryadQueryNode child) + : base(QueryNodeType.OutputTable, child.QueryGen, queryExpr, child) + { + if (TypeSystem.IsTypeOrAnyGenericParamsAnonymous(child.OutputTypes[0])) + { + throw DryadLinqException.Create(HpcLinqErrorCode.OutputTypeCannotBeAnonymous, + SR.OutputTypeCannotBeAnonymous, + queryExpr); + } + this.m_context = context; + this.m_outputUri = outputUri; + this.m_outputType = child.OutputTypes[0]; + this.m_outputDataSetInfo = child.OutputDataSetInfo; + this.m_partitionCount = child.OutputDataSetInfo.partitionInfo.Count; + this.m_dynamicManager = DynamicManager.Splitter; + this.m_outputCompressionScheme = outputScheme; + this.m_isTempOutput = isTempOutput; + } + + internal override Type[] InputTypes + { + get { return new Type[] { this.m_outputType }; } + } + + internal override Type[] OutputTypes + { + get { return new Type[] { this.m_outputType }; } + } + + internal string MetaDataUri + { + get { return this.m_outputUri; } + } + + internal DscCompressionScheme OutputCompressionScheme + { + get { return this.m_outputCompressionScheme; } + } + + internal override int AddToQueryPlan(XmlDocument queryDoc, + XmlElement queryPlan, + HashSet seen) + { + int cid = this.Children[0].AddToQueryPlan(queryDoc, queryPlan, seen); + if (!seen.Contains(this.m_uniqueId)) + { + seen.Add(this.m_uniqueId); + XmlElement vertexElem = this.CreateVertexElem(queryDoc, this.m_uniqueId, this.m_vertexEntryMethod); + queryPlan.AppendChild(vertexElem); + + XmlElement storageElem = queryDoc.CreateElement("StorageSet"); + vertexElem.AppendChild(storageElem); + + XmlElement elem = queryDoc.CreateElement("Type"); + if (DataPath.IsHdfs(this.m_outputUri)) + { + elem.InnerText = DataPath.HDFS_STORAGE_SET_TYPE; + } + else + { + elem.InnerText = DataPath.DSC_STORAGE_SET_TYPE; + } + storageElem.AppendChild(elem); + + elem = queryDoc.CreateElement("SinkURI"); + elem.InnerText = this.m_outputUri; + storageElem.AppendChild(elem); + + elem = queryDoc.CreateElement("IsTemporary"); + elem.InnerText = this.m_isTempOutput.ToString(); + storageElem.AppendChild(elem); + + DryadLinqMetaData metaData = DryadLinqMetaData.FromOutputNode(m_context, this); + + elem = queryDoc.CreateElement("OutputCompressionScheme"); + elem.InnerText = ((int)metaData.CompressionScheme).ToString(); + storageElem.AppendChild(elem); + + elem = queryDoc.CreateElement("RecordType"); + elem.InnerText = metaData.ElemType.AssemblyQualifiedName; + storageElem.AppendChild(elem); + + XmlElement childrenElem = this.CreateVertexChildrenElem(queryDoc, cid); + vertexElem.AppendChild(childrenElem); + } + return this.m_uniqueId; + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + throw new NotImplementedException(); + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[[Table: " + this.m_outputUri + "], "); + this.Children[0].BuildString(builder); + builder.Append("]"); + } + } + + internal class DryadWhereNode : DryadQueryNode + { + private LambdaExpression m_whereExpression; + private Expression m_whereExpression1; + + internal DryadWhereNode(LambdaExpression whereExpr, + Expression queryExpr, + DryadQueryNode child) + : base(QueryNodeType.Where, child.QueryGen, queryExpr, child) + { + this.m_whereExpression = whereExpr; + + //If indexed version and the index is a long, we will use opName=DryadLong. + if (this.m_whereExpression.Parameters.Count() == 2 && this.m_whereExpression.Parameters[1].Type == typeof(long)) + { + this.m_opName = "LongWhere"; + } + else + { + this.m_opName = "Where"; + } + + this.m_partitionCount = child.OutputPartition.Count; + this.m_outputDataSetInfo = new DataSetInfo(child.OutputDataSetInfo); + + this.m_dynamicManager = this.InferDynamicManager(); + } + + internal override bool IsHomomorphic + { + get { return this.m_whereExpression.Type.GetGenericArguments().Length == 2; } + } + + internal override Expression RebuildQueryExpression(Expression inputExpr) + { + return Expression.Call(typeof(System.Linq.Enumerable), + "Where", + new Type[] { this.InputTypes[0] }, + inputExpr, m_whereExpression1); + } + + internal override Type[] OutputTypes + { + get { return this.Children[0].OutputTypes; } + } + + internal bool OrderPreserving() + { + return (this.m_queryGen.Context.Configuration.SelectiveOrderPreservation || + this.OutputDataSetInfo.orderByInfo.IsOrdered); + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + CodeExpression whereExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + this.QueryGen.CodeGen.MakeExpression(this.m_whereExpression1), + new CodePrimitiveExpression(this.OrderPreserving())); + CodeVariableDeclarationStatement + sourceDecl = this.QueryGen.CodeGen.MakeVarDeclStatement("var", "source", whereExpr); + vertexMethod.Statements.Add(sourceDecl); + return sourceDecl.Name; + } + + internal override void GetReferencedQueries(ReferencedQuerySubst subst) + { + this.m_whereExpression1 = subst.Visit(this.m_whereExpression); + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + this.Children[0].BuildString(builder); + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_whereExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + builder.Append("]"); + } + + /// + /// The expression performing the filtering. + /// + internal LambdaExpression WhereExpression + { + get { return this.m_whereExpression; } + } + } + + internal class DryadSelectNode : DryadQueryNode + { + private LambdaExpression m_selectExpression; + private LambdaExpression m_resultSelectExpression; + private Expression m_selectExpression1; + private Expression m_resultSelectExpression1; + + internal DryadSelectNode(QueryNodeType nodeType, + LambdaExpression selectExpr, + LambdaExpression resultSelectExpr, + Expression queryExpr, + DryadQueryNode child) + : base(nodeType, child.QueryGen, queryExpr, child) + { + Debug.Assert(nodeType == QueryNodeType.Select || nodeType == QueryNodeType.SelectMany); + this.m_selectExpression = selectExpr; + this.m_resultSelectExpression = resultSelectExpr; + + //If indexed version and the index is a long, we will use opName=DryadLong. + if (this.m_selectExpression.Parameters.Count() == 2 && this.m_selectExpression.Parameters[1].Type == typeof(long)) + { + this.m_opName = "Long" + nodeType; + } + else + { + this.m_opName = nodeType.ToString(); + } + + this.m_partitionCount = child.OutputPartition.Count; + this.m_outputDataSetInfo = this.ComputeOutputDataSetInfo(); + + this.m_dynamicManager = this.InferDynamicManager(); + } + + internal LambdaExpression SelectExpression + { + get { return this.m_selectExpression; } + } + + internal LambdaExpression ResultSelectExpression + { + get { return this.m_resultSelectExpression; } + } + + internal override bool IsHomomorphic + { + get { return this.m_selectExpression.Type.GetGenericArguments().Length == 2; } + } + + internal override Expression RebuildQueryExpression(Expression inputExpr) + { + string methodName; + Type[] typeArgs; + Expression[] args; + if (this.OpName == "Select") + { + methodName = "Select"; + typeArgs = new Type[] { this.InputTypes[0], this.OutputTypes[0] }; + args = new Expression[] { inputExpr, this.m_selectExpression1 }; + } + else + { + methodName = "SelectMany"; + if (this.m_resultSelectExpression1 == null) + { + typeArgs = new Type[] { this.InputTypes[0], this.OutputTypes[0] }; + args = new Expression[] { inputExpr, this.m_selectExpression1 }; + } + else + { + Type collectionType = this.m_resultSelectExpression1.Type.GetGenericArguments()[1]; + typeArgs = new Type[] { this.InputTypes[0], collectionType, this.OutputTypes[0] }; + args = new Expression[] { inputExpr, this.m_selectExpression1, this.m_resultSelectExpression1 }; + } + } + return Expression.Call(typeof(System.Linq.Enumerable), methodName, typeArgs, args); + } + + internal override Type[] OutputTypes + { + get { + Type resType; + if (this.m_resultSelectExpression == null) + { + Type[] argTypes = this.m_selectExpression.Type.GetGenericArguments(); + resType = (argTypes.Length == 3) ? argTypes[2] : argTypes[1]; + if (this.OpName == "SelectMany" || this.OpName == "LongSelectMany") + { + resType = resType.GetGenericArguments()[0]; + } + } + else + { + resType = this.m_resultSelectExpression.Body.Type; + } + return new Type[] { resType }; + } + } + + private DataSetInfo ComputeOutputDataSetInfo() + { + DataSetInfo childInfo = this.Children[0].OutputDataSetInfo; + + ParameterExpression param = this.m_selectExpression.Parameters[0]; + PartitionInfo pinfo = childInfo.partitionInfo.Rewrite(this.m_selectExpression, param); + OrderByInfo oinfo = childInfo.orderByInfo.Rewrite(this.m_selectExpression, param); + DistinctInfo dinfo = DataSetInfo.NoDistinct; + DistinctAttribute attrib1 = AttributeSystem.GetDistinctAttrib(this.m_selectExpression); + if (attrib1 != null && (!attrib1.MustBeDistinct || childInfo.distinctInfo.IsDistinct())) + { + dinfo = DistinctInfo.Create(attrib1.Comparer, this.OutputTypes[0]); + } + + return new DataSetInfo(pinfo, oinfo, dinfo); + } + + internal bool OrderPreserving() + { + return (m_queryGen.Context.Configuration.SelectiveOrderPreservation || + this.OutputDataSetInfo.orderByInfo.IsOrdered); + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + CodeExpression selectorExpr = this.QueryGen.CodeGen.MakeExpression(this.m_selectExpression1); + CodeExpression selectExpr; + if (this.m_resultSelectExpression1 == null) + { + selectExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + selectorExpr, + new CodePrimitiveExpression(this.OrderPreserving())); + } + else + { + selectExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + selectorExpr, + this.QueryGen.CodeGen.MakeExpression(this.m_resultSelectExpression1), + new CodePrimitiveExpression(this.OrderPreserving())); + } + CodeVariableDeclarationStatement + sourceDecl = this.QueryGen.CodeGen.MakeVarDeclStatement("var", "source", selectExpr); + vertexMethod.Statements.Add(sourceDecl); + return sourceDecl.Name; + } + + internal override void GetReferencedQueries(ReferencedQuerySubst subst) + { + this.m_selectExpression1 = subst.Visit(this.m_selectExpression); + if (this.m_resultSelectExpression != null) + { + this.m_resultSelectExpression1 = subst.Visit(this.m_resultSelectExpression); + } + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + this.Children[0].BuildString(builder); + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_selectExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + builder.Append("]"); + } + } + + internal class DryadZipNode : DryadQueryNode + { + private LambdaExpression m_selectExpression; + private Expression m_selectExpression1; + + internal DryadZipNode(LambdaExpression selectExpr, + Expression queryExpr, + DryadQueryNode child1, + DryadQueryNode child2) + : base(QueryNodeType.Zip, child1.QueryGen, queryExpr, child1, child2) + { + this.m_opName = "Zip"; + this.m_selectExpression = selectExpr; + this.m_partitionCount = child1.OutputPartition.Count; + this.m_outputDataSetInfo = this.ComputeOutputDataSetInfo(); + } + + internal LambdaExpression SelectExpression + { + get { return this.m_selectExpression; } + } + + internal override Type[] OutputTypes + { + get { + Type resType = this.m_selectExpression.Body.Type; + return new Type[] { resType }; + } + } + + private DataSetInfo ComputeOutputDataSetInfo() + { + PartitionInfo pinfo = new RandomPartition(this.m_partitionCount); + OrderByInfo oinfo = DataSetInfo.NoOrderBy; + DistinctInfo dinfo = DataSetInfo.NoDistinct; + return new DataSetInfo(pinfo, oinfo, dinfo); + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + bool orderPreserving = false; + CodeExpression selectorExpr = this.QueryGen.CodeGen.MakeExpression(this.m_selectExpression1); + CodeExpression selectExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + selectorExpr, + new CodePrimitiveExpression(orderPreserving)); + CodeVariableDeclarationStatement + sourceDecl = this.QueryGen.CodeGen.MakeVarDeclStatement("var", "source", selectExpr); + vertexMethod.Statements.Add(sourceDecl); + return sourceDecl.Name; + } + + internal override void GetReferencedQueries(ReferencedQuerySubst subst) + { + this.m_selectExpression1 = subst.Visit(this.m_selectExpression); + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + this.Children[0].BuildString(builder); + builder.Append(", "); + this.Children[1].BuildString(builder); + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_selectExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + builder.Append("]"); + } + } + + internal class DryadOrderByNode : DryadQueryNode + { + private LambdaExpression m_keySelectExpression; + private Expression m_comparerExpression; + private bool m_isDescending; + private object m_comparer; + private int m_comparerIdx; + + internal DryadOrderByNode(LambdaExpression keySelectExpr, + Expression comparerExpr, + bool isDescending, + Expression queryExpr, + DryadQueryNode child) + : base(QueryNodeType.OrderBy, child.QueryGen, queryExpr, child) + { + this.m_keySelectExpression = keySelectExpr; + this.m_comparerExpression = comparerExpr; + this.m_isDescending = isDescending; + this.m_opName = "Sort"; + + this.m_comparer = null; + this.m_comparerIdx = -1; + if (comparerExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + this.m_comparer = evaluator.Eval(comparerExpr); + this.m_comparerIdx = HpcLinqObjectStore.Put(this.m_comparer); + } + + this.m_partitionCount = child.OutputPartition.Count; + + this.m_outputDataSetInfo = new DataSetInfo(child.OutputDataSetInfo); + Type[] typeArgs = this.KeySelectExpression.Type.GetGenericArguments(); + this.m_outputDataSetInfo.orderByInfo = OrderByInfo.Create(this.KeySelectExpression, + this.m_comparer, + this.m_isDescending, + typeArgs[1]); + + this.m_dynamicManager = this.InferDynamicManager(); + } + + internal override bool IsStateful + { + get { return true; } + } + + internal override Type[] OutputTypes + { + get { return this.Children[0].OutputTypes; } + } + + internal LambdaExpression KeySelectExpression + { + get { return this.m_keySelectExpression; } + } + + internal bool IsDescending + { + get { return this.m_isDescending; } + } + + internal Expression ComparerExpression + { + get { return this.m_comparerExpression; } + } + + internal object Comparer + { + get { return this.m_comparer; } + } + + internal override void CreateCodeAndMappingsForIntermediateTypes() + { + // External sort uses serializers directly, so we must process this type + // even if it appears inside a super-node and would otherwise not be serialized. + this.QueryGen.CodeGen.AddDryadCodeForType(this.InputTypes[0]); + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + CodeExpression comparerArg = HpcLinqCodeGen.NullExpr; + if (this.m_comparerExpression != null) + { + CodeExpression getCall = new CodeMethodInvokeExpression( + new CodeTypeReferenceExpression("HpcLinqObjectStore"), + "Get", + new CodePrimitiveExpression(this.m_comparerIdx)); + Type[] typeArgs = this.m_comparerExpression.Type.GetGenericArguments(); + Type comparerType = typeof(IComparer<>).MakeGenericType(typeArgs[0]); + comparerArg = new CodeCastExpression(comparerType, getCall); + } + + bool isIdentityFunc = IdentityFunction.IsIdentity(this.KeySelectExpression); + string factoryName = this.QueryGen.CodeGen.GetStaticFactoryName(this.OutputTypes[0]); + CodeExpression orderByExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + this.QueryGen.CodeGen.MakeExpression(this.KeySelectExpression), + comparerArg, + new CodePrimitiveExpression(this.m_isDescending), + new CodePrimitiveExpression(isIdentityFunc), + new CodeArgumentReferenceExpression(factoryName)); + + CodeVariableDeclarationStatement + sourceDecl = this.QueryGen.CodeGen.MakeVarDeclStatement("var", "source", orderByExpr); + vertexMethod.Statements.Add(sourceDecl); + return sourceDecl.Name; + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + this.Children[0].BuildString(builder); + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_keySelectExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + if (this.m_comparerExpression != null) + { + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_comparerExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + } + builder.Append("]"); + } + } + + internal class DryadGroupByNode : DryadQueryNode + { + private LambdaExpression m_keySelectExpr; + private LambdaExpression m_elemSelectExpr; + private LambdaExpression m_resSelectExpr; + private Expression m_keySelectExpr1; + private Expression m_elemSelectExpr1; + private Expression m_resSelectExpr1; + private LambdaExpression m_seedExpr; + private LambdaExpression m_accumulatorExpr; + private LambdaExpression m_recursiveAccumulatorExpr; + private Expression m_comparerExpr; + private object m_comparer; + private int m_comparerIdx; + private bool m_isPartial; + + internal DryadGroupByNode(string opName, + LambdaExpression keySelectExpr, + LambdaExpression elemSelectExpr, + LambdaExpression resSelectExpr, + LambdaExpression seedExpr, + LambdaExpression accumulateExpr, + LambdaExpression recursiveAccumulatorExpr, + Expression comparerExpr, + bool isPartial, + Expression queryExpr, + DryadQueryNode child) + : base(QueryNodeType.GroupBy, child.QueryGen, queryExpr, child) + { + Debug.Assert(opName == "GroupBy" || opName == "OrderedGroupBy"); + this.m_keySelectExpr = keySelectExpr; + this.m_elemSelectExpr = elemSelectExpr; + this.m_resSelectExpr = resSelectExpr; + this.m_seedExpr = seedExpr; + this.m_accumulatorExpr = accumulateExpr; + this.m_recursiveAccumulatorExpr = recursiveAccumulatorExpr; + this.m_comparerExpr = comparerExpr; + this.m_isPartial = isPartial; + this.m_opName = opName; + + this.m_comparer = null; + this.m_comparerIdx = -1; + if (comparerExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + this.m_comparer = evaluator.Eval(comparerExpr); + this.m_comparerIdx = HpcLinqObjectStore.Put(this.m_comparer); + } + + this.m_partitionCount = child.OutputDataSetInfo.partitionInfo.Count; + this.m_outputDataSetInfo = this.ComputeOutputDataSetInfo(isPartial); + + this.m_dynamicManager = this.InferDynamicManager(); + } + + internal override bool IsStateful + { + get { + return (this.m_opName == "GroupBy" && !this.m_isPartial); + } + } + + internal override Type[] OutputTypes + { + get { + Type keyType = this.m_keySelectExpr.Type.GetGenericArguments()[1]; + if (this.m_seedExpr != null) + { + Type elemType = this.m_seedExpr.Body.Type; + return new Type[] { typeof(Pair<,>).MakeGenericType(keyType, elemType) }; + } + else if (this.m_resSelectExpr == null) + { + // Note: The output type is the IGrouping interface. + Type elemType = this.Children[0].OutputTypes[0]; + if (this.m_elemSelectExpr != null) + { + elemType = this.m_elemSelectExpr.Type.GetGenericArguments()[1]; + } + Type groupingType = typeof(IGrouping<,>).MakeGenericType(keyType, elemType); + return new Type[] { groupingType }; + } + else + { + // Get the output type from the result selector expression + Type[] typeArgs = this.m_resSelectExpr.Type.GetGenericArguments(); + return new Type[] { typeArgs[2] }; + } + } + } + + private DataSetInfo ComputeOutputDataSetInfo(bool isLocalReduce) + { + // TBD: could do a bit better with DistinctInfo. + DataSetInfo childInfo = this.Children[0].OutputDataSetInfo; + + if (isLocalReduce) + { + // Partial aggregation node. No need to do anything. + PartitionInfo pinfo = new RandomPartition(this.m_partitionCount); + OrderByInfo oinfo = DataSetInfo.NoOrderBy; + DistinctInfo dinfo = DataSetInfo.NoDistinct; + return new DataSetInfo(pinfo, oinfo, dinfo); + } + else if (this.m_resSelectExpr == null || this.m_seedExpr != null) + { + // Build the new key selection expression (based on group key): + ParameterExpression param = Expression.Parameter(this.OutputTypes[0], "g"); + PropertyInfo propInfo = param.Type.GetProperty("Key"); + Expression body = Expression.Property(param, propInfo); + Type dType = typeof(Func<,>).MakeGenericType(param.Type, body.Type); + LambdaExpression keySelExpr = Expression.Lambda(dType, body, param); + + PartitionInfo pinfo = childInfo.partitionInfo.Create(keySelExpr); + OrderByInfo oinfo = DataSetInfo.NoOrderBy; + if (this.m_opName == "OrderedGroupBy") + { + oinfo = childInfo.orderByInfo.Create(keySelExpr); + } + DistinctInfo dinfo = DataSetInfo.NoDistinct; + return new DataSetInfo(pinfo, oinfo, dinfo); + } + else + { + ParameterExpression param = Expression.Parameter(this.m_keySelectExpr.Body.Type, "k"); + Type dType = typeof(Func<,>).MakeGenericType(param.Type, param.Type); + LambdaExpression keySelExpr = Expression.Lambda(dType, param, param); + PartitionInfo pinfo = childInfo.partitionInfo.Create(keySelExpr); + + pinfo = pinfo.Rewrite(this.m_resSelectExpr, this.m_resSelectExpr.Parameters[0]); + OrderByInfo oinfo = DataSetInfo.NoOrderBy; + if (this.m_opName == "OrderedGroupBy") + { + oinfo = childInfo.orderByInfo.Create(keySelExpr); + oinfo = oinfo.Rewrite(this.m_resSelectExpr, param); + } + DistinctInfo dinfo = DataSetInfo.NoDistinct; + return new DataSetInfo(pinfo, oinfo, dinfo); + } + } + + internal LambdaExpression KeySelectExpression + { + get { return this.m_keySelectExpr; } + } + + internal LambdaExpression ElemSelectExpression + { + get { return this.m_elemSelectExpr; } + } + + internal LambdaExpression ResSelectExpression + { + get { return this.m_resSelectExpr; } + } + + internal LambdaExpression SeedExpression + { + get { return this.m_seedExpr; } + } + + internal LambdaExpression AccumulatorExpression + { + get { return this.m_accumulatorExpr; } + } + + internal LambdaExpression RecursiveAccumulatorExpression + { + get { return this.m_recursiveAccumulatorExpr; } + } + + internal Expression ComparerExpression + { + get { return this.m_comparerExpr; } + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + if (this.m_resSelectExpr != null) + { + this.QueryGen.CodeGen.AddDryadCodeForType(this.m_resSelectExpr.Body.Type); + } + + CodeExpression comparerArg = HpcLinqCodeGen.NullExpr; + if (this.m_comparerExpr != null) + { + CodeExpression getCall = new CodeMethodInvokeExpression( + new CodeTypeReferenceExpression("HpcLinqObjectStore"), + "Get", + new CodePrimitiveExpression(this.m_comparerIdx)); + Type[] typeArgs = this.m_comparerExpr.Type.GetGenericArguments(); + Type comparerType = typeof(IEqualityComparer<>).MakeGenericType(typeArgs[0]); + comparerArg = new CodeCastExpression(comparerType, getCall); + } + + CodeExpression groupByExpr; + if (this.m_seedExpr == null) + { + if (this.m_elemSelectExpr == null) + { + if (this.m_resSelectExpr == null) + { + groupByExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + this.QueryGen.CodeGen.MakeExpression(this.m_keySelectExpr1), + comparerArg); + } + else + { + groupByExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + this.QueryGen.CodeGen.MakeExpression(this.m_keySelectExpr1), + this.QueryGen.CodeGen.MakeExpression(this.m_resSelectExpr1), + comparerArg); + } + } + else + { + if (this.m_resSelectExpr == null) + { + groupByExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + this.QueryGen.CodeGen.MakeExpression(this.m_keySelectExpr1), + this.QueryGen.CodeGen.MakeExpression(this.m_elemSelectExpr1), + comparerArg); + } + else + { + groupByExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + this.QueryGen.CodeGen.MakeExpression(this.m_keySelectExpr1), + this.QueryGen.CodeGen.MakeExpression(this.m_elemSelectExpr1), + this.QueryGen.CodeGen.MakeExpression(this.m_resSelectExpr1), + comparerArg); + } + } + } + else + { + // m_seedExpr != null + if (this.m_elemSelectExpr == null) + { + groupByExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + this.QueryGen.CodeGen.MakeExpression(this.m_keySelectExpr1), + this.QueryGen.CodeGen.MakeExpression(this.m_seedExpr), + this.QueryGen.CodeGen.MakeExpression(this.m_accumulatorExpr), + comparerArg, + new CodePrimitiveExpression(this.m_isPartial)); + } + else + { + groupByExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + this.QueryGen.CodeGen.MakeExpression(this.m_keySelectExpr1), + this.QueryGen.CodeGen.MakeExpression(this.m_elemSelectExpr1), + this.QueryGen.CodeGen.MakeExpression(this.m_seedExpr), + this.QueryGen.CodeGen.MakeExpression(this.m_accumulatorExpr), + comparerArg, + new CodePrimitiveExpression(this.m_isPartial)); + } + } + CodeVariableDeclarationStatement + sourceDecl = this.QueryGen.CodeGen.MakeVarDeclStatement("var", "source", groupByExpr); + vertexMethod.Statements.Add(sourceDecl); + return sourceDecl.Name; + } + + internal override void GetReferencedQueries(ReferencedQuerySubst subst) + { + this.m_keySelectExpr1 = subst.Visit(this.m_keySelectExpr); + if (this.m_elemSelectExpr != null) + { + this.m_elemSelectExpr1 = subst.Visit(this.m_elemSelectExpr); + } + if (this.m_resSelectExpr != null) + { + this.m_resSelectExpr1 = subst.Visit(this.m_resSelectExpr); + } + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + this.Children[0].BuildString(builder); + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_keySelectExpr, + this.QueryGen.CodeGen.AnonymousTypeToName)); + if (this.m_elemSelectExpr != null) + { + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_elemSelectExpr, + this.QueryGen.CodeGen.AnonymousTypeToName)); + } + if (this.m_comparerExpr != null) + { + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_comparerExpr, + this.QueryGen.CodeGen.AnonymousTypeToName)); + } + builder.Append("]"); + } + } + + internal class DryadPartitionOpNode : DryadQueryNode + { + private Expression m_controlExpression; + private bool m_isFirstStage; + private int m_count; + + internal DryadPartitionOpNode(string opName, + QueryNodeType nodeType, + Expression controlExpr, + bool isFirstStage, + Expression queryExpr, + DryadQueryNode child) + : base(nodeType, child.QueryGen, queryExpr, child) + { + this.m_controlExpression = controlExpr; + this.m_isFirstStage = isFirstStage; + this.m_opName = opName; + + this.m_count = -1; + if (nodeType == QueryNodeType.Take || nodeType == QueryNodeType.Skip) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + this.m_count = evaluator.Eval(controlExpr); + } + + DataSetInfo childInfo = child.OutputDataSetInfo; + if (isFirstStage) + { + this.m_partitionCount = child.OutputPartition.Count; + this.m_outputDataSetInfo = new DataSetInfo(child.OutputDataSetInfo); + this.m_dynamicManager = this.InferDynamicManager(); + } + else + { + this.m_partitionCount = 1; + this.m_outputDataSetInfo = new DataSetInfo(childInfo); + this.m_outputDataSetInfo.partitionInfo = DataSetInfo.OnePartition; + if (childInfo.partitionInfo.Count > 1 && + (childInfo.partitionInfo.ParType != PartitionType.Range || + childInfo.orderByInfo.IsOrdered)) + { + this.m_outputDataSetInfo.orderByInfo = DataSetInfo.NoOrderBy; + } + this.m_dynamicManager = DynamicManager.None; + } + } + + internal override bool ContainsMerge + { + get { return !this.m_isFirstStage; } + } + + internal override bool KeepInputPortOrder() + { + return !this.m_isFirstStage; + } + + internal bool IsMergeStage + { + get { return !this.m_isFirstStage; } + } + + internal override Type[] OutputTypes + { + get { + if (this.NodeType == QueryNodeType.TakeWhile) + { + return new Type[] { this.m_controlExpression.Type.GetGenericArguments()[0] }; + } + return this.Children[0].OutputTypes; + } + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + CodeExpression partitionExpr; + if (this.NodeType == QueryNodeType.TakeWhile) + { + partitionExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0])); + } + else + { + CodeExpression controlArg; + if (this.NodeType == QueryNodeType.Take || this.NodeType == QueryNodeType.Skip) + { + controlArg = new CodePrimitiveExpression(this.m_count); + } + else + { + controlArg = this.QueryGen.CodeGen.MakeExpression(this.m_controlExpression); + } + partitionExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + controlArg); + } + CodeVariableDeclarationStatement + sourceDecl = this.QueryGen.CodeGen.MakeVarDeclStatement("var", "source", partitionExpr); + vertexMethod.Statements.Add(sourceDecl); + return sourceDecl.Name; + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + this.Children[0].BuildString(builder); + builder.Append(", "); + if (this.NodeType == QueryNodeType.Take || this.NodeType == QueryNodeType.Skip) + { + builder.Append(Convert.ToString(this.m_count)); + } + else + { + builder.Append(HpcLinqExpression.ToCSharpString(this.m_controlExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + } + builder.Append("]"); + } + + internal Expression ControlExpression { get { return this.m_controlExpression; } } + } + + internal class DryadJoinNode : DryadQueryNode + { + private LambdaExpression m_outerKeySelectExpression; + private LambdaExpression m_innerKeySelectExpression; + private LambdaExpression m_resultSelectExpression; + private Expression m_outerKeySelectExpression1; + private Expression m_innerKeySelectExpression1; + private Expression m_resultSelectExpression1; + private Expression m_comparerExpression; + private object m_comparer; + private int m_comparerIdx; + private Pipeline m_attachedPipeline; + + internal DryadJoinNode(QueryNodeType nodeType, + string opName, + LambdaExpression outerKeySelectExpr, + LambdaExpression innerKeySelectExpr, + LambdaExpression resultSelectExpr, + Expression comparerExpr, + Expression queryExpr, + DryadQueryNode outerChild, + DryadQueryNode innerChild) + : base(nodeType, outerChild.QueryGen, queryExpr, outerChild, innerChild) + { + Debug.Assert(nodeType == QueryNodeType.Join || nodeType == QueryNodeType.GroupJoin); + this.m_outerKeySelectExpression = outerKeySelectExpr; + this.m_innerKeySelectExpression = innerKeySelectExpr; + this.m_resultSelectExpression = resultSelectExpr; + this.m_comparerExpression = comparerExpr; + this.m_opName = opName; + this.m_attachedPipeline = null; + + this.m_comparer = null; + this.m_comparerIdx = -1; + if (comparerExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + this.m_comparer = evaluator.Eval(comparerExpr); + this.m_comparerIdx = HpcLinqObjectStore.Put(m_comparer); + } + + if (StaticConfig.UseMemoryFIFO) + { + bool isStateful = this.IsStateful; + foreach (DryadQueryNode child in this.Children) + { + if (!(child is DryadInputNode) && + !child.IsForked && + !(isStateful && child.IsStateful) && + child.PartitionCount > 1) + { + child.ChannelType = ChannelType.MemoryFIFO; + } + isStateful = isStateful || child.IsStateful; + } + } + + this.m_partitionCount = outerChild.OutputDataSetInfo.partitionInfo.Count; + this.m_outputDataSetInfo = this.ComputeOutputDataSetInfo(); + + this.m_dynamicManager = DynamicManager.None; + } + + internal override bool IsStateful + { + get { return this.m_opName.StartsWith("Hash", StringComparison.Ordinal); } + } + + internal override Type[] OutputTypes + { + get { + Type resultType = this.m_resultSelectExpression.Type.GetGenericArguments()[2]; + return new Type[] { resultType }; + } + } + + internal override bool CanAttachPipeline + { + get { return true; } + } + + internal override Pipeline AttachedPipeline + { + set { this.m_attachedPipeline = value; } + } + + private DataSetInfo ComputeOutputDataSetInfo() + { + DataSetInfo leftChildInfo = this.Children[0].OutputDataSetInfo; + DataSetInfo rightChildInfo = this.Children[1].OutputDataSetInfo; + + ParameterExpression leftParam = this.m_resultSelectExpression.Parameters[0]; + ParameterExpression rightParam = this.m_resultSelectExpression.Parameters[1]; + PartitionInfo pinfo = leftChildInfo.partitionInfo.Rewrite(this.m_resultSelectExpression, leftParam); + if (pinfo is RandomPartition) + { + pinfo = rightChildInfo.partitionInfo.Rewrite(this.m_resultSelectExpression, rightParam); + } + OrderByInfo oinfo = leftChildInfo.orderByInfo.Rewrite(this.m_resultSelectExpression, leftParam); + if (!oinfo.IsOrdered) + { + oinfo = rightChildInfo.orderByInfo.Rewrite(this.m_resultSelectExpression, rightParam); + } + DistinctInfo dinfo = DataSetInfo.NoDistinct; + DistinctAttribute attrib1 = AttributeSystem.GetDistinctAttrib(this.m_resultSelectExpression); + if (attrib1 != null && + (!attrib1.MustBeDistinct || + (leftChildInfo.distinctInfo.IsDistinct() && rightChildInfo.distinctInfo.IsDistinct()))) + { + dinfo = DistinctInfo.Create(attrib1.Comparer, this.OutputTypes[0]); + } + + return new DataSetInfo(pinfo, oinfo, dinfo); + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + bool isHashJoin = this.OpName.StartsWith("Hash", StringComparison.Ordinal); + CodeExpression comparerArg = HpcLinqCodeGen.NullExpr; + if (this.m_comparerExpression != null) + { + CodeExpression getCall = new CodeMethodInvokeExpression( + new CodeTypeReferenceExpression("HpcLinqObjectStore"), + "Get", + new CodePrimitiveExpression(this.m_comparerIdx)); + Type[] typeArgs = this.m_comparerExpression.Type.GetGenericArguments(); + Type comparerType; + if (isHashJoin) + { + comparerType = typeof(IEqualityComparer<>).MakeGenericType(typeArgs[0]); + } + else + { + comparerType = typeof(IComparer<>).MakeGenericType(typeArgs[0]); + } + comparerArg = new CodeCastExpression(comparerType, getCall); + } + + CodeExpression joinExpr; + if (this.m_attachedPipeline != null && this.m_attachedPipeline.Length > 1) + { + Type paramType = typeof(IEnumerable<>).MakeGenericType(this.OutputTypes[0]); + ParameterExpression param = Expression.Parameter(paramType, HpcLinqCodeGen.MakeUniqueName("x")); + CodeExpression pipelineArg = this.m_attachedPipeline.BuildExpression(1, param, param); + if (isHashJoin) + { + joinExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + new CodeVariableReferenceExpression(readerNames[1]), + this.QueryGen.CodeGen.MakeExpression(this.m_outerKeySelectExpression1), + this.QueryGen.CodeGen.MakeExpression(this.m_innerKeySelectExpression1), + this.QueryGen.CodeGen.MakeExpression(this.m_resultSelectExpression1), + comparerArg, + pipelineArg); + } + else + { + bool isDescending = this.Children[0].OutputDataSetInfo.orderByInfo.IsDescending; + joinExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + new CodeVariableReferenceExpression(readerNames[1]), + this.QueryGen.CodeGen.MakeExpression(this.m_outerKeySelectExpression1), + this.QueryGen.CodeGen.MakeExpression(this.m_innerKeySelectExpression1), + this.QueryGen.CodeGen.MakeExpression(this.m_resultSelectExpression1), + comparerArg, + new CodePrimitiveExpression(isDescending), + pipelineArg); + } + } + else + { + if (isHashJoin) + { + joinExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + new CodeVariableReferenceExpression(readerNames[1]), + this.QueryGen.CodeGen.MakeExpression(this.m_outerKeySelectExpression1), + this.QueryGen.CodeGen.MakeExpression(this.m_innerKeySelectExpression1), + this.QueryGen.CodeGen.MakeExpression(this.m_resultSelectExpression1), + comparerArg); + } + else + { + bool isDescending = this.Children[0].OutputDataSetInfo.orderByInfo.IsDescending; + joinExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + new CodeVariableReferenceExpression(readerNames[1]), + this.QueryGen.CodeGen.MakeExpression(this.m_outerKeySelectExpression1), + this.QueryGen.CodeGen.MakeExpression(this.m_innerKeySelectExpression1), + this.QueryGen.CodeGen.MakeExpression(this.m_resultSelectExpression1), + comparerArg, + new CodePrimitiveExpression(isDescending)); + } + } + CodeVariableDeclarationStatement + sourceDecl = this.QueryGen.CodeGen.MakeVarDeclStatement("var", "source", joinExpr); + vertexMethod.Statements.Add(sourceDecl); + return sourceDecl.Name; + } + + internal LambdaExpression OuterKeySelectorExpression + { + get { return this.m_outerKeySelectExpression; } + } + + internal LambdaExpression InnerKeySelectorExpression + { + get { return this.m_innerKeySelectExpression; } + } + + internal LambdaExpression ResultSelectorExpression + { + get { return this.m_innerKeySelectExpression; } + } + + internal Expression ComparerExpression + { + get { return this.m_comparerExpression; } + } + + internal override void GetReferencedQueries(ReferencedQuerySubst subst) + { + this.m_outerKeySelectExpression1 = subst.Visit(this.m_outerKeySelectExpression); + this.m_innerKeySelectExpression1 = subst.Visit(this.m_innerKeySelectExpression); + this.m_resultSelectExpression1 = subst.Visit(this.m_resultSelectExpression); + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + this.Children[0].BuildString(builder); + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_outerKeySelectExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + builder.Append(", "); + this.Children[1].BuildString(builder); + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_innerKeySelectExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + builder.Append("]"); + } + } + + internal class DryadDistinctNode : DryadQueryNode + { + private bool m_isPartial; + private Expression m_comparerExpression; + private object m_comparer; + private int m_comparerIdx; + private Pipeline m_attachedPipeline; + + internal DryadDistinctNode(bool isPartial, + Expression comparerExpr, + Expression queryExpr, + DryadQueryNode child) + : base(QueryNodeType.Distinct, child.QueryGen, queryExpr, child) + { + this.m_isPartial = isPartial; + this.m_comparerExpression = comparerExpr; + this.m_opName = "Distinct"; + this.m_attachedPipeline = null; + + this.m_comparer = null; + this.m_comparerIdx = -1; + if (comparerExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + this.m_comparer = evaluator.Eval(comparerExpr); + this.m_comparerIdx = HpcLinqObjectStore.Put(this.m_comparer); + } + + this.m_partitionCount = child.OutputDataSetInfo.partitionInfo.Count; + + this.m_outputDataSetInfo = new DataSetInfo(child.OutputDataSetInfo); + this.m_outputDataSetInfo.distinctInfo = DistinctInfo.Create(this.m_comparer, this.OutputTypes[0]); + + this.m_dynamicManager = this.InferDynamicManager(); + } + + internal Expression ComparerExpression + { + get { return this.m_comparerExpression; } + } + + internal override bool IsStateful + { + get { return !this.m_isPartial; } + } + + internal override Type[] OutputTypes + { + get { return this.Children[0].OutputTypes; } + } + + internal override bool CanAttachPipeline + { + get { return true; } + } + + internal override Pipeline AttachedPipeline + { + set { this.m_attachedPipeline = value; } + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + CodeExpression compareArg = HpcLinqCodeGen.NullExpr; + if (this.m_comparerExpression != null) + { + CodeExpression getCall = new CodeMethodInvokeExpression( + new CodeTypeReferenceExpression("HpcLinqObjectStore"), + "Get", + new CodePrimitiveExpression(this.m_comparerIdx)); + Type[] typeArgs = this.m_comparerExpression.Type.GetGenericArguments(); + Type comparerType = typeof(IEqualityComparer<>).MakeGenericType(typeArgs[0]); + compareArg = new CodeCastExpression(comparerType, getCall); + } + CodeExpression distinctExpr; + if (this.m_attachedPipeline != null && this.m_attachedPipeline.Length > 1) + { + Type paramType = typeof(IEnumerable<>).MakeGenericType(this.OutputTypes[0]); + ParameterExpression param = Expression.Parameter(paramType, HpcLinqCodeGen.MakeUniqueName("x")); + CodeExpression pipelineArg = this.m_attachedPipeline.BuildExpression(1, param, param); + distinctExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + compareArg, + pipelineArg, + new CodePrimitiveExpression(this.m_isPartial)); + } + else + { + distinctExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + compareArg, + new CodePrimitiveExpression(this.m_isPartial)); + } + CodeVariableDeclarationStatement + sourceDecl = this.QueryGen.CodeGen.MakeVarDeclStatement("var", "source", distinctExpr); + vertexMethod.Statements.Add(sourceDecl); + return sourceDecl.Name; + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + this.Children[0].BuildString(builder); + if (this.m_comparerExpression != null) + { + builder.Append(","); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_comparerExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + } + builder.Append("]"); + } + } + + internal class DryadContainsNode : DryadQueryNode + { + private Expression m_valueExpression; + private Expression m_comparerExpression; + private int m_valueIdx; + private int m_comparerIdx; + + internal DryadContainsNode(Expression valueExpr, + Expression comparerExpr, + Expression queryExpr, + DryadQueryNode child) + : base(QueryNodeType.Contains, child.QueryGen, queryExpr, child) + { + this.m_valueExpression = valueExpr; + this.m_comparerExpression = comparerExpr; + this.m_opName = "Contains"; + + this.m_valueIdx = HpcLinqObjectStore.Put(ExpressionSimplifier.Evaluate(valueExpr)); + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + this.m_comparerIdx = -1; + if (comparerExpr != null) + { + this.m_comparerIdx = HpcLinqObjectStore.Put(evaluator.Eval(comparerExpr)); + } + + this.m_partitionCount = child.OutputDataSetInfo.partitionInfo.Count; + this.m_outputDataSetInfo = new DataSetInfo(); + this.m_outputDataSetInfo.partitionInfo = new RandomPartition(this.m_partitionCount); + + this.m_dynamicManager = this.InferDynamicManager(); + } + + internal Expression ValueExpression + { + get { return this.m_valueExpression; } + } + + internal Expression ComparerExpression + { + get { return this.m_comparerExpression; } + } + + internal override Type[] OutputTypes + { + get { return new Type[] { typeof(bool) }; } + } + + internal override Expression RebuildQueryExpression(Expression inputExpr) + { + MethodInfo minfo = typeof(HpcLinqObjectStore).GetMethod("Get"); + Expression getValueExpr = Expression.Call(minfo, Expression.Constant(this.m_valueIdx)); + Expression valueExpr = Expression.Convert(getValueExpr, this.m_valueExpression.Type); + + Expression comparerExpr; + if (this.m_comparerExpression == null) + { + Type comparerType = typeof(IEqualityComparer<>).MakeGenericType(this.InputTypes[0]); + comparerExpr = Expression.Constant(null, comparerType); + } + else + { + getValueExpr = Expression.Call(minfo, Expression.Constant(this.m_comparerIdx)); + comparerExpr = Expression.Convert(getValueExpr, this.m_comparerExpression.Type); + } + Type[] typeArgs = new Type[] { this.InputTypes[0] }; + Expression[] args = new Expression[] { inputExpr, valueExpr, comparerExpr }; + Expression resExpr = Expression.Call(typeof(HpcLinqVertex), this.m_opName, typeArgs, args); + minfo = typeof(HpcLinqVertex).GetMethod("AsEnumerable").MakeGenericMethod(resExpr.Type); + return Expression.Call(minfo, new Expression[] { resExpr }); + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + CodeExpression getValueCall = + new CodeMethodInvokeExpression(new CodeTypeReferenceExpression("HpcLinqObjectStore"), + "Get", + new CodePrimitiveExpression(this.m_valueIdx)); + CodeExpression valueArg = new CodeCastExpression(this.m_valueExpression.Type, getValueCall); + + CodeExpression comparerArg = HpcLinqCodeGen.NullExpr; + if (this.m_comparerExpression != null) + { + CodeExpression getComparerCall = new CodeMethodInvokeExpression( + new CodeTypeReferenceExpression("HpcLinqObjectStore"), + "Get", + new CodePrimitiveExpression(this.m_comparerIdx)); + Type[] typeArgs = this.m_comparerExpression.Type.GetGenericArguments(); + Type comparerType = typeof(IEqualityComparer<>).MakeGenericType(typeArgs[0]); + comparerArg = new CodeCastExpression(comparerType, getComparerCall); + } + CodeExpression containsExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + valueArg, + comparerArg); + containsExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + "AsEnumerable", + containsExpr); + CodeVariableDeclarationStatement + sourceDecl = this.QueryGen.CodeGen.MakeVarDeclStatement("var", "source", containsExpr); + vertexMethod.Statements.Add(sourceDecl); + return sourceDecl.Name; + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + this.Children[0].BuildString(builder); + if (this.m_valueExpression != null) + { + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_valueExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + } + if (this.m_comparerExpression != null) + { + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_comparerExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + } + builder.Append("]"); + } + } + + internal class DryadBasicAggregateNode : DryadQueryNode + { + private LambdaExpression m_selectExpression; + private AggregateOpType m_aggregateOpType; + private bool m_isFirstStage; + private bool m_isQuery; + + internal DryadBasicAggregateNode(LambdaExpression selectExpr, + AggregateOpType aggType, + bool isFirstStage, + bool isQuery, + Expression queryExpr, + DryadQueryNode child) + : base(QueryNodeType.BasicAggregate, child.QueryGen, queryExpr, child) + { + this.m_selectExpression = selectExpr; + this.m_aggregateOpType = aggType; + this.m_isFirstStage = isFirstStage; + this.m_isQuery = isQuery; + this.m_opName = aggType.ToString(); + + if (isFirstStage) + { + this.m_partitionCount = child.OutputDataSetInfo.partitionInfo.Count; + this.m_outputDataSetInfo = new DataSetInfo(); + this.m_outputDataSetInfo.partitionInfo = new RandomPartition(this.m_partitionCount); + this.m_dynamicManager = this.InferDynamicManager(); + } + else + { + this.m_partitionCount = 1; + this.m_outputDataSetInfo = new DataSetInfo(); + this.m_dynamicManager = DynamicManager.None; + } + } + + internal AggregateOpType OpType + { + get { return this.m_aggregateOpType; } + } + + internal LambdaExpression SelectExpression + { + get { return this.m_selectExpression; } + } + + internal override bool ContainsMerge + { + get { return !this.m_isFirstStage; } + } + + internal override bool KeepInputPortOrder() + { + return (!this.m_isFirstStage && + (this.m_aggregateOpType == AggregateOpType.First || + this.m_aggregateOpType == AggregateOpType.Last || + this.m_aggregateOpType == AggregateOpType.FirstOrDefault || + this.m_aggregateOpType == AggregateOpType.LastOrDefault)); + } + + internal bool IsMergeStage + { + get { return !this.m_isFirstStage; } + } + + internal override Type[] OutputTypes + { + get { + if (this.m_aggregateOpType == AggregateOpType.Count) + { + return new Type[] { typeof(Int32) }; + } + if (this.m_aggregateOpType == AggregateOpType.LongCount) + { + return new Type[] { typeof(Int64) }; + } + if (this.m_aggregateOpType == AggregateOpType.Any || + this.m_aggregateOpType == AggregateOpType.All) + { + return new Type[] { typeof(bool) }; + } + + Type qType = this.QueryExpression.Type; + if (this.m_isQuery) + { + qType = qType.GetGenericArguments()[0]; + } + + if (!this.m_isFirstStage) + { + if (this.m_aggregateOpType == AggregateOpType.FirstOrDefault || + this.m_aggregateOpType == AggregateOpType.SingleOrDefault || + this.m_aggregateOpType == AggregateOpType.LastOrDefault) + { + return new Type[] { typeof(AggregateValue<>).MakeGenericType(qType) }; + } + else + { + return new Type[] { qType }; + } + } + + switch (this.m_aggregateOpType) + { + case AggregateOpType.Sum: + { + return new Type[] { qType }; + } + case AggregateOpType.Min: + case AggregateOpType.Max: + { + if (qType == typeof(Int32?) || + qType == typeof(Int64?) || + qType == typeof(float?) || + qType == typeof(double?) || + qType == typeof(decimal?)) + { + return new Type[] { qType }; + } + return new Type[] { typeof(AggregateValue<>).MakeGenericType(qType) }; + } + case AggregateOpType.First: + case AggregateOpType.Single: + case AggregateOpType.Last: + case AggregateOpType.FirstOrDefault: + case AggregateOpType.SingleOrDefault: + case AggregateOpType.LastOrDefault: + { + return new Type[] { typeof(AggregateValue<>).MakeGenericType(qType) }; + } + case AggregateOpType.Average: + { + ParameterInfo[] paramInfos = ((MethodCallExpression)this.QueryExpression).Method.GetParameters(); + Type valueType; + if (this.m_selectExpression == null) + { + valueType = paramInfos[0].ParameterType.GetGenericArguments()[0]; + } + else + { + valueType = paramInfos[1].ParameterType.GetGenericArguments()[0].GetGenericArguments()[1]; + } + + if (valueType == typeof(int)) + { + valueType = typeof(long); + } + else if (valueType == typeof(int?)) + { + valueType = typeof(long?); + } + else if (valueType == typeof(float)) + { + valueType = typeof(double); + } + else if (valueType == typeof(float?)) + { + valueType = typeof(double?); + } + return new Type[] { typeof(AggregateValue<>).MakeGenericType(valueType) }; + } + default: + { + throw new DryadLinqException(HpcLinqErrorCode.AggregateOperatorNotSupported, + String.Format(SR.AggregateOperatorNotSupported, this.m_aggregateOpType)); + } + } + } + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + CodeExpression aggregateExpr; + if (this.m_selectExpression == null) + { + aggregateExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0])); + } + else + { + aggregateExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + this.QueryGen.CodeGen.MakeExpression(this.m_selectExpression)); + } + if (!this.m_isFirstStage && + this.m_aggregateOpType == AggregateOpType.Average && + ((this.OutputTypes[0] == typeof(float)) || (this.OutputTypes[0] == typeof(float?)))) + { + aggregateExpr = new CodeCastExpression(this.OutputTypes[0], aggregateExpr); + } + aggregateExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + "AsEnumerable", + aggregateExpr); + CodeVariableDeclarationStatement + sourceDecl = this.QueryGen.CodeGen.MakeVarDeclStatement("var", "source", aggregateExpr); + vertexMethod.Statements.Add(sourceDecl); + return sourceDecl.Name; + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + this.Children[0].BuildString(builder); + if (this.m_selectExpression != null) + { + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_selectExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + } + builder.Append("]"); + } + } + + internal class DryadAggregateNode : DryadQueryNode + { + // There are up to 3 levels of aggregations: + // Stage 1: first level is pipelined with the source computation; + // Stage 2: second level is an aggregation of first level and then a dynamic aggregation; + // Stage 3: third level is a single partition + private Expression m_seedExpression; + private LambdaExpression m_funcLambda; + private LambdaExpression m_combinerLambda; + private LambdaExpression m_resultLambda; + private int m_stage; + private bool m_isQuery; + private object m_seedValue; + private int m_seedIdx; + + internal DryadAggregateNode(string opName, + Expression seedExpr, + LambdaExpression funcLambda, + LambdaExpression combinerLambda, + LambdaExpression resultLambda, + int stage, + bool isQuery, + Expression queryExpr, + DryadQueryNode child, + bool functionIsExpensive) + : base(QueryNodeType.Aggregate, child.QueryGen, queryExpr, child) + { + this.m_seedExpression = seedExpr; + this.m_funcLambda = funcLambda; + this.m_combinerLambda = combinerLambda; + this.m_resultLambda = resultLambda; + this.m_stage = stage; + this.m_isQuery = isQuery; + this.m_opName = opName; + + this.m_seedValue = null; + this.m_seedIdx = -1; + if (seedExpr != null) + { + this.m_seedValue = ExpressionSimplifier.Evaluate(seedExpr); + if (!seedExpr.Type.IsPrimitive) + { + this.m_seedIdx = HpcLinqObjectStore.Put(m_seedValue); + } + } + + if (stage != 3) + { + this.m_partitionCount = child.OutputDataSetInfo.partitionInfo.Count; + this.m_outputDataSetInfo = new DataSetInfo(); + this.m_outputDataSetInfo.partitionInfo = new RandomPartition(this.m_partitionCount); + this.m_dynamicManager = this.InferDynamicManager(); + } + else + { + this.m_partitionCount = 1; + this.m_outputDataSetInfo = new DataSetInfo(); + if (functionIsExpensive) + { + DryadDynamicNode dnode = new DryadDynamicNode(DynamicManagerType.FullAggregator, child); + this.m_dynamicManager = new DynamicManager(DynamicManagerType.FullAggregator, dnode); + this.m_dynamicManager.AggregationLevels = 2; + } + else + { + this.m_dynamicManager = DynamicManager.None; + } + } + } + + // NOTE: because some stages may consume their inputs out of order, + // [Associative] may in fact also assume commutativity. + internal override bool ContainsMerge + { + get { return (this.m_stage == 3); } + } + + internal override bool KeepInputPortOrder() + { + return (this.m_stage == 3); + } + + internal bool IsMergeStage + { + get { return (this.m_stage == 3); } + } + + internal override Type[] OutputTypes + { + get { + Type resultType; + if (this.m_stage != 3) + { + resultType = this.m_funcLambda.Type.GetGenericArguments()[0]; + resultType = typeof(AggregateValue<>).MakeGenericType(resultType); + } + else + { + if (this.m_resultLambda != null) + { + resultType = this.m_resultLambda.Type.GetGenericArguments()[1]; + } + else + { + resultType = this.m_funcLambda.Type.GetGenericArguments()[0]; + } + } + return new Type[] { resultType }; + } + } + + internal override bool IsHomomorphic + { + get { return (this.m_stage == 1); } + } + + internal override Expression RebuildQueryExpression(Expression inputExpr) + { + Type[] typeArgs; + Expression[] args; + if (this.m_seedValue == null) + { + typeArgs = new Type[] { this.InputTypes[0] }; + args = new Expression[] { inputExpr, this.m_funcLambda}; + } + else + { + Expression seedExpr; + if (this.m_seedExpression.Type.IsPrimitive) + { + seedExpr = Expression.Constant(this.m_seedValue, this.m_seedExpression.Type); + } + else + { + MethodInfo minfo = typeof(HpcLinqObjectStore).GetMethod("Get"); + Expression getValueExpr = Expression.Call(minfo, Expression.Constant(this.m_seedIdx)); + seedExpr = Expression.Convert(getValueExpr, this.m_seedExpression.Type); + } + typeArgs = new Type[] { this.InputTypes[0], seedExpr.Type }; + args = new Expression[] { inputExpr, seedExpr, this.m_funcLambda}; + } + Expression resExpr = Expression.Call(typeof(HpcLinqVertex), this.m_opName, typeArgs, args); + return Expression.Call(typeof(HpcLinqVertex).GetMethod("AsEnumerable").MakeGenericMethod(resExpr.Type), + new Expression[] { resExpr }); + } + + internal LambdaExpression FuncLambda + { + get { return this.m_funcLambda; } + } + + internal LambdaExpression CombinerLambda + { + get { return this.m_combinerLambda; } + } + + internal LambdaExpression ResultLambda + { + get { return this.m_resultLambda; } + } + + internal Expression SeedExpression + { + get { return this.m_seedExpression; } + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + CodeExpression funcArg = null; + if (this.m_stage != 1 && this.OpName.Contains("Assoc")) + { + funcArg = this.QueryGen.CodeGen.MakeExpression(this.m_combinerLambda); + } + else + { + funcArg = this.QueryGen.CodeGen.MakeExpression(this.m_funcLambda); + } + + CodeExpression seedArg = null; + if (this.m_seedExpression != null && + ((this.m_stage == 1 && this.OpName == "AssocAggregate") || + (this.m_stage == 3 && this.OpName == "Aggregate"))) + { + if (this.m_seedExpression.Type.IsPrimitive) + { + seedArg = new CodePrimitiveExpression(this.m_seedValue); + } + else + { + CodeExpression getCall = + new CodeMethodInvokeExpression(new CodeTypeReferenceExpression("HpcLinqObjectStore"), + "Get", + new CodePrimitiveExpression(this.m_seedIdx)); + seedArg = new CodeCastExpression(this.m_seedExpression.Type, getCall); + } + } + + CodeExpression resultArg = null; + if (this.m_stage == 3 && this.m_resultLambda != null) + { + resultArg = this.QueryGen.CodeGen.MakeExpression(this.m_resultLambda); + } + + CodeExpression aggregateExpr; + if (seedArg == null) + { + if (resultArg == null) + { + aggregateExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + this.m_opName, + new CodeVariableReferenceExpression(readerNames[0]), + funcArg); + } + else + { + aggregateExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + this.m_opName, + new CodeVariableReferenceExpression(readerNames[0]), + funcArg, + resultArg); + } + } + else + { + if (resultArg == null) + { + aggregateExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + this.m_opName, + new CodeVariableReferenceExpression(readerNames[0]), + seedArg, + funcArg); + } + else + { + aggregateExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + this.m_opName, + new CodeVariableReferenceExpression(readerNames[0]), + seedArg, + funcArg, + resultArg); + } + } + aggregateExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + "AsEnumerable", + aggregateExpr); + CodeVariableDeclarationStatement + sourceDecl = this.QueryGen.CodeGen.MakeVarDeclStatement("var", "source", aggregateExpr); + vertexMethod.Statements.Add(sourceDecl); + return sourceDecl.Name; + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + this.Children[0].BuildString(builder); + if (this.m_seedExpression != null) + { + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_seedExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + } + if (this.m_funcLambda != null) + { + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_funcLambda, + this.QueryGen.CodeGen.AnonymousTypeToName)); + } + if (this.m_resultLambda != null) + { + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_resultLambda, + this.QueryGen.CodeGen.AnonymousTypeToName)); + } + builder.Append("]"); + } + } + + internal class DryadConcatNode : DryadQueryNode + { + internal DryadConcatNode(Expression queryExpr, params DryadQueryNode[] children) + : base(QueryNodeType.Concat, children[0].QueryGen, queryExpr) + { + this.m_opName = "Concat"; + List childList = new List(); + this.m_partitionCount = 0; + foreach (DryadQueryNode child in children) + { + if ((child is DryadConcatNode) && !child.IsForked) + { + foreach (DryadQueryNode cc in child.Children) + { + cc.UpdateParent(child, this); + childList.Add(cc); + } + } + else + { + child.Parents.Add(this); + childList.Add(child); + } + this.m_partitionCount += child.OutputDataSetInfo.partitionInfo.Count; + } + this.Children = childList.ToArray(); + + this.m_outputDataSetInfo = new DataSetInfo(); + this.m_outputDataSetInfo.partitionInfo = new RandomPartition(this.m_partitionCount); + + this.m_dynamicManager = DynamicManager.None; + } + + internal override Type[] OutputTypes + { + get { return this.Children[0].OutputTypes; } + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + throw new InvalidOperationException(); + } + + internal void FixInputs() + { + for (int i = 0; i < this.Children.Length; i++) + { + DryadQueryNode child = this.Children[i]; + + if ((child is DryadInputNode) || child.IsForked) + { + // Insert a dummy Apply + Type paramType = typeof(IEnumerable<>).MakeGenericType(child.OutputTypes[0]); + ParameterExpression param = Expression.Parameter(paramType, "x"); + Type type = typeof(Func<,>).MakeGenericType(paramType, paramType); + LambdaExpression applyExpr = Expression.Lambda(type, param, param); + this.Children[i] = new DryadApplyNode(applyExpr, child.QueryExpression, child); + this.Children[i].OutputDataSetInfo = child.OutputDataSetInfo; + this.Children[i].Parents.Add(this); + child.Parents.Remove(this); + } + } + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + for (int i = 0; i < this.Children.Length; ++i) + { + this.Children[i].BuildString(builder); + if (i < this.Children.Length - 1) + { + builder.Append(","); + } + } + builder.Append("]"); + } + } + + internal class DryadSetOperationNode : DryadQueryNode + { + // Inv: The children are both either ordered or unordered + private bool m_isOrdered; + private Expression m_comparerExpression; + private object m_comparer; + private int m_comparerIdx; + + internal DryadSetOperationNode(QueryNodeType nodeType, + string opName, + Expression comparerExpr, + Expression queryExpr, + DryadQueryNode child1, + DryadQueryNode child2) + : base(nodeType, child1.QueryGen, queryExpr, child1, child2) + { + this.m_isOrdered = opName.StartsWith("Ordered", StringComparison.Ordinal); + this.m_opName = opName; + this.m_comparerExpression = comparerExpr; + + this.m_comparer = null; + this.m_comparerIdx = -1; + if (comparerExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + this.m_comparer = evaluator.Eval(comparerExpr); + this.m_comparerIdx = HpcLinqObjectStore.Put(m_comparer); + } + + if (StaticConfig.UseMemoryFIFO) + { + bool isStateful = this.IsStateful; + foreach (DryadQueryNode child in this.Children) + { + if (!(child is DryadInputNode) && + !child.IsForked && + !(isStateful && child.IsStateful) && + child.PartitionCount > 1) + { + child.ChannelType = ChannelType.MemoryFIFO; + } + isStateful = isStateful || child.IsStateful; + } + } + + this.m_partitionCount = child1.OutputDataSetInfo.partitionInfo.Count; + this.m_outputDataSetInfo = new DataSetInfo(child1.OutputDataSetInfo); + + this.m_dynamicManager = DynamicManager.None; + } + + internal override bool IsStateful + { + get { return !this.m_isOrdered; } + } + + internal override Type[] OutputTypes + { + get { return this.Children[0].OutputTypes; } + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + CodeExpression comparerArg = HpcLinqCodeGen.NullExpr; + if (this.m_comparerExpression != null) + { + CodeExpression getCall = + new CodeMethodInvokeExpression(new CodeTypeReferenceExpression("HpcLinqObjectStore"), + "Get", + new CodePrimitiveExpression(this.m_comparerIdx)); + Type[] typeArgs = this.m_comparerExpression.Type.GetGenericArguments(); + Type comparerType; + if (m_isOrdered) + { + comparerType = typeof(IComparer<>).MakeGenericType(typeArgs[0]); + } + else + { + comparerType = typeof(IEqualityComparer<>).MakeGenericType(typeArgs[0]); + } + + comparerArg = new CodeCastExpression(comparerType, getCall); + } + + CodeExpression setOpExpr; + if (this.m_isOrdered) + { + bool isDescending = this.Children[0].OutputDataSetInfo.orderByInfo.IsDescending; + setOpExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + new CodeVariableReferenceExpression(readerNames[1]), + comparerArg, + new CodePrimitiveExpression(isDescending)); + } + else + { + setOpExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + new CodeVariableReferenceExpression(readerNames[1]), + comparerArg); + } + CodeVariableDeclarationStatement + sourceDecl = this.QueryGen.CodeGen.MakeVarDeclStatement("var", "source", setOpExpr); + vertexMethod.Statements.Add(sourceDecl); + return sourceDecl.Name; + } + + internal Expression ComparerExpression + { + get { return this.m_comparerExpression; } + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + this.Children[0].BuildString(builder); + builder.Append(", "); + this.Children[1].BuildString(builder); + builder.Append("]"); + } + } + + // Merging of multiple data channels. + // This can operate in two principal modes: + // 1. "arbitrary" merge .. no port-ordering and no rules about the resulting sequence. + // 2. "sorted" merge. port-ordering and merge-sort logic. Requires the input channels be sorted. + internal class DryadMergeNode : DryadQueryNode + { + private LambdaExpression m_keySelectExpression; + private Expression m_comparerExpression; + private bool m_isDescending; + private bool m_keepPortOrder; + private bool m_isTemp; + private bool m_isSplitting; + private object m_comparer; + private int m_comparerIdx; + + internal DryadMergeNode(bool keepPortOrder, bool isTemp, Expression queryExpr, DryadQueryNode child) + : base(QueryNodeType.Merge, child.QueryGen, queryExpr, child) + { + OrderByInfo childInfo = child.OutputDataSetInfo.orderByInfo; + LambdaExpression keySelectExpr = childInfo.KeySelector; + Expression comparerExpr = childInfo.Comparer; + bool isDescending = childInfo.IsDescending; + this.m_keepPortOrder = keepPortOrder; + this.Initialize(keySelectExpr, comparerExpr, isDescending, isTemp, child, queryExpr); + } + + internal DryadMergeNode(LambdaExpression keySelectExpr, + Expression comparerExpr, + bool isDescending, + bool isTemp, + Expression queryExpr, + DryadQueryNode child) + : base(QueryNodeType.Merge, child.QueryGen, queryExpr, child) + { + this.m_keepPortOrder = true; + this.Initialize(keySelectExpr, comparerExpr, isDescending, isTemp, child, queryExpr); + } + + // This overload specifically and only supports the Left-homomorphic binary apply query plan. + // Unlike a "full-merge" of a dataset to one partition, this creates n nodes each collating the complete data. + internal DryadMergeNode(Int32 parCount, Expression queryExpr, DryadQueryNode child) + : base(QueryNodeType.Merge, child.QueryGen, queryExpr, child) + { + this.m_opName = "Merge"; + this.m_keySelectExpression = null; + this.m_comparerExpression = null; + this.m_isDescending = false; + this.m_keepPortOrder = false; + this.m_isTemp = false; + this.m_comparer = null; + this.m_comparerIdx = -1; + this.m_dynamicManager = DynamicManager.None; + this.m_partitionCount = parCount; + PartitionInfo pinfo = new RandomPartition(parCount); + this.m_outputDataSetInfo = new DataSetInfo(pinfo, DataSetInfo.NoOrderBy, DataSetInfo.NoDistinct); + this.m_isSplitting = false; + } + + private void Initialize(LambdaExpression keySelectExpr, + Expression comparerExpr, + bool isDescending, + bool isTemp, + DryadQueryNode child, + Expression queryExpr) + { + this.m_opName = (keySelectExpr == null) ? "Merge" : "MergeSort"; + this.m_keySelectExpression = keySelectExpr; + this.m_comparerExpression = comparerExpr; + this.m_isDescending = isDescending; + this.m_isTemp = isTemp; + + if (keySelectExpr != null) + { + Type keyType = keySelectExpr.Type.GetGenericArguments()[1]; + if (comparerExpr == null && !TypeSystem.HasDefaultComparer(keyType)) + { + throw DryadLinqException.Create(HpcLinqErrorCode.ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable, + String.Format(SR.ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable, keyType), + queryExpr); + } + } + + this.m_comparer = null; + this.m_comparerIdx = -1; + if (comparerExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + this.m_comparer = evaluator.Eval(comparerExpr); + this.m_comparerIdx = HpcLinqObjectStore.Put(this.m_comparer); + } + + this.m_dynamicManager = DynamicManager.None; + if (this.OpName == "MergeSort" && StaticConfig.UseSMBAggregation) + { + DryadDynamicNode dnode = new DryadDynamicNode(DynamicManagerType.FullAggregator, this); + this.m_dynamicManager = new DynamicManager(DynamicManagerType.FullAggregator, dnode); + } + + DryadQueryNode child1 = child; + if ((child is DryadHashPartitionNode) && + ((DryadHashPartitionNode)child).IsDynamicDistributor) + { + child1 = child.Children[0]; + this.Children[0] = child1; + bool found = child1.UpdateParent(child, this); + this.m_dynamicManager = this.m_dynamicManager.CreateManager(DynamicManagerType.HashDistributor); + this.m_dynamicManager.InsertVertexNode(-1, child); + } + + DataSetInfo childInfo = child1.OutputDataSetInfo; + PartitionInfo pinfo; + if (child1.ConOpType == ConnectionOpType.CrossProduct) + { + this.m_partitionCount = childInfo.partitionInfo.Count; + pinfo = childInfo.partitionInfo; + } + else + { + this.m_partitionCount = 1; + pinfo = DataSetInfo.OnePartition; + } + + DistinctInfo dinfo = childInfo.distinctInfo; + OrderByInfo oinfo = DataSetInfo.NoOrderBy; + if (this.OpName == "MergeSort") + { + Type[] typeArgs = this.m_keySelectExpression.Type.GetGenericArguments(); + oinfo = OrderByInfo.Create(this.m_keySelectExpression, this.m_comparer, this.m_isDescending, typeArgs[1]); + } + this.m_outputDataSetInfo = new DataSetInfo(pinfo, oinfo, dinfo); + + this.m_isSplitting = (((child is DryadHashPartitionNode) && + ((DryadHashPartitionNode)child).IsDynamicDistributor) || + ((child is DryadRangePartitionNode) && + ((DryadRangePartitionNode)child).IsDynamicDistributor)); + } + + internal Expression ComparerExpression { get { return this.m_comparerExpression; } } + + internal override bool KeepInputPortOrder() + { + return this.m_keepPortOrder; + } + + internal bool IsTemp + { + get { return this.m_isTemp; } + } + + internal bool IsSplitting + { + get { return this.m_isSplitting; } + } + + internal override bool ContainsMerge + { + get { return true; } + } + + internal override Type[] OutputTypes + { + get { return this.Children[0].OutputTypes; } + } + + internal void AddAggregateNode(DryadQueryNode node) + { + switch (this.m_dynamicManager.ManagerType) + { + case DynamicManagerType.None: + { + DryadDynamicNode dnode = new DryadDynamicNode(DynamicManagerType.FullAggregator, this); + this.m_dynamicManager = new DynamicManager(DynamicManagerType.FullAggregator, dnode); + break; + } + case DynamicManagerType.HashDistributor: + { + DryadQueryNode firstVertex = this.m_dynamicManager.GetVertexNode(0); + DryadDynamicNode dnode = firstVertex as DryadDynamicNode; + if (dnode == null || dnode.DynamicType != DynamicManagerType.FullAggregator) + { + dnode = new DryadDynamicNode(DynamicManagerType.FullAggregator, this); + this.m_dynamicManager.InsertVertexNode(0, dnode); + } + break; + } + case DynamicManagerType.FullAggregator: + { + break; + } + default: + { + //@@TODO: this should not be reachable. could change to Assert/InvalidOpEx + throw new DryadLinqException(HpcLinqErrorCode.Internal, + String.Format(SR.DynamicManagerType, + this.m_dynamicManager.ManagerType)); + } + } + ((DryadDynamicNode)this.m_dynamicManager.GetVertexNode(0)).AddNode(node); + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + Debug.Assert(this.OpName == "Merge" || this.OpName == "MergeSort"); + + CodeExpression mergeExpr = null; + if (this.OpName == "Merge") + { + mergeExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0])); + } + else + { + CodeExpression comparerArg = HpcLinqCodeGen.NullExpr; + if (this.m_comparerExpression != null) + { + CodeExpression getCall = + new CodeMethodInvokeExpression(new CodeTypeReferenceExpression("HpcLinqObjectStore"), + "Get", + new CodePrimitiveExpression(this.m_comparerIdx)); + Type[] typeArgs = this.m_comparerExpression.Type.GetGenericArguments(); + Type comparerType = typeof(IComparer<>).MakeGenericType(typeArgs[0]); + comparerArg = new CodeCastExpression(comparerType, getCall); + } + + mergeExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + this.QueryGen.CodeGen.MakeExpression(this.m_keySelectExpression), + comparerArg, + new CodePrimitiveExpression(this.m_isDescending)); + } + + CodeVariableDeclarationStatement + sourceDecl = this.QueryGen.CodeGen.MakeVarDeclStatement("var", "source", mergeExpr); + vertexMethod.Statements.Add(sourceDecl); + return sourceDecl.Name; + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + this.Children[0].BuildString(builder); + builder.Append("]"); + } + } + + // Hash partition of a dataset + internal class DryadHashPartitionNode : DryadQueryNode + { + private LambdaExpression m_keySelectExpression; + private Expression m_keySelectExpression1; + private LambdaExpression m_resultSelectExpression; + private int m_parCount; + private Expression m_comparerExpression; + private object m_comparer; + private int m_comparerIdx; + private bool m_isDynamic; + + internal DryadHashPartitionNode(LambdaExpression keySelectExpr, + Expression comparerExpr, + int count, + Expression queryExpr, + DryadQueryNode child) + : this(keySelectExpr, comparerExpr, count, false, queryExpr, child) + { + } + + internal DryadHashPartitionNode(LambdaExpression keySelectExpr, + Expression comparerExpr, + int count, + bool isDynamic, + Expression queryExpr, + DryadQueryNode child) + : this(keySelectExpr, null, comparerExpr, count, isDynamic, queryExpr, child) + { + } + + internal DryadHashPartitionNode(LambdaExpression keySelectExpr, + LambdaExpression resultSelectExpr, + Expression comparerExpr, + int count, + bool isDynamic, + Expression queryExpr, + DryadQueryNode child) + : base(QueryNodeType.HashPartition, child.QueryGen, queryExpr, child) + { + this.m_keySelectExpression = keySelectExpr; + this.m_resultSelectExpression = resultSelectExpr; + this.m_parCount = count; + this.m_isDynamic = isDynamic; + this.m_comparerExpression = comparerExpr; + this.m_opName = "HashPartition"; + this.m_conOpType = ConnectionOpType.CrossProduct; + + this.m_comparer = null; + this.m_comparerIdx = -1; + if (comparerExpr != null) + { + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + this.m_comparer = evaluator.Eval(comparerExpr); + this.m_comparerIdx = HpcLinqObjectStore.Put(this.m_comparer); + } + + this.m_partitionCount = child.OutputDataSetInfo.partitionInfo.Count; + + Type keyType = child.OutputTypes[0]; + if (this.m_keySelectExpression != null) + { + keyType = this.m_keySelectExpression.Type.GetGenericArguments()[1]; + } + DataSetInfo childInfo = child.OutputDataSetInfo; + PartitionInfo pInfo = PartitionInfo.CreateHash(this.m_keySelectExpression, + this.m_parCount, + this.m_comparer, + keyType); + OrderByInfo oinfo = childInfo.orderByInfo; + DistinctInfo dinfo = childInfo.distinctInfo; + this.m_outputDataSetInfo = new DataSetInfo(pInfo, oinfo, dinfo); + + this.m_dynamicManager = this.InferDynamicManager(); + } + + internal override Type[] OutputTypes + { + get { + if (this.m_resultSelectExpression != null) + { + return new Type[] { this.m_resultSelectExpression.Body.Type }; + } + return this.Children[0].OutputTypes; + } + } + + internal bool IsDynamicDistributor + { + get { return this.m_isDynamic; } + } + + internal LambdaExpression KeySelectExpression + { + get { return m_keySelectExpression; } + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + CodeExpression comparerArg = HpcLinqCodeGen.NullExpr; + if (this.m_comparerIdx != -1) + { + CodeExpression getComparerCall = + new CodeMethodInvokeExpression(new CodeTypeReferenceExpression("HpcLinqObjectStore"), + "Get", + new CodePrimitiveExpression(this.m_comparerIdx)); + Type[] typeArgs = this.m_comparerExpression.Type.GetGenericArguments(); + Type comparerType = typeof(IEqualityComparer<>).MakeGenericType(typeArgs[0]); + comparerArg = new CodeCastExpression(comparerType, getComparerCall); + } + + CodeExpression distributeExpr; + if (this.m_keySelectExpression == null) + { + if (this.m_resultSelectExpression == null) + { + distributeExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + comparerArg, + new CodeVariableReferenceExpression(writerNames[0])); + } + else + { + distributeExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + comparerArg, + this.QueryGen.CodeGen.MakeExpression(this.m_resultSelectExpression), + new CodeVariableReferenceExpression(writerNames[0])); + } + } + else + { + ExpressionInfo einfo = new ExpressionInfo(this.m_keySelectExpression1); + if (this.m_resultSelectExpression == null) + { + distributeExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + this.QueryGen.CodeGen.MakeExpression(this.m_keySelectExpression1), + new CodePrimitiveExpression(einfo.IsExpensive), + comparerArg, + new CodeVariableReferenceExpression(writerNames[0])); + } + else + { + distributeExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + this.QueryGen.CodeGen.MakeExpression(this.m_keySelectExpression1), + new CodePrimitiveExpression(einfo.IsExpensive), + comparerArg, + this.QueryGen.CodeGen.MakeExpression(this.m_resultSelectExpression), + new CodeVariableReferenceExpression(writerNames[0])); + } + } + vertexMethod.Statements.Add(distributeExpr); + return null; + } + + internal override void GetReferencedQueries(ReferencedQuerySubst subst) + { + if (this.m_keySelectExpression != null) + { + this.m_keySelectExpression1 = subst.Visit(this.m_keySelectExpression); + } + } + + internal int NumberOfPartitions { get { return this.m_parCount; } } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[HashPartition "); + this.Children[0].BuildString(builder); + builder.Append(","); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_keySelectExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + builder.Append(","); + builder.Append(Convert.ToString(this.m_parCount)); + if (this.m_comparerIdx != -1) + { + builder.Append(","); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_comparerExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + } + builder.Append("]"); + } + } + + // Range partition of a dataset + internal class DryadRangePartitionNode : DryadQueryNode + { + private LambdaExpression m_keySelectExpression; + private Expression m_keySelectExpression1; + private LambdaExpression m_resultSelectExpression; + private Expression m_keysExpression; + private Expression m_comparerExpression; + private Expression m_isDescendingExpression; + private Expression m_countExpression; + private object m_keys; + private int m_keysIdx; + private object m_comparer; + private int m_comparerIdx; + private bool? m_isDescending; + private int m_count; + + //Creates a "Range distribution" Node + internal DryadRangePartitionNode(LambdaExpression keySelectExpr, + LambdaExpression resultSelectExpr, + Expression keysExpr, + Expression comparerExpr, + Expression isDescendingExpr, + Expression countExpr, + Expression queryExpr, + params DryadQueryNode[] children) + : base(QueryNodeType.RangePartition, children[0].QueryGen, queryExpr, children) + { + this.m_keySelectExpression = keySelectExpr; + this.m_resultSelectExpression = resultSelectExpr; + this.m_keysExpression = keysExpr; + this.m_comparerExpression = comparerExpr; + this.m_isDescendingExpression = isDescendingExpr; + this.m_countExpression = countExpr; + this.m_opName = "RangePartition"; + this.m_conOpType = ConnectionOpType.CrossProduct; + + this.m_isDescending = null; + if (this.m_isDescendingExpression != null) + { + ExpressionSimplifier bevaluator = new ExpressionSimplifier(); + this.m_isDescending = bevaluator.Eval(isDescendingExpr); + } + + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + this.m_keys = null; + this.m_keysIdx = -1; + if (keysExpr != null) + { + this.m_keys = evaluator.Eval(keysExpr); + this.m_keysIdx = HpcLinqObjectStore.Put(m_keys); + } + + this.m_comparer = null; + this.m_comparerIdx = -1; + if (comparerExpr != null) + { + this.m_comparer = evaluator.Eval(comparerExpr); + this.m_comparerIdx = HpcLinqObjectStore.Put(m_comparer); + } + + this.m_count = 1; + if (countExpr != null) + { + ExpressionSimplifier ievaluator = new ExpressionSimplifier(); + this.m_count = ievaluator.Eval(countExpr); + } + + this.m_partitionCount = this.Children[0].OutputDataSetInfo.partitionInfo.Count; + + DataSetInfo childInfo = this.Children[0].OutputDataSetInfo; + Type keyType = this.m_keySelectExpression.Type.GetGenericArguments()[1]; + PartitionInfo pInfo = PartitionInfo.CreateRange(this.m_keySelectExpression, + this.m_keys, + this.m_comparer, + this.m_isDescending, + this.m_count, + keyType); + OrderByInfo oinfo = childInfo.orderByInfo; + DistinctInfo dinfo = childInfo.distinctInfo; + this.m_outputDataSetInfo = new DataSetInfo(pInfo, oinfo, dinfo); + this.m_dynamicManager = this.InferDynamicManager(); + } + + internal override Type[] OutputTypes + { + get { + if (this.m_resultSelectExpression != null) + { + return new Type[] { this.m_resultSelectExpression.Body.Type }; + } + return this.Children[0].OutputTypes; + } + } + + internal Expression KeysExpression + { + get { return this.m_keysExpression; } + } + internal Expression ComparerExpression { get { return this.m_comparerExpression; } } + internal Expression CountExpression { get { return this.m_countExpression; } } + + internal bool IsDynamicDistributor + { + get { return this.m_countExpression == null; } + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + CodeExpression rangeKeys; + if (this.m_keys == null) + { + rangeKeys = new CodeVariableReferenceExpression(readerNames[1]); + } + else + { + rangeKeys = new CodeMethodInvokeExpression(new CodeTypeReferenceExpression("HpcLinqObjectStore"), + "Get", + new CodePrimitiveExpression(this.m_keysIdx)); + rangeKeys = new CodeCastExpression(this.m_keysExpression.Type, rangeKeys); + } + CodeExpression sinkExpr = new CodeVariableReferenceExpression(writerNames[0]); + CodeExpression comparerArg = HpcLinqCodeGen.NullExpr; + if (this.m_comparerIdx != -1) + { + CodeExpression getComparerCall = + new CodeMethodInvokeExpression(new CodeTypeReferenceExpression("HpcLinqObjectStore"), + "Get", + new CodePrimitiveExpression(this.m_comparerIdx)); + Type[] typeArgs = this.m_comparerExpression.Type.GetGenericArguments(); + Type comparerType = typeof(IComparer<>).MakeGenericType(typeArgs[0]); + comparerArg = new CodeCastExpression(comparerType, getComparerCall); + } + CodeExpression isDescendingArg = new CodePrimitiveExpression(this.OutputPartition.IsDescending); + CodeExpression distributeExpr; + if (this.m_keySelectExpression == null) + { + if (this.m_resultSelectExpression == null) + { + distributeExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + rangeKeys, + comparerArg, + isDescendingArg, + sinkExpr); + } + else + { + distributeExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + rangeKeys, + comparerArg, + isDescendingArg, + this.QueryGen.CodeGen.MakeExpression(this.m_resultSelectExpression), + sinkExpr); + } + } + else + { + ExpressionInfo einfo = new ExpressionInfo(this.m_keySelectExpression1); + if (this.m_resultSelectExpression == null) + { + distributeExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + this.QueryGen.CodeGen.MakeExpression(this.m_keySelectExpression1), + new CodePrimitiveExpression(einfo.IsExpensive), + rangeKeys, + comparerArg, + isDescendingArg, + sinkExpr); + } + else + { + distributeExpr = new CodeMethodInvokeExpression( + HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + this.QueryGen.CodeGen.MakeExpression(this.m_keySelectExpression1), + new CodePrimitiveExpression(einfo.IsExpensive), + rangeKeys, + comparerArg, + isDescendingArg, + this.QueryGen.CodeGen.MakeExpression(this.m_resultSelectExpression), + sinkExpr); + } + } + vertexMethod.Statements.Add(distributeExpr); + return null; + } + + internal override void GetReferencedQueries(ReferencedQuerySubst subst) + { + if (this.m_keySelectExpression != null) + { + this.m_keySelectExpression1 = subst.Visit(this.m_keySelectExpression); + } + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[RangePartition "); + this.Children[0].BuildString(builder); + builder.Append(","); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_keySelectExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + builder.Append(","); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_keysExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + if (this.m_comparerIdx != -1) + { + builder.Append(","); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_comparerExpression, + this.QueryGen.CodeGen.AnonymousTypeToName)); + } + builder.Append("]"); + } + } + + // A super node encapsulates a subtree of the query tree into a single + // vertex. It could have arbitrary number of inputs and outputs. + internal class DryadSuperNode : DryadQueryNode + { + private DryadQueryNode m_rootNode; + private bool m_isStateful; + private bool m_containsMerge; + + internal DryadSuperNode(DryadQueryNode root) + : base(QueryNodeType.Super, root.QueryGen, root.QueryExpression) + { + this.ChannelType = root.ChannelType; + this.m_conOpType = root.ConOpType; + this.m_rootNode = root; + this.IsForked = root.IsForked; + root.SuperNode = this; + this.m_isStateful = root.IsStateful; + this.m_containsMerge = false; + foreach (DryadQueryNode child in root.Children) + { + if (!(child is DryadInputNode)) + { + this.m_isStateful = this.m_isStateful || child.IsStateful; + this.m_containsMerge = this.m_containsMerge || child.ContainsMerge; + } + } + + this.Parents.AddRange(root.Parents); + this.m_partitionCount = root.PartitionCount; + this.m_outputDataSetInfo = root.OutputDataSetInfo; + this.m_dynamicManager = root.Children[0].DynamicManager; + } + + internal DryadQueryNode RootNode + { + get { return this.m_rootNode; } + } + + internal void SwitchTo(DryadSuperNode node) + { + this.SwitchTo(this.m_rootNode, node); + } + + private void SwitchTo(DryadQueryNode curNode, DryadSuperNode node) + { + if (curNode.SuperNode == this) + { + curNode.SuperNode = node; + foreach (DryadQueryNode child in curNode.Children) + { + this.SwitchTo(child, node); + } + } + } + + internal override bool ContainsMerge + { + get { return this.m_containsMerge; } + } + + internal override bool KeepInputPortOrder() + { + return this.KeepInputPortOrder(this.m_rootNode); + } + + private bool KeepInputPortOrder(DryadQueryNode curNode) + { + if (curNode.SuperNode == this) + { + if (curNode.KeepInputPortOrder()) + { + return true; + } + foreach (DryadQueryNode child in curNode.Children) + { + if (this.KeepInputPortOrder(child)) + { + return true; + } + } + } + return false; + } + + internal override bool IsStateful + { + get { return this.m_isStateful; } + } + + internal override Type[] OutputTypes + { + get { return this.m_rootNode.OutputTypes; } + } + + internal bool Contains(DryadQueryNode node) + { + return (node.SuperNode == this); + } + + internal override void CreateCodeAndMappingsForIntermediateTypes() + { + this.CreateCodeAndMappingsForIntermediateTypes(this.m_rootNode); + } + + private void CreateCodeAndMappingsForIntermediateTypes(DryadQueryNode curNode) + { + if (curNode.SuperNode == this) + { + foreach (DryadQueryNode child in curNode.Children) + { + this.CreateCodeAndMappingsForIntermediateTypes(child); + } + } + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + Pipeline pipeline = new Pipeline(vertexMethod, this.QueryGen.CodeGen, writerNames); + this.MakeSuperBody(vertexMethod, this.m_rootNode, writerNames, pipeline); + return this.QueryGen.CodeGen.AddVertexCode(vertexMethod, pipeline); + } + + private void MakeSuperBody(CodeMemberMethod vertexMethod, + DryadQueryNode curNode, + string[] writerNames, + Pipeline pipeline) + { + bool isHomomorphic = curNode.IsHomomorphic; + DryadQueryNode[] curChildren = curNode.Children; + string[] curSources = new string[curChildren.Length]; + + for (int i = 0; i < curChildren.Length; i++) + { + DryadQueryNode child = curChildren[i]; + if (this.Contains(child)) + { + this.MakeSuperBody(vertexMethod, child, writerNames, pipeline); + if (!isHomomorphic) + { + curSources[i] = this.QueryGen.CodeGen.AddVertexCode(vertexMethod, pipeline); + } + } + else + { + Type inputType = child.OutputTypes[0]; + string factoryName = this.QueryGen.CodeGen.GetStaticFactoryName(inputType); + CodeVariableDeclarationStatement + readerDecl = this.QueryGen.CodeGen.MakeDryadReaderDecl(inputType, factoryName); + vertexMethod.Statements.Add(readerDecl); + curSources[i] = readerDecl.Name; + pipeline.Reset(new string[] { readerDecl.Name }); + } + } + + if (!isHomomorphic) + { + pipeline.Reset(curSources); + } + pipeline.Add(curNode); + } + + internal override void GetReferencedQueries(ReferencedQuerySubst subst) + { + this.GetReferencedQueries(this.m_rootNode, subst); + } + + private void GetReferencedQueries(DryadQueryNode curNode, ReferencedQuerySubst subst) + { + curNode.GetReferencedQueries(subst); + DryadQueryNode[] curChildren = curNode.Children; + for (int i = 0; i < curChildren.Length; i++) + { + DryadQueryNode child = curChildren[i]; + if (this.Contains(child)) + { + this.GetReferencedQueries(child, subst); + } + } + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + for (int i = 0; i < this.Children.Length; i++) + { + if (i != 0) + { + builder.Append(", "); + } + this.Children[i].BuildString(builder); + } + builder.Append("]"); + } + } + + internal class DryadApplyNode : DryadQueryNode + { + private LambdaExpression m_procLambda; + private Expression m_procLambda1; + private bool m_isMultiSources; + + internal DryadApplyNode(LambdaExpression procLambda, + bool isMultiSources, + Expression queryExpr, + params DryadQueryNode[] children) + : base(QueryNodeType.Apply, children[0].QueryGen, queryExpr, children) + { + this.m_procLambda = procLambda; + this.m_isMultiSources = isMultiSources; + this.m_opName = "Apply"; + + if (StaticConfig.UseMemoryFIFO && children.Length > 1) + { + bool isStateful = this.IsStateful; + foreach (DryadQueryNode child in this.Children) + { + if (!(child is DryadInputNode) && + !child.IsForked && + !(isStateful && child.IsStateful) && + child.PartitionCount > 1) + { + child.ChannelType = ChannelType.MemoryFIFO; + } + isStateful = isStateful || child.IsStateful; + } + } + + this.m_partitionCount = this.Children[0].OutputPartition.Count; + this.m_outputDataSetInfo = this.ComputeOutputDataSetInfo(); + + this.m_dynamicManager = this.InferDynamicManager(); + } + + internal DryadApplyNode(LambdaExpression procLambda, + Expression queryExpr, + params DryadQueryNode[] children) + : this(procLambda, false, queryExpr, children) + { + } + + // This operator is not stateful iff + // 1. procLambda is of form: (args) => Method(args) and + // 2. Method has a DryadStatefulAttribute specifying it not stateful + internal override bool IsStateful + { + get { + ResourceAttribute attrib = AttributeSystem.GetResourceAttrib(this.m_procLambda); + return (attrib == null || attrib.IsStateful); + } + } + + internal override Type[] OutputTypes + { + get { + Type[] procArgTypes = m_procLambda.Type.GetGenericArguments(); + Int32 idx = (this.IsWriteToStream) ? 0 : (procArgTypes.Length - 1); + Type procReturnType = procArgTypes[idx].GetGenericArguments()[0]; + return new Type[] { procReturnType }; + } + } + + internal bool IsReadFromStream + { + get { + Type[] procArgTypes = m_procLambda.Type.GetGenericArguments(); + return typeof(Stream).IsAssignableFrom(procArgTypes[0]); + } + } + + internal bool IsWriteToStream + { + get { + Type[] procArgTypes = m_procLambda.Type.GetGenericArguments(); + return typeof(Stream).IsAssignableFrom(procArgTypes[1]); + } + } + + private DataSetInfo ComputeOutputDataSetInfo() + { + PartitionInfo pinfo = new RandomPartition(this.m_partitionCount); + OrderByInfo oinfo = DataSetInfo.NoOrderBy; + DistinctInfo dinfo = DataSetInfo.NoDistinct; + return new DataSetInfo(pinfo, oinfo, dinfo); + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + CodeExpression procArg = this.QueryGen.CodeGen.MakeExpression(this.m_procLambda1); + CodeExpression applyExpr = null; + if (this.IsWriteToStream) + { + applyExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + procArg, + new CodeVariableReferenceExpression(writerNames[0])); + vertexMethod.Statements.Add(new CodeExpressionStatement(applyExpr)); + return null; + } + if (this.m_isMultiSources) + { + // Array of the sources. + CodeExpression[] sourceExprs = new CodeExpression[readerNames.Length]; + for (int i = 0; i < readerNames.Length; ++i) + { + sourceExprs[i] = new CodeVariableReferenceExpression(readerNames[i]); + } + var arrayExpr = new CodeArrayCreateExpression(typeof(IEnumerable<>).MakeGenericType(OutputTypes[0]), + sourceExprs); + CodeVariableDeclarationStatement + arrayDecl = this.QueryGen.CodeGen.MakeVarDeclStatement("var", "sourceArray", arrayExpr); + vertexMethod.Statements.Add(arrayDecl); + applyExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(arrayDecl.Name), + procArg); + + } + else if (readerNames.Length == 1) + { + applyExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + procArg); + } + else + { + applyExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, + this.OpName, + new CodeVariableReferenceExpression(readerNames[0]), + new CodeVariableReferenceExpression(readerNames[1]), + procArg); + } + + CodeVariableDeclarationStatement + sourceDecl = this.QueryGen.CodeGen.MakeVarDeclStatement("var", "source", applyExpr); + vertexMethod.Statements.Add(sourceDecl); + return sourceDecl.Name; + } + + internal override void GetReferencedQueries(ReferencedQuerySubst subst) + { + this.m_procLambda1 = subst.Visit(this.m_procLambda); + } + + internal Expression LambdaExpression { get { return this.m_procLambda; } } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + foreach (DryadQueryNode child in this.Children) + { + child.BuildString(builder); + builder.Append(" "); + } + builder.Append(HpcLinqExpression.ToCSharpString(this.m_procLambda, + this.QueryGen.CodeGen.AnonymousTypeToName)); + builder.Append("]"); + } + } + + internal class DryadForkNode : DryadQueryNode + { + private LambdaExpression m_forkLambda; + private Expression m_keysExpression; + private object m_keys; + private int m_keysIdx; + private Type[] m_outputTypes; + + internal DryadForkNode(LambdaExpression fork, + Expression keysExpr, + Expression queryExpr, + DryadQueryNode child) + : base(QueryNodeType.Fork, child.QueryGen, queryExpr, child) + { + this.m_forkLambda = fork; + this.m_keysExpression = keysExpr; + this.m_opName = "Fork"; + + ExpressionSimplifier evaluator = new ExpressionSimplifier(); + this.m_keys = null; + this.m_keysIdx = -1; + if (keysExpr != null) + { + this.m_keys = evaluator.Eval(keysExpr); + this.m_keysIdx = HpcLinqObjectStore.Put(m_keys); + } + + this.m_partitionCount = child.OutputPartition.Count; + PartitionInfo pinfo = new RandomPartition(child.OutputDataSetInfo.partitionInfo.Count); + this.m_outputDataSetInfo = new DataSetInfo(pinfo, DataSetInfo.NoOrderBy, DataSetInfo.NoDistinct); + + this.m_dynamicManager = this.InferDynamicManager(); + + // Finally, create all the children of this: + if (keysExpr == null) + { + Type forkTupleType = fork.Type.GetGenericArguments()[1]; + if (forkTupleType.GetGenericTypeDefinition() == typeof(IEnumerable<>)) + { + forkTupleType = forkTupleType.GetGenericArguments()[0]; + } + Type[] queryTypeArgs = forkTupleType.GetGenericArguments(); + this.m_outputTypes = new Type[queryTypeArgs.Length]; + for (int i = 0; i < queryTypeArgs.Length; i++) + { + this.m_outputTypes[i] = queryTypeArgs[i]; + DryadQueryNode parentNode = new DryadTeeNode(queryTypeArgs[i], true, queryExpr, this); + } + } + else + { + int forkCnt = ((Array)m_keys).Length; + Type forkType = fork.Type.GetGenericArguments()[0]; + this.m_outputTypes = new Type[forkCnt]; + for (int i = 0; i < forkCnt; i++) + { + this.m_outputTypes[i] = forkType; + DryadQueryNode parentNode = new DryadTeeNode(forkType, true, queryExpr, this); + } + } + } + + internal override bool IsStateful + { + get { + if (this.KeysExpression != null) return false; + if (m_forkLambda.Type.GetGenericArguments()[1].GetGenericTypeDefinition() == typeof(ForkTuple<,>)) + { + return false; + } + ResourceAttribute attrib = AttributeSystem.GetResourceAttrib(this.m_forkLambda); + return (attrib == null || attrib.IsStateful); + } + } + + internal override Type[] OutputTypes + { + get { return this.m_outputTypes; } + } + + internal Expression KeysExpression + { + get { return this.m_keysExpression; } + } + + internal Expression ForkLambda + { + get { return this.m_forkLambda; } + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + CodeExpression[] args; + bool orderPreserving = (m_queryGen.Context.Configuration.SelectiveOrderPreservation || + this.OutputDataSetInfo.orderByInfo.IsOrdered); + if (this.KeysExpression != null) + { + args = new CodeExpression[readerNames.Length + writerNames.Length + 3]; + args[0] = new CodeVariableReferenceExpression(readerNames[0]); + args[1] = this.QueryGen.CodeGen.MakeExpression(this.m_forkLambda); + CodeExpression rangeKeys = new CodeMethodInvokeExpression( + new CodeTypeReferenceExpression("HpcLinqObjectStore"), + "Get", + new CodePrimitiveExpression(this.m_keysIdx)); + args[2] = new CodeCastExpression(this.m_keysExpression.Type, rangeKeys); + args[3] = new CodePrimitiveExpression(orderPreserving); + for (int i = 0; i < writerNames.Length; i++) + { + args[i+4] = new CodeVariableReferenceExpression(writerNames[i]); + } + } + else + { + args = new CodeExpression[readerNames.Length + writerNames.Length + 2]; + args[0] = new CodeVariableReferenceExpression(readerNames[0]); + args[1] = this.QueryGen.CodeGen.MakeExpression(this.m_forkLambda); + args[2] = new CodePrimitiveExpression(orderPreserving); + for (int i = 0; i < writerNames.Length; i++) + { + args[i+3] = new CodeVariableReferenceExpression(writerNames[i]); + } + } + + CodeExpression forkExpr = new CodeMethodInvokeExpression(HpcLinqCodeGen.DLVTypeExpr, this.OpName, args); + vertexMethod.Statements.Add(forkExpr); + return null; + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + this.Children[0].BuildString(builder); + builder.Append(", "); + builder.Append(HpcLinqExpression.ToCSharpString(this.m_forkLambda, + this.QueryGen.CodeGen.AnonymousTypeToName)); + builder.Append("]"); + } + } + + internal class DryadTeeNode : DryadQueryNode + { + private Type m_outputType; + + internal DryadTeeNode(Type outputType, bool isForked, Expression queryExpr, DryadQueryNode child) + : base(QueryNodeType.Tee, child.QueryGen, queryExpr, child) + { + this.m_outputType = outputType; + this.m_opName = "Tee"; + this.IsForked = isForked; + + this.m_partitionCount = child.OutputPartition.Count; + PartitionInfo pinfo = new RandomPartition(child.OutputDataSetInfo.partitionInfo.Count); + this.m_outputDataSetInfo = new DataSetInfo(pinfo, DataSetInfo.NoOrderBy, DataSetInfo.NoDistinct); + this.m_dynamicManager = this.InferDynamicManager(); + } + + internal override Type[] InputTypes + { + get { return new Type[] { this.m_outputType }; } + } + + internal override Type[] OutputTypes + { + get { return new Type[] { this.m_outputType }; } + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + throw new InvalidOperationException(); + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " "); + this.Children[0].BuildString(builder); + builder.Append("]"); + } + } + + internal class DryadDynamicNode : DryadQueryNode + { + private DynamicManagerType m_dmType; + private List m_realNodes; + + internal DryadDynamicNode(DynamicManagerType dmType, DryadQueryNode node) + : base(QueryNodeType.Dynamic, node.QueryGen, node.QueryExpression) + { + switch (dmType) + { + case DynamicManagerType.FullAggregator: + case DynamicManagerType.Broadcast: + { + this.m_dmType = dmType; + this.m_realNodes = new List(1); + this.m_realNodes.Add(node); + break; + } + default: + { + throw new DryadLinqException(HpcLinqErrorCode.Internal, + SR.IllegalDynamicManagerType); + } + } + } + + internal override Type[] InputTypes + { + get { return this.m_realNodes[0].InputTypes; } + } + + internal override Type[] OutputTypes + { + get { return this.m_realNodes[this.m_realNodes.Count-1].OutputTypes; } + } + + internal DynamicManagerType DynamicType + { + get { return this.m_dmType; } + } + + internal List RealNodes + { + get { return this.m_realNodes; } + } + + internal DryadQueryNode GetRealNode(int index) + { + return this.m_realNodes[index]; + } + + internal void AddNode(DryadQueryNode node) + { + this.m_realNodes.Add(node); + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + string source = readerNames[0]; + foreach (DryadQueryNode node in this.m_realNodes) + { + source = node.AddVertexCode(vertexMethod, new string[] { source }, null); + } + return source; + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("[" + this.NodeType + " ]"); + } + } + + internal class DryadDummyNode : DryadQueryNode + { + private Type m_outputType; // type of output channel + + /// + /// Create a dummy node with a specific code generator. + /// + /// Query generator to instantiate. + /// Type of the single output. + /// The upstream nodes + internal DryadDummyNode(HpcLinqQueryGen queryGen, + Type outputType, + params DryadQueryNode[] children) + : base(QueryNodeType.Dummy, queryGen, null, children) + { + this.m_outputDataSetInfo = new DataSetInfo(); + this.DynamicManager = DynamicManager.None; + this.m_outputType = outputType; + } + + internal override Type[] OutputTypes + { + get { return new Type[] { this.m_outputType }; } + } + + internal override string AddVertexCode(CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + throw new InvalidOperationException(); + } + + internal override void BuildString(StringBuilder builder) + { + builder.Append("Dummy"); + } + } + + internal class Pipeline + { + private CodeMemberMethod m_vertexMethod; + private HpcLinqCodeGen m_codeGen; + private string[] m_readerNames; + private string[] m_writerNames; + private List m_nodes; + + internal Pipeline(CodeMemberMethod vertexMethod, HpcLinqCodeGen codeGen, string[] writerNames) + { + this.m_vertexMethod = vertexMethod; + this.m_codeGen = codeGen; + this.m_readerNames = null; + this.m_writerNames = writerNames; + this.m_nodes = new List(); + } + + internal string[] ReaderNames + { + get { return this.m_readerNames; } + } + + internal string[] WriterNames + { + get { return this.m_writerNames; } + } + + internal Type InputType + { + get { return this.m_nodes[0].InputTypes[0]; } + } + + internal Type OutputType + { + get { return this.m_nodes[this.Length - 1].OutputTypes[0]; } + } + + internal int Length + { + get { return this.m_nodes.Count; } + } + + internal DryadQueryNode this[int index] + { + get { return this.m_nodes[index]; } + } + + internal void Add(DryadQueryNode node) + { + this.m_nodes.Add(node); + } + + internal CodeExpression BuildExpression(int idx, + Expression inputExpr, + params ParameterExpression[] paramList) + { + Expression bodyExpr = inputExpr; + for (int i = idx; i < this.Length; i++) + { + bodyExpr = this.m_nodes[i].RebuildQueryExpression(bodyExpr); + } + Type type = typeof(Func<,>).MakeGenericType(inputExpr.Type, bodyExpr.Type); + return this.m_codeGen.MakeExpression(Expression.Lambda(type, bodyExpr, paramList)); + } + + internal void Reset(string[] readerNames) + { + this.m_readerNames = readerNames; + this.m_nodes.Clear(); + } + } +} diff --git a/LinqToDryad/DryadRecordReader.cs b/LinqToDryad/DryadRecordReader.cs new file mode 100644 index 0000000..2bb838b --- /dev/null +++ b/LinqToDryad/DryadRecordReader.cs @@ -0,0 +1,637 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; +using System.Reflection; +using System.Threading; +using System.Data.SqlTypes; +using System.Diagnostics; +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq.Internal +{ + // This class defines the abstraction of reading Dryad records. + public abstract class HpcRecordReader : IEnumerable + { + private const int bufferMaxSize = 1024; //@@TODO: it would be good to choose a buffersize based on record-sizes.. as done in HpcRecordWriter. + private T[] m_buffer1; + private T[] m_buffer2; + private int m_count1; + private int m_count2; + private int m_index1; + private Thread m_worker; + private Exception m_workerException; + private long m_numRecordsRead; + protected bool m_isUsed; + + public HpcRecordReader() + { + this.m_isUsed = false; + this.m_numRecordsRead = 0; + this.FirstReadTime = DateTime.Now; + this.LastReadTime = this.FirstReadTime; + } + + protected abstract bool ReadRecord(ref T rec); // simple synchronous non-buffering read of a record + public abstract Int64 GetTotalLength(); + public abstract string GetChannelURI(); + + public virtual void Close() + { + if (this.m_worker != null) + { + lock (this) + { + this.m_count2 = -2; + Monitor.Pulse(this); + } + } + } + + /// + /// Time when first record was read from the channel. + /// + public DateTime FirstReadTime { get; protected set; } + /// + /// Time when last record was read from the channel. + /// + public DateTime LastReadTime { get; protected set; } + public long RecordsRead { get { return m_numRecordsRead; } } + + //Note: direct use of this method (rather than going through the enumerator) + // will miss the checks we do to prevent repeat enumeration. Either add those manually + // or go through the enumerator. + public bool ReadRecordSync(ref T rec) + { + bool isRead = this.ReadRecord(ref rec); + if (isRead) + { + this.m_numRecordsRead++; + } + else + { + this.LastReadTime = DateTime.Now; + } + return isRead; + } + + internal void StartWorker() + { + if (this.m_worker == null) + { + this.m_buffer1 = new T[bufferMaxSize]; + this.m_buffer2 = new T[bufferMaxSize]; + this.m_count1 = 0; + this.m_index1 = this.m_buffer1.Length; + this.m_count2 = -1; + this.m_worker = new Thread(this.FillBuffer); + this.m_worker.Start(); + } + } + + private void FillBuffer() + { + DryadLinqLog.Add("HpcRecordReader reader thread started. ThreadId=" + Thread.CurrentThread.ManagedThreadId); + lock (this) + { + while (true) + { + try + { + while (this.m_count2 > 0) + { + Monitor.Wait(this); + } + if (this.m_count2 == -2) return; + this.m_count2 = 0; + while (this.m_count2 < this.m_buffer2.Length && + this.ReadRecord(ref this.m_buffer2[this.m_count2])) + { + this.m_count2++; + } + Monitor.Pulse(this); + if (this.m_count2 < this.m_buffer2.Length) return; + } + catch (Exception e) + { + this.m_workerException = e; + Monitor.Pulse(this); + return; + } + } + } + } + + internal bool ReadRecordAsync(ref T rec) + { + if (this.m_index1 < this.m_count1) + { + rec = this.m_buffer1[this.m_index1++]; + this.m_numRecordsRead++; + return true; + } + if (this.m_index1 < this.m_buffer1.Length) + { + this.LastReadTime = DateTime.Now; + return false; + } + lock (this) + { + while (this.m_count2 == -1) + { + Monitor.Wait(this); + } + if (this.m_count2 == -2) return false; + if (this.m_workerException != null) + { + throw this.m_workerException; + } + T[] temp = this.m_buffer1; + this.m_buffer1 = this.m_buffer2; + this.m_buffer2 = temp; + this.m_count1 = this.m_count2; + this.m_index1 = 0; + this.m_count2 = -1; + Monitor.Pulse(this); + } + return this.ReadRecordAsync(ref rec); + } + + internal void AddLogEntry() + { + if (this.LastReadTime == this.FirstReadTime) + { + this.LastReadTime = DateTime.Now; + } + DryadLinqLog.Add("Read {0} records from {1} from {2} to {3} ", + this.RecordsRead, + this.ToString(), + this.FirstReadTime.ToString("MM/dd/yyyy HH:mm:ss.fff"), + this.LastReadTime.ToString("MM/dd/yyyy HH:mm:ss.fff")); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + if (this.m_isUsed) + { + throw new DryadLinqException(HpcLinqErrorCode.ChannelCannotBeReadMoreThanOnce, + SR.ChannelCannotBeReadMoreThanOnce); + } + this.m_isUsed = true; + + return new RecordEnumerator(this); + } + + // Internal enumerator class + private class RecordEnumerator : IEnumerator + { + private HpcRecordReader m_reader; + private T m_current; + + public RecordEnumerator(HpcRecordReader reader) + { + this.m_reader = reader; + this.m_current = default(T); + } + + public bool MoveNext() + { + return this.m_reader.ReadRecord(ref this.m_current); + } + + object IEnumerator.Current + { + get { return this.m_current; } + } + + public T Current + { + get { return this.m_current; } + } + + public void Reset() + { + throw new InvalidOperationException(); + } + + void IDisposable.Dispose() + { + } + } + } + + public sealed class HpcRecordTextReader : HpcRecordReader + { + private HpcTextReader m_reader; + + public HpcRecordTextReader(HpcTextReader reader) + { + this.m_reader = reader; + } + + protected override bool ReadRecord(ref LineRecord rec) + { + string line = this.m_reader.ReadLine(); + if (line != null) + { + rec.Line = line; + return true; + } + return false; + } + + public override Int64 GetTotalLength() + { + return this.m_reader.GetTotalLength(); + } + + public override string GetChannelURI() + { + return this.m_reader.GetChannelURI(); + } + + public override void Close() + { + this.AddLogEntry(); + base.Close(); + this.m_reader.Close(); + } + + public override string ToString() + { + return this.m_reader.ToString(); + } + } + + public unsafe abstract class HpcRecordBinaryReader : HpcRecordReader + { + protected HpcBinaryReader m_reader; + + public HpcRecordBinaryReader(HpcBinaryReader reader) + { + this.m_reader = reader; + } + + // entry point needed for generated vertex code + public bool IsReaderAtEndOfStream() + { + return m_reader.EndOfStream(); + } + + public override Int64 GetTotalLength() + { + return this.m_reader.GetTotalLength(); + } + + public override string GetChannelURI() + { + return this.m_reader.GetChannelURI(); + } + + public override void Close() + { + this.AddLogEntry(); + base.Close(); + this.m_reader.Close(); + } + + public override string ToString() + { + return this.m_reader.ToString(); + } + } + + public sealed class HpcRecordByteReader : HpcRecordBinaryReader + { + public HpcRecordByteReader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref byte rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadUByte(); + return true; + } + return false; + } + } + + public sealed class HpcRecordSByteReader : HpcRecordBinaryReader + { + public HpcRecordSByteReader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref sbyte rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadSByte(); + return true; + } + return false; + } + } + + public sealed class HpcRecordBoolReader : HpcRecordBinaryReader + { + public HpcRecordBoolReader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref bool rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadBool(); + return true; + } + return false; + } + } + + public sealed class HpcRecordCharReader : HpcRecordBinaryReader + { + public HpcRecordCharReader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref char rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadChar(); + return true; + } + return false; + } + } + + public sealed class HpcRecordShortReader : HpcRecordBinaryReader + { + public HpcRecordShortReader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref short rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadInt16(); + return true; + } + return false; + } + } + + public sealed class HpcRecordUShortReader : HpcRecordBinaryReader + { + public HpcRecordUShortReader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref ushort rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadUInt16(); + return true; + } + return false; + } + } + + public sealed class HpcRecordInt32Reader : HpcRecordBinaryReader + { + public HpcRecordInt32Reader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref int rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadInt32(); + return true; + } + return false; + } + } + + public sealed class HpcRecordUInt32Reader : HpcRecordBinaryReader + { + public HpcRecordUInt32Reader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref uint rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadUInt32(); + return true; + } + return false; + } + } + + public sealed class HpcRecordInt64Reader : HpcRecordBinaryReader + { + public HpcRecordInt64Reader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref long rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadInt64(); + return true; + } + return false; + } + } + + public sealed class HpcRecordUInt64Reader : HpcRecordBinaryReader + { + public HpcRecordUInt64Reader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref ulong rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadUInt64(); + return true; + } + return false; + } + } + + public sealed class HpcRecordFloatReader : HpcRecordBinaryReader + { + public HpcRecordFloatReader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref float rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadSingle(); + return true; + } + return false; + } + } + + public sealed class HpcRecordDecimalReader : HpcRecordBinaryReader + { + public HpcRecordDecimalReader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref decimal rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadDecimal(); + return true; + } + return false; + } + } + + public sealed class HpcRecordDoubleReader : HpcRecordBinaryReader + { + public HpcRecordDoubleReader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref double rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadDouble(); + return true; + } + return false; + } + } + + public sealed class HpcRecordDateTimeReader : HpcRecordBinaryReader + { + public HpcRecordDateTimeReader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref DateTime rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadDateTime(); + return true; + } + return false; + } + } + + public sealed class HpcRecordStringReader : HpcRecordBinaryReader + { + public HpcRecordStringReader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref string rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadString(); + return true; + } + return false; + } + } + + public sealed class HpcRecordSqlDateTimeReader : HpcRecordBinaryReader + { + public HpcRecordSqlDateTimeReader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref SqlDateTime rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadSqlDateTime(); + return true; + } + return false; + } + } + + public sealed class HpcRecordGuidReader : HpcRecordBinaryReader + { + public HpcRecordGuidReader(HpcBinaryReader reader) + : base(reader) + { + } + + protected override bool ReadRecord(ref Guid rec) + { + if (!this.m_reader.EndOfStream()) + { + rec = this.m_reader.ReadGuid(); + return true; + } + return false; + } + } +} diff --git a/LinqToDryad/DryadRecordWriter.cs b/LinqToDryad/DryadRecordWriter.cs new file mode 100644 index 0000000..01f2da5 --- /dev/null +++ b/LinqToDryad/DryadRecordWriter.cs @@ -0,0 +1,564 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; +using System.Reflection; +using System.Threading; +using System.Data.SqlTypes; +using System.Diagnostics; +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq.Internal +{ + // This class defines the abstraction of writing HcLinq records. + public unsafe abstract class HpcRecordWriter + { + private const int BufferMaxSize = 1024; + private const int InitRecords = 100; + + private long m_numRecordsWritten; + private T[] m_buffer1; + private T[] m_buffer2; + private int m_index1; + private int m_count2; + private bool m_isClosed; + private Thread m_worker; + + public HpcRecordWriter() + { + this.m_numRecordsWritten = 0; + this.m_buffer1 = null; + this.m_buffer2 = null; + this.m_index1 = 0; + this.m_count2 = -1; + this.m_isClosed = false; + this.m_worker = null; + } + + protected abstract void WriteRecord(T rec); + protected abstract void FlushInternal(); + protected abstract void CloseInternal(); + + public abstract Int64 Length { get; } + public abstract string GetChannelURI(); + public abstract Int64 GetTotalLength(); + public abstract UInt64 GetFingerPrint(); + public abstract bool CalcFP { get; set; } + public abstract Int32 BufferSizeHint { get; } + + public void WriteRecordSync(T rec) + { + this.WriteRecord(rec); + this.m_numRecordsWritten++; + } + + //Called by HpcVertexWrite.WriteItemSequence, DataProvider.IngressDirectlyToDsc etc. + //Note: async writer thread will only be started after nRecords>InitRecords (default=100) + public void WriteRecordAsync(T rec) + { + if (this.m_worker == null) + { + this.WriteRecord(rec); + this.m_numRecordsWritten++; + if (this.m_numRecordsWritten == InitRecords) + { + // Decide if we want to use async and the buffer size + Int32 bsize = (this.BufferSizeHint / (4 * (Int32)this.Length)) * InitRecords; + if (this.BufferSizeHint > (64 * BufferMaxSize) && bsize > 1) + { + bsize = Math.Min(bsize, BufferMaxSize); + this.m_buffer1 = new T[bsize]; + this.m_buffer2 = new T[bsize]; + this.m_index1 = 0; + this.m_count2 = -1; + this.m_isClosed = false; + this.m_worker = new Thread(this.WriteBuffer); + this.m_worker.Start(); + DryadLinqLog.Add("Async writer with buffer size {0}", bsize); + } + } + } + else + { + if (this.m_index1 == this.m_buffer1.Length) + { + lock (this) + { + while (this.m_count2 != -1) + { + Monitor.Wait(this); + } + T[] temp = this.m_buffer1; + this.m_buffer1 = this.m_buffer2; + this.m_buffer2 = temp; + this.m_count2 = this.m_index1; + this.m_index1 = 0; + Monitor.Pulse(this); + } + } + this.m_buffer1[this.m_index1++] = rec; + } + } + + private void WriteBuffer() + { + try + { + while (true) + { + lock (this) + { + while (this.m_count2 == -1) + { + Monitor.Wait(this); + } + } + + // Write the records + for (int i = 0; i < this.m_count2; i++) + { + this.WriteRecord(this.m_buffer2[i]); + } + this.m_numRecordsWritten += this.m_count2; + + lock (this) + { + this.m_count2 = -1; + Monitor.Pulse(this); + + if (this.m_isClosed) break; + } + } + } + catch (Exception e) + { + DryadLinqLog.Add(true, e.ToString()); + throw; + } + } + + private void Flush(bool closeIt) + { + if (this.m_worker != null) + { + lock (this) + { + while (this.m_count2 != -1) + { + Monitor.Wait(this); + } + T[] temp = this.m_buffer1; + this.m_buffer1 = this.m_buffer2; + this.m_buffer2 = temp; + this.m_count2 = this.m_index1; + this.m_index1 = 0; + this.m_isClosed = closeIt; + Monitor.Pulse(this); + + // Again, wait for the worker to complete + while (this.m_count2 != -1) + { + Monitor.Wait(this); + } + } + } + this.FlushInternal(); + } + + public void Flush() + { + this.Flush(false); + } + + public void Close() + { + this.Flush(true); + this.CloseInternal(); + DryadLinqLog.Add("Wrote {0} records to {1}", this.m_numRecordsWritten, this.ToString()); + } + } + + public sealed class HpcRecordTextWriter : HpcRecordWriter + { + private HpcTextWriter m_writer; + + public HpcRecordTextWriter(HpcTextWriter writer) + { + this.m_writer = writer; + } + + protected override void WriteRecord(LineRecord rec) + { + this.m_writer.WriteLine(rec.Line); + } + + public override Int64 Length + { + get { return this.m_writer.Length; } + } + + public override string GetChannelURI() + { + return this.m_writer.GetChannelURI(); + } + + public override long GetTotalLength() + { + return this.m_writer.GetTotalLength(); + } + + public override UInt64 GetFingerPrint() + { + return this.m_writer.GetFingerPrint(); + } + + public override bool CalcFP + { + get { return this.m_writer.CalcFP; } + set { this.m_writer.CalcFP = value; } + } + + public override Int32 BufferSizeHint + { + get { return this.m_writer.BufferSizeHint; } + } + + protected override void FlushInternal() + { + this.m_writer.Flush(); + } + + protected override void CloseInternal() + { + this.m_writer.Close(); + } + + public override string ToString() + { + return this.m_writer.ToString(); + } + } + + public unsafe abstract class HpcRecordBinaryWriter : HpcRecordWriter + { + protected HpcBinaryWriter m_writer; + + public HpcRecordBinaryWriter(HpcBinaryWriter writer) + { + this.m_writer = writer; + } + + public override Int64 Length + { + get { return this.m_writer.Length; } + } + + public override string GetChannelURI() + { + return this.m_writer.GetChannelURI(); + } + + public override long GetTotalLength() + { + return this.m_writer.GetTotalLength(); + } + + public override UInt64 GetFingerPrint() + { + return this.m_writer.GetFingerPrint(); + } + + public override bool CalcFP + { + get { return this.m_writer.CalcFP; } + set { this.m_writer.CalcFP = value; } + } + + public override Int32 BufferSizeHint + { + get { return this.m_writer.BufferSizeHint; } + } + + // helper for generated vertex code to call m_writer.CompleteWriteRecord() + public void CompleteWriteRecord() + { + m_writer.CompleteWriteRecord(); + } + + protected override void FlushInternal() + { + this.m_writer.Flush(); + } + + protected override void CloseInternal() + { + this.m_writer.Close(); + } + + public override string ToString() + { + return this.m_writer.ToString(); + } + } + + public sealed class HpcRecordByteWriter : HpcRecordBinaryWriter + { + public HpcRecordByteWriter(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(byte rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } + + public sealed class HpcRecordSByteWriter : HpcRecordBinaryWriter + { + public HpcRecordSByteWriter(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(sbyte rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } + + public sealed class HpcRecordBoolWriter : HpcRecordBinaryWriter + { + public HpcRecordBoolWriter(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(bool rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } + + public sealed class HpcRecordCharWriter : HpcRecordBinaryWriter + { + public HpcRecordCharWriter(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(char rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } + + public sealed class HpcRecordShortWriter : HpcRecordBinaryWriter + { + public HpcRecordShortWriter(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(short rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } + + public sealed class HpcRecordUShortWriter : HpcRecordBinaryWriter + { + public HpcRecordUShortWriter(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(ushort rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } + + public sealed class HpcRecordInt32Writer : HpcRecordBinaryWriter + { + public HpcRecordInt32Writer(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(int rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } + + public sealed class HpcRecordUInt32Writer : HpcRecordBinaryWriter + { + public HpcRecordUInt32Writer(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(uint rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } + + public sealed class HpcRecordInt64Writer : HpcRecordBinaryWriter + { + public HpcRecordInt64Writer(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(long rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } + + public sealed class HpcRecordUInt64Writer : HpcRecordBinaryWriter + { + public HpcRecordUInt64Writer(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(ulong rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } + + public sealed class HpcRecordDecimalWriter : HpcRecordBinaryWriter + { + public HpcRecordDecimalWriter(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(decimal rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } + + public sealed class HpcRecordFloatWriter : HpcRecordBinaryWriter + { + public HpcRecordFloatWriter(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(float rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } + + public sealed class HpcRecordDoubleWriter : HpcRecordBinaryWriter + { + public HpcRecordDoubleWriter(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(double rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } + + public sealed class HpcRecordDateTimeWriter : HpcRecordBinaryWriter + { + public HpcRecordDateTimeWriter(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(DateTime rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } + + public sealed class HpcRecordStringWriter : HpcRecordBinaryWriter + { + public HpcRecordStringWriter(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(string rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } + + public sealed class HpcRecordSqlDateTimeWriter : HpcRecordBinaryWriter + { + public HpcRecordSqlDateTimeWriter(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(SqlDateTime rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } + + public sealed class HpcRecordGuidWriter : HpcRecordBinaryWriter + { + public HpcRecordGuidWriter(HpcBinaryWriter writer) + : base(writer) + { + } + + protected override void WriteRecord(Guid rec) + { + this.m_writer.Write(rec); + this.m_writer.CompleteWriteRecord(); + } + } +} diff --git a/LinqToDryad/DryadRuntime.cs b/LinqToDryad/DryadRuntime.cs new file mode 100644 index 0000000..dc3ce82 --- /dev/null +++ b/LinqToDryad/DryadRuntime.cs @@ -0,0 +1,110 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Microsoft.Research.DryadLinq +{ + public sealed class HpcLinqJobInfo + { + internal const int JOBID_LOCALDEBUG = -1; + private int _jobId; + private string _headNode; + private string[] _targetUris; // Test-support + private JobExecutor _jobExecutor; + + public int JobId + { + get {return _jobId;} + } + + public string HeadNode + { + get { return _headNode; } + } + + // Test-support + internal string[] TargetUris + { + get { return _targetUris; } + } + + internal HpcLinqJobInfo(int jobId, + string headNode, + JobExecutor jobExecutor, + string[] targetUris) + { + _jobId = jobId; + _headNode = headNode; + _jobExecutor = jobExecutor; + _targetUris = targetUris; + } + + public void Wait() + { + if (_jobExecutor != null) + { + JobStatus finalStatus = _jobExecutor.WaitForCompletion(); + if (finalStatus != JobStatus.Success) + { + throw new DryadLinqException(HpcLinqErrorCode.DidNotCompleteSuccessfully, + SR.DidNotCompleteSuccessfully); + } + } + } + } + + /// + /// Represents a connection to a HPC Server that can execute HpcLinq jobs. + /// + /// + /// A HpcQueryRuntime instance holds an open Microsoft.Hpc.Scheduler.Scheduler connection. + /// This connection can be closed by calling Dispose() + /// When a HpcQueryRuntime instance is passed to HpcLinqQuery.Submit(), HpcLinq will use + /// the open connection to submit the job. + /// + internal sealed class HpcQueryRuntime : IDisposable + { + private string m_headNode; + private IScheduler m_scheduler; + + public string HostName { get { return m_headNode; } } + + public HpcQueryRuntime(string headNode){ + m_headNode = headNode; + m_scheduler = new YarnScheduler(); + m_scheduler.Connect(m_headNode); + } + + public void Dispose() + { + m_scheduler.Dispose(); + } + + // Return IScheduler reference for internal use + internal IScheduler GetIScheduler() + { + return m_scheduler; + } + } +} diff --git a/LinqToDryad/DryadTextReader.cs b/LinqToDryad/DryadTextReader.cs new file mode 100644 index 0000000..5d3a87e --- /dev/null +++ b/LinqToDryad/DryadTextReader.cs @@ -0,0 +1,233 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; +using System.Reflection; +using System.Diagnostics; +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq.Internal +{ + public unsafe sealed class HpcTextReader + { + // The number of bytes we attempt to decode each time + private const int DecodeUnitByteSize = 8192 * 16; + + private NativeBlockStream m_nativeStream; // source stream + private Encoding m_encoding; // character encoding + private Decoder m_decoder; // class decoding bytes to chars + private DataBlockInfo m_curDataBlockInfo; // unsafe class describing a memory buffer read from stream + private Int32 m_curBlockPos; // pointer in input buffer holding first char to decode + private Int32 m_decodeUnitCharSize; // how many characters are decoded in one call + private char[] m_charBuff; // temporary buffer holding decoded characters; grown dynamically + private Int32 m_charBuffEnd; // offset of last character in charBuff + private Int32 m_curLineStart; // offset of line start in charBuff + private Int32 m_curLineEnd; // offset of line end in charBuff + private bool m_isClosed; + + public HpcTextReader(NativeBlockStream stream) + : this(stream, Encoding.UTF8) + { + } + + public HpcTextReader(NativeBlockStream stream, Encoding encoding) + { + this.m_nativeStream = stream; + this.m_encoding = encoding; + this.m_decoder = encoding.GetDecoder(); + this.m_curDataBlockInfo.dataBlock = null; + this.m_curDataBlockInfo.blockSize = -1; + this.m_curDataBlockInfo.itemHandle = IntPtr.Zero; + this.m_curBlockPos = 0; + this.m_decodeUnitCharSize = this.m_encoding.GetMaxCharCount(DecodeUnitByteSize); + this.m_charBuff = new char[this.m_decodeUnitCharSize + 2]; //allow 2 bytes for trailing newline + this.m_charBuffEnd = 0; + this.m_curLineStart = 0; + this.m_curLineEnd = 0; + this.m_isClosed = false; + } + + public HpcTextReader(IntPtr vertexInfo, UInt32 portNum) + : this(new HpcLinqChannel(vertexInfo, portNum, true), Encoding.UTF8) + { + } + + public HpcTextReader(IntPtr vertexInfo, UInt32 portNum, Encoding encoding) + : this(new HpcLinqChannel(vertexInfo, portNum, true), encoding) + { + } + + ~HpcTextReader() + { + this.Close(); + } + + public Int64 GetTotalLength() + { + return this.m_nativeStream.GetTotalLength(); + } + + // Fill this.charBuff by decoding more from data block. + private int FillCharBuffer() + { + if (this.m_curDataBlockInfo.blockSize <= 0) + { + if (this.m_curDataBlockInfo.blockSize == -1) + { + this.GetNextDataBlock(); + } + if (this.m_curDataBlockInfo.blockSize == 0) return 0; + } + + Int32 curLineLen = this.m_curLineEnd - this.m_curLineStart; + if (curLineLen + this.m_decodeUnitCharSize + 2 > this.m_charBuff.Length) //allow 2 bytes for trailing newline + { + // The current charBuff is too small, augment + char[] newCharBuff = new char[curLineLen + this.m_decodeUnitCharSize + 2]; //allow 2 bytes for trailing newline + Array.Copy(this.m_charBuff, this.m_curLineStart, newCharBuff, 0, curLineLen); + this.m_charBuff = newCharBuff; + } + else + { + // Shift the current line to the beginning + Array.Copy(this.m_charBuff, this.m_curLineStart, this.m_charBuff, 0, curLineLen); + } + this.m_curLineStart = 0; + this.m_curLineEnd = curLineLen; + this.m_charBuffEnd = curLineLen; + + // Decode DecodeUnitByteSize bytes unless EOF + Int32 numChars = 0; + Int32 numBytesDesired = DecodeUnitByteSize - 2; + while (numBytesDesired > 0) + { + Int32 numBytesRemainingInBlock = this.m_curDataBlockInfo.blockSize - this.m_curBlockPos; + if (numBytesRemainingInBlock > numBytesDesired) + { + fixed (char* pChars = this.m_charBuff) + { + numChars += this.m_decoder.GetChars(this.m_curDataBlockInfo.dataBlock + this.m_curBlockPos, + numBytesDesired, + pChars + curLineLen + numChars, + this.m_decodeUnitCharSize - numChars, + false); + } + this.m_curBlockPos += numBytesDesired; + break; + } + else + { + fixed (char* pChars = this.m_charBuff) + { + numChars += this.m_decoder.GetChars(this.m_curDataBlockInfo.dataBlock + this.m_curBlockPos, + numBytesRemainingInBlock, + pChars + curLineLen + numChars, + this.m_decodeUnitCharSize - numChars, + false); + } + numBytesDesired -= numBytesRemainingInBlock; + + this.GetNextDataBlock(); + if (this.m_curDataBlockInfo.blockSize <= 0) break; + } + } + + this.m_charBuffEnd += numChars; + return numChars; + } + + // Get the next data block + private unsafe void GetNextDataBlock() + { + this.m_nativeStream.ReleaseDataBlock(this.m_curDataBlockInfo.itemHandle); + this.m_curDataBlockInfo.itemHandle = IntPtr.Zero; + this.m_curDataBlockInfo = this.m_nativeStream.ReadDataBlock(); + this.m_curBlockPos = 0; + } + + public bool MoveNext() + { + return (this.m_curLineEnd < this.m_charBuffEnd || this.FillCharBuffer() > 0); + } + + // Reads a line of characters and returns as a string. Returns null if EOF. + public string ReadLine() + { + Debug.Assert(this.m_curLineStart == this.m_curLineEnd); + while (this.m_curLineEnd < this.m_charBuffEnd || this.FillCharBuffer() > 0) + { + char ch = this.m_charBuff[m_curLineEnd]; + if (ch == '\r' || ch == '\n') + { + Int32 lineLen = this.m_curLineEnd - this.m_curLineStart; + this.m_curLineEnd++; + if (ch == '\r' && (this.m_curLineEnd < this.m_charBuffEnd || this.FillCharBuffer() > 0)) + { + if (this.m_charBuff[this.m_curLineEnd] == '\n') + { + this.m_curLineEnd++; + } + } + Int32 lineStart = this.m_curLineStart; + this.m_curLineStart = this.m_curLineEnd; + return new String(this.m_charBuff, lineStart, lineLen); + } + else + { + this.m_curLineEnd++; + } + } + + // This is for the last line: + Int32 lastLineLen = this.m_curLineEnd - this.m_curLineStart; + if (lastLineLen == 0) return null; + String lastLine = new String(this.m_charBuff, this.m_curLineStart, lastLineLen); + this.m_curLineStart = this.m_curLineEnd; + return lastLine; + } + + public void Close() + { + if (!this.m_isClosed) + { + this.m_isClosed = true; + this.m_nativeStream.ReleaseDataBlock(this.m_curDataBlockInfo.itemHandle); + this.m_nativeStream.Close(); + } + GC.SuppressFinalize(this); + } + + internal string GetChannelURI() + { + return this.m_nativeStream.GetURI(); + } + + public override string ToString() + { + return this.m_nativeStream.ToString(); + } + } +} diff --git a/LinqToDryad/DryadTextWriter.cs b/LinqToDryad/DryadTextWriter.cs new file mode 100644 index 0000000..ea17609 --- /dev/null +++ b/LinqToDryad/DryadTextWriter.cs @@ -0,0 +1,247 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; +using System.Reflection; +using System.Diagnostics; +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq.Internal +{ + public unsafe sealed class HpcTextWriter + { + private const int DefaultBlockSize = 256 * 1024; + private const string NewLine = "\r\n"; + + private NativeBlockStream m_nativeStream; + private Encoding m_encoding; + private Int32 m_nextBlockSize; + private Int32 m_bufferSizeHint; + private DataBlockInfo m_curDataBlockInfo; + private byte* m_curDataBlock; + private Int32 m_curBlockSize; + private Int32 m_curLineStart; + private Int32 m_curLineEnd; + private Int64 m_numBytesWritten; + private bool m_calcFP; + private bool m_isClosed; + private bool m_isASCIIOrUTF8; + + public HpcTextWriter(NativeBlockStream stream) + : this(stream, Encoding.UTF8) + { + } + + public HpcTextWriter(NativeBlockStream stream, Encoding encoding) + : this(stream, encoding, DefaultBlockSize) + { + } + + public HpcTextWriter(NativeBlockStream stream, Encoding encoding, Int32 buffSize) + { + this.m_nativeStream = stream; + this.m_encoding = encoding; + this.m_nextBlockSize = Math.Max(DefaultBlockSize, buffSize/2); + this.m_bufferSizeHint = buffSize; + this.m_curDataBlockInfo.dataBlock = null; + this.m_curDataBlockInfo.blockSize = 0; + this.m_curDataBlockInfo.itemHandle = IntPtr.Zero; + this.m_curDataBlock = this.m_curDataBlockInfo.dataBlock; + this.m_curBlockSize = this.m_curDataBlockInfo.blockSize; + this.m_curLineStart = 0; + this.m_curLineEnd = 0; + this.m_numBytesWritten = 0; + this.m_calcFP = false; + this.m_isClosed = false; + this.m_isASCIIOrUTF8 = (encoding == Encoding.UTF8 || encoding == Encoding.ASCII); + } + + public HpcTextWriter(IntPtr vertexInfo, UInt32 portNum, Int32 buffSize) + : this(new HpcLinqChannel(vertexInfo, portNum, false), Encoding.UTF8, buffSize) + { + } + + public HpcTextWriter(IntPtr vertexInfo, UInt32 portNum, Encoding encoding, Int32 buffSize) + : this(new HpcLinqChannel(vertexInfo, portNum, false), encoding, buffSize) + { + } + + ~HpcTextWriter() + { + this.Close(); + } + + public Int32 BufferSizeHint + { + get { return this.m_bufferSizeHint; } + } + + internal string GetChannelURI() + { + return this.m_nativeStream.GetURI(); + } + + internal Int64 GetTotalLength() + { + return this.m_nativeStream.GetTotalLength(); + } + + internal UInt64 GetFingerPrint() + { + if (!this.m_calcFP) + { + throw new DryadLinqException(HpcLinqErrorCode.FingerprintDisabled, SR.FingerprintDisabled); + } + return this.m_nativeStream.GetFingerPrint(); + } + + public bool CalcFP + { + get { return this.m_calcFP; } + set { this.m_calcFP = value; } + } + + public unsafe int WriteLine(string line) + { + Int32 strLen = line.Length; + Int32 maxByteCount = this.m_encoding.GetMaxByteCount(strLen + 2); + + while (this.m_curBlockSize - this.m_curLineEnd < maxByteCount) + { + this.FlushDataBlock(); + } + + Int32 numBytes; + fixed (char* pLine = line) + { + numBytes = this.m_encoding.GetBytes(pLine, + strLen, + this.m_curDataBlock + this.m_curLineEnd, + this.m_curBlockSize - this.m_curLineEnd); + } + this.m_curLineEnd += numBytes; + + Int32 numBytes1 = 2; + if (this.m_isASCIIOrUTF8) + { + this.m_curDataBlock[this.m_curLineEnd] = 0x0d; + this.m_curDataBlock[this.m_curLineEnd+1] = 0x0a; + this.m_curLineEnd += numBytes1; + } + else + { + fixed (char* pNewLine = NewLine) + { + numBytes1 = this.m_encoding.GetBytes(pNewLine, + NewLine.Length, + this.m_curDataBlock + this.m_curLineEnd, + this.m_curBlockSize - this.m_curLineEnd); + } + this.m_curLineEnd += numBytes1; + } + this.m_curLineStart = this.m_curLineEnd; + return numBytes + numBytes1; + } + + public void Flush() + { + Debug.Assert(this.m_curLineStart == this.m_curLineEnd); + if (this.m_curLineStart > 0) + { + this.m_nativeStream.WriteDataBlock(this.m_curDataBlockInfo.itemHandle, this.m_curLineStart); + this.m_numBytesWritten += this.m_curLineStart; + this.m_nativeStream.ReleaseDataBlock(this.m_curDataBlockInfo.itemHandle); + this.m_curDataBlockInfo.itemHandle = IntPtr.Zero; + this.m_curDataBlockInfo = this.m_nativeStream.AllocateDataBlock(this.m_curBlockSize); + this.m_curDataBlock = this.m_curDataBlockInfo.dataBlock; + this.m_curBlockSize = this.m_curDataBlockInfo.blockSize; + this.m_curLineStart = 0; + this.m_curLineEnd = 0; + } + + this.m_nativeStream.Flush(); + } + + public void Close() + { + if (!this.m_isClosed) + { + this.m_isClosed = true; + this.Flush(); + this.m_nativeStream.Close(); + } + GC.SuppressFinalize(this); + } + + private void FlushDataBlock() + { + DataBlockInfo newDataBlockInfo; + if (this.m_curLineStart == 0) + { + // The current block is too small for a single record, augment it + if (this.m_curBlockSize == this.m_nextBlockSize) + { + throw new DryadLinqException(HpcLinqErrorCode.RecordSizeMax2GB, SR.RecordSizeMax2GB); + } + newDataBlockInfo = this.m_nativeStream.AllocateDataBlock(this.m_nextBlockSize); + this.m_nextBlockSize = this.m_nextBlockSize * 2; + if (this.m_nextBlockSize < 0) + { + this.m_nextBlockSize = 0x7FFFFFF8; + } + this.m_nativeStream.ReleaseDataBlock(this.m_curDataBlockInfo.itemHandle); + this.m_curDataBlockInfo.itemHandle = IntPtr.Zero; + } + else + { + // Write all the complete records in the block + this.m_nativeStream.WriteDataBlock(this.m_curDataBlockInfo.itemHandle, this.m_curLineStart); + this.m_numBytesWritten += this.m_curLineStart; + this.m_nativeStream.ReleaseDataBlock(this.m_curDataBlockInfo.itemHandle); + this.m_curDataBlockInfo.itemHandle = IntPtr.Zero; + newDataBlockInfo = this.m_nativeStream.AllocateDataBlock(this.m_curBlockSize); + this.m_curLineEnd -= this.m_curLineStart; + this.m_curLineStart = 0; + } + this.m_curDataBlockInfo = newDataBlockInfo; + this.m_curDataBlock = newDataBlockInfo.dataBlock; + this.m_curBlockSize = newDataBlockInfo.blockSize; + } + + public Int64 Length + { + get { + return this.m_numBytesWritten + this.m_curLineEnd; + } + } + + public override string ToString() + { + return this.m_nativeStream.ToString(); + } + } + +} diff --git a/LinqToDryad/DryadVertexEnv.cs b/LinqToDryad/DryadVertexEnv.cs new file mode 100644 index 0000000..8ce8878 --- /dev/null +++ b/LinqToDryad/DryadVertexEnv.cs @@ -0,0 +1,315 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.IO; +using System.Globalization; +using System.Reflection; +using System.Runtime; +using System.Diagnostics; +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq.Internal +{ + // The class encapsulates the external environment in which a + // managed query operator executes. + public class HpcLinqVertexEnv + { + private const string VERTEX_EXCEPTION_FILENAME = @"VertexException.txt"; + + private IntPtr m_nativeHandle; + private UInt32 m_numberOfInputs; + private UInt32 m_numberOfOutputs; + private Int32 m_nextInputPort; + private Int32 m_nextOutputPort; + private string[] m_argList; + private HpcLinqVertexParams m_vertexParams; + private bool m_useLargeBuffer; + private bool m_keepInputPortOrder; + private bool m_multiThreading; + + public HpcLinqVertexEnv(string args, HpcLinqVertexParams vertexParams) + { + this.m_argList = args.Split('|'); + this.m_nativeHandle = new IntPtr(Int64.Parse(this.m_argList[0], NumberStyles.HexNumber)); + this.m_numberOfInputs = HpcLinqNative.GetNumOfInputs(this.m_nativeHandle); + this.m_numberOfOutputs = HpcLinqNative.GetNumOfOutputs(this.m_nativeHandle); + this.m_nextInputPort = 0; + this.m_nextOutputPort = 0; + this.m_vertexParams = vertexParams; + this.m_useLargeBuffer = vertexParams.UseLargeBuffer; + this.m_keepInputPortOrder = vertexParams.KeepInputPortOrder; + this.m_multiThreading = vertexParams.MultiThreading; + if (this.m_numberOfOutputs > 0) + { + this.SetInitialWriteSizeHint(); + } + + Debug.Assert(vertexParams.InputArity <= this.m_numberOfInputs); + Debug.Assert(vertexParams.OutputArity <= this.m_numberOfOutputs); + } + + public bool MultiThreading + { + get { return m_multiThreading; } + set { m_multiThreading = value; } + } + + internal IntPtr NativeHandle + { + get { return this.m_nativeHandle; } + } + + public UInt32 NumberOfInputs + { + get { return this.m_numberOfInputs; } + } + + public UInt32 NumberOfOutputs + { + get { return this.m_numberOfOutputs; } + } + + public Int32 NumberOfArguments + { + get { return this.m_argList.Length; } + } + + public string GetArgument(Int32 idx) + { + return this.m_argList[idx]; + } + + private bool UseLargeBuffer + { + get { return this.m_useLargeBuffer; } + } + + internal bool KeepInputPortOrder + { + get { return this.m_keepInputPortOrder; } + } + + public Int64 VertexId + { + get { + return HpcLinqNative.GetVertexId(this.m_nativeHandle); + } + } + + public HpcVertexReader MakeReader(HpcLinqFactory readerFactory) + { + if (this.m_nextInputPort + 1 < this.m_vertexParams.InputArity) + { + UInt32 portNum = (UInt32)this.m_nextInputPort++; + return new HpcVertexReader(this, readerFactory, portNum); + } + else + { + UInt32 startPort = (UInt32)this.m_nextInputPort; + UInt32 endPort = this.NumberOfInputs; + return new HpcVertexReader(this, readerFactory, startPort, endPort); + } + } + + public HpcVertexWriter MakeWriter(HpcLinqFactory writerFactory) + { + if (this.m_nextOutputPort + 1 < this.m_vertexParams.OutputArity) + { + UInt32 portNum = (UInt32)this.m_nextOutputPort++; + return new HpcVertexWriter(this, writerFactory, portNum); + } + else + { + UInt32 startPort = (UInt32)this.m_nextOutputPort; + UInt32 endPort = this.NumberOfOutputs; + return new HpcVertexWriter(this, writerFactory, startPort, endPort); + } + } + + public static HpcBinaryReader MakeBinaryReader(NativeBlockStream nativeStream) + { + return new HpcBinaryReader(nativeStream); + } + + public static HpcBinaryReader MakeBinaryReader(IntPtr handle, UInt32 port) + { + return new HpcBinaryReader(handle, port); + } + + public static HpcBinaryWriter MakeBinaryWriter(NativeBlockStream nativeStream) + { + return new HpcBinaryWriter(nativeStream); + } + + public static HpcBinaryWriter MakeBinaryWriter(IntPtr handle, UInt32 port, Int32 buffSize) + { + return new HpcBinaryWriter(handle, port, buffSize); + } + + private static Exception s_lastReportedException; + + internal static int ErrorCode { get; set; } + + // + // This method gets called by the generated vertex code, as well as VertexBridge to report exceptions. + // The exception will be dumped to "VertexException.txt" in the working directory. + // + public static void ReportVertexError(Exception e) + { + // We first need to check whether the same exception object was already reported recently, + // and ignore the second call. + + // This will be the case for most vertex exceptions because 1) the generated vertex code catches the exceptions, + // calls ReportVertexError and rethrows, and right after that 2) VertexBridge will receive the same exception + // wrapped in a TargetInvocationException, and call ReportVertexError again after extracting the inner exception. + // + // The second call from the VertexBridge is necessary because some exceptions + // (particularly TypeLoadException due to static ctors) happen in the vertex DLL, + // but just before the try/catch blocks in the vertex entry point (therefore are missed by 1). + if (s_lastReportedException == e) return; + + s_lastReportedException = e; + + // add to HpcLog + DryadLinqLog.Add("Vertex failed with the following exception:"); + DryadLinqLog.Add("{0}", e.ToString()); + + // also write out to the standalone vertex exception file in the working directory + using (StreamWriter exceptionFile = new StreamWriter(VERTEX_EXCEPTION_FILENAME)) + { + exceptionFile.WriteLine(e.ToString()); + } + if (ErrorCode == 0) throw e; + } + + internal unsafe Int32 GetWriteBuffSize() + { + MEMORYSTATUSEX memStatus = new MEMORYSTATUSEX(); + memStatus.dwLength = (UInt32)sizeof(MEMORYSTATUSEX); + UInt64 maxSize = 512 * 1024 * 1024UL; + if (HpcLinqNative.GlobalMemoryStatusEx(ref memStatus)) + { + maxSize = memStatus.ullAvailPhys / 4; + } + if (this.m_vertexParams.RemoteArch == "i386") + { + maxSize = Math.Min(maxSize, 1024 * 1024 * 1024UL); + } + if (this.NumberOfOutputs > 0) + { + maxSize = maxSize / this.NumberOfOutputs; + } + + UInt64 buffSize = (this.UseLargeBuffer) ? (256 * 1024 * 1024UL) : (1024 * 1024UL); + if (buffSize > maxSize) buffSize = maxSize; + if (buffSize < (8 * 1024UL)) buffSize = 8 * 1024; + return (Int32)buffSize; + } + + internal Int64 GetInputSize() + { + Int64 totalSize = 0; + for (UInt32 i = 0; i < this.m_numberOfInputs; i++) + { + Int64 channelSize = HpcLinqNative.GetExpectedLength(this.NativeHandle, i); + if (channelSize == -1) return -1; + totalSize += channelSize; + } + return totalSize; + } + + internal void SetInitialWriteSizeHint() + { + Int64 inputSize = this.GetInputSize(); + UInt64 hsize = (inputSize == -1) ? (5 * 1024 * 1024 * 1024UL) : (UInt64)inputSize; + hsize /= this.NumberOfOutputs; + for (UInt32 i = 0; i < this.NumberOfOutputs; i++) + { + HpcLinqNative.SetInitialSizeHint(this.m_nativeHandle, i, hsize); + } + } + + // + // The Vertex Host native layer will use this bridge method to invoke the vertex entry point + // instead of invoking it directly through the CLR host. + // This has the advantage of doing all the assembly load and invoke work for the generated + // vertex assembly to happen in a managed context, so that any type or assembly load exceptions + // can be caught and reported in full detail. + // + private static void VertexBridge(string vertexBridgeArgs) + { + DryadLinqLog.IsOn = true; + DryadLinqLog.Add(".NET runtime version = v{0}.{1}.{2}", + Environment.Version.Major, + Environment.Version.Minor, + Environment.Version.Build); + DryadLinqLog.Add(".NET runtime GC = {0}({1})", + (GCSettings.IsServerGC) ? "ServerGC" : "WorkstationGC", + GCSettings.LatencyMode); + + try + { + string[] splitArgs = vertexBridgeArgs.Split(','); + if (splitArgs.Length != 4) + { + throw new ArgumentException(string.Format(SR.VertexBridgeBadArgs, vertexBridgeArgs), "vertexBridgeArgs"); + } + + // @TODO: Temporary hack to find the vertex DLL from the job dir (which is currently always one level up from the WD). + // As part of bug 12618 we need to pass this down from the VH. + string moduleName = Path.Combine("..", splitArgs[0]); + string className = splitArgs[1]; + string methodName = splitArgs[2]; + string nativeChannelString = splitArgs[3]; + + Assembly vertexAssembly = Assembly.LoadFrom(moduleName); + + DryadLinqLog.Add("Vertex Bridge loaded assembly {0}", vertexAssembly.Location); + + MethodInfo vertexMethod = vertexAssembly.GetType(className).GetMethod(methodName, BindingFlags.Static | BindingFlags.Public); + vertexMethod.Invoke(null, new object[] { nativeChannelString }); + } + catch (Exception e) + { + // + // Any exception that happens in the vertex code will come wrapped in a TargetInvocationException since we're using Invoke(). + // We only want to report the inner exception in this case. + // If the exception is of another type (most likely one coming from the Assembly.LoadFrom() call), then we will report it as is. + // + if (e is TargetInvocationException && e.InnerException != null) + { + ReportVertexError(e.InnerException); + if (ErrorCode == 0) throw e.InnerException; + } + else + { + ReportVertexError(e); + if (ErrorCode == 0) throw; + } + } + } + + } +} diff --git a/LinqToDryad/DryadVertexReader.cs b/LinqToDryad/DryadVertexReader.cs new file mode 100644 index 0000000..fd403dc --- /dev/null +++ b/LinqToDryad/DryadVertexReader.cs @@ -0,0 +1,243 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Globalization; +using System.Reflection; +using System.Linq.Expressions; +using System.Linq; +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq.Internal +{ + // The class encapsulates the external environment in which a + // managed query operator reads from Dryad channels. + public class HpcVertexReader : IMultiEnumerable + { + private HpcLinqVertexEnv m_dvertexEnv; + private IntPtr m_nativeHandle; + private HpcLinqFactory m_readerFactory; + private UInt32 m_startPort; + private UInt32 m_numberOfInputs; + internal HpcRecordReader[] m_readers; + internal UInt32[] m_portPermArray; + private bool m_isUsed; + + public HpcVertexReader(HpcLinqVertexEnv denv, HpcLinqFactory readerFactory, UInt32 startPort, UInt32 endPort) + { + this.m_dvertexEnv = denv; + this.m_nativeHandle = denv.NativeHandle; + this.m_readerFactory = readerFactory; + this.m_startPort = startPort; + this.m_numberOfInputs = endPort - startPort; + this.m_portPermArray = new UInt32[this.NumberOfInputs]; + for (UInt32 i = 0; i < this.NumberOfInputs; i++) + { + this.m_portPermArray[i] = i; + } + if (!denv.KeepInputPortOrder) + { + Random rdm = new Random(System.Diagnostics.Process.GetCurrentProcess().Id); + Int32 max = (Int32)this.NumberOfInputs; + for (UInt32 i = 1; i < this.NumberOfInputs; i++) + { + int idx = rdm.Next(max); + UInt32 n = this.m_portPermArray[max-1]; + this.m_portPermArray[max-1] = this.m_portPermArray[idx]; + this.m_portPermArray[idx] = n; + max--; + } + } + + this.m_readers = new HpcRecordReader[this.NumberOfInputs]; + for (UInt32 i = 0; i < this.NumberOfInputs; i++) + { + this.m_readers[i] = this.m_readerFactory.MakeReader(this.m_nativeHandle, startPort + i); + } + this.m_isUsed = false; + } + + public HpcVertexReader(HpcLinqVertexEnv denv, HpcLinqFactory readerFactory, UInt32 portNum) + { + this.m_dvertexEnv = denv; + this.m_nativeHandle = denv.NativeHandle; + this.m_readerFactory = readerFactory; + this.m_startPort = portNum; + this.m_numberOfInputs = 1; + this.m_portPermArray = new UInt32[] { 0 }; + HpcRecordReader reader = readerFactory.MakeReader(this.m_nativeHandle, portNum); + this.m_readers = new HpcRecordReader[] { reader }; + this.m_isUsed = false; + } + + public HpcLinqVertexEnv VertexEnv + { + get { return this.m_dvertexEnv; } + } + + public IntPtr NativeHandle + { + get { return this.m_nativeHandle; } + } + + public UInt32 NumberOfInputs + { + get { return this.m_numberOfInputs; } + } + + public IEnumerable this[int idx] + { + get { return this.m_readers[idx]; } + } + + public Int64 GetTotalLength() + { + Int64 totalLen = 0; + for (UInt32 i = 0; i < this.NumberOfInputs; i++) + { + Int64 chLen = this.m_readers[i].GetTotalLength(); + if (chLen < 0) return -1; + totalLen += chLen; + } + return totalLen; + } + + public string GetChannelURI(int idx) + { + return this.m_readers[idx].GetChannelURI(); + } + + // Close the internal Dryad readers. + public void CloseReaders() + { + for (UInt32 i = 0; i < this.NumberOfInputs; i++) + { + this.m_readers[i].Close(); + } + } + + // Make this reader into a System.IO.Stream. + internal Stream InputStream + { + get { + if (this.m_isUsed) + { + throw new DryadLinqException(HpcLinqErrorCode.ChannelCannotBeReadMoreThanOnce, + SR.ChannelCannotBeReadMoreThanOnce); + } + this.m_isUsed = true; + + HpcBinaryReader[] inputStreamArray = new HpcBinaryReader[this.NumberOfInputs]; + for (int i = 0; i < this.NumberOfInputs; i++) + { + NativeBlockStream nativeStream = new HpcLinqChannel(this.m_nativeHandle, + this.m_portPermArray[i] + this.m_startPort, + true); + inputStreamArray[i] = new HpcBinaryReader(nativeStream); + } + return new HpcLinqMultiInputStream(inputStreamArray); + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + if (this.m_isUsed) + { + throw new DryadLinqException(HpcLinqErrorCode.ChannelCannotBeReadMoreThanOnce, + SR.ChannelCannotBeReadMoreThanOnce); + } + this.m_isUsed = true; + + return new RecordEnumerator(this); + } + + // Internal enumerator class (sync) + private class RecordEnumerator : IEnumerator + { + private HpcVertexReader m_vertexReader; + private HpcRecordReader m_curReader; + private UInt32 m_nextPortIdx; + private T m_current; + + public RecordEnumerator(HpcVertexReader reader) + { + this.m_vertexReader = reader; + this.m_nextPortIdx = 0; + UInt32 curPortIdx = this.GetNextPortIdx(); + this.m_curReader = this.m_vertexReader.m_readers[curPortIdx]; + this.m_current = default(T); + } + + private UInt32 GetNextPortIdx() + { + return this.m_vertexReader.m_portPermArray[this.m_nextPortIdx++]; + } + + public bool MoveNext() + { + while (true) + { + if (this.m_curReader.ReadRecordSync(ref this.m_current)) + { + return true; + } + if (this.m_nextPortIdx >= this.m_vertexReader.m_readers.Length) + { + return false; + } + + UInt32 curPortNum = this.GetNextPortIdx(); + this.m_curReader = this.m_vertexReader.m_readers[curPortNum]; + } + } + + object IEnumerator.Current + { + get { return this.m_current; } + } + + public T Current + { + get { return this.m_current; } + } + + public void Reset() + { + throw new InvalidOperationException(); + } + + void IDisposable.Dispose() + { + this.m_vertexReader.CloseReaders(); + } + } + } +} diff --git a/LinqToDryad/DryadVertexWriter.cs b/LinqToDryad/DryadVertexWriter.cs new file mode 100644 index 0000000..2259e36 --- /dev/null +++ b/LinqToDryad/DryadVertexWriter.cs @@ -0,0 +1,180 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.IO; +using System.Globalization; +using System.Reflection; +using System.Linq.Expressions; +using System.Linq; +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq.Internal +{ + // The class encapsulates the external environment in which a managed + // query operator writes to HpcLinq channels. + public class HpcVertexWriter + { + private HpcLinqVertexEnv m_dvertexEnv; + private IntPtr m_nativeHandle; + private UInt32 m_startPort; + private UInt32 m_numberOfOutputs; + private HpcLinqFactory m_writerFactory; + private HpcRecordWriter[] m_writers; + + public HpcVertexWriter(HpcLinqVertexEnv denv, HpcLinqFactory writerFactory, UInt32 startPort, UInt32 endPort) + { + this.m_dvertexEnv = denv; + this.m_nativeHandle = denv.NativeHandle; + this.m_startPort = startPort; + this.m_numberOfOutputs = endPort - startPort; + this.m_writerFactory = writerFactory; + this.m_writers = new HpcRecordWriter[this.m_numberOfOutputs]; + Int32 buffSize = this.m_dvertexEnv.GetWriteBuffSize(); + for (UInt32 i = 0; i < this.m_numberOfOutputs; i++) + { + this.m_writers[i] = writerFactory.MakeWriter(this.m_nativeHandle, i + startPort, buffSize); + } + } + + public HpcVertexWriter(HpcLinqVertexEnv denv, HpcLinqFactory writerFactory, UInt32 portNum) + { + this.m_dvertexEnv = denv; + this.m_nativeHandle = denv.NativeHandle; + this.m_startPort = portNum; + this.m_numberOfOutputs = 1; + this.m_writerFactory = writerFactory; + Int32 buffSize = this.m_dvertexEnv.GetWriteBuffSize(); + HpcRecordWriter writer = writerFactory.MakeWriter(this.m_nativeHandle, portNum, buffSize); + this.m_writers = new HpcRecordWriter[] { writer }; + } + + public HpcLinqVertexEnv VertexEnv + { + get { return this.m_dvertexEnv; } + } + + public IntPtr NativeHandle + { + get { return this.m_nativeHandle; } + } + + public UInt32 NumberOfOutputs + { + get { return this.m_numberOfOutputs; } + } + + internal HpcRecordWriter GetWriter(UInt32 portNum) + { + return this.m_writers[portNum]; + } + + public void WriteItemSequence(IEnumerable source) + { + HpcRecordWriter writer = this.m_writers[0]; + + if (m_dvertexEnv.MultiThreading) + { + foreach (T item in source) + { + writer.WriteRecordAsync(item); + } + } + else + { + foreach (T item in source) + { + writer.WriteRecordSync(item); + } + } + this.CloseWriters(); + } + + // Write a single item to the output channel. Use sync write. + internal void WriteItem(T item, Int32 portNum) + { + this.m_writers[portNum].WriteRecordSync(item); + } + + public string GetChannelURI(int idx) + { + return this.m_writers[idx].GetChannelURI(); + } + + public Int64 GetChannelLength(int idx) + { + return this.m_writers[idx].GetTotalLength(); + } + + public UInt64 GetChannelFP(int idx) + { + return this.m_writers[idx].GetFingerPrint(); + } + + public void SetCalcFP(int idx) + { + this.m_writers[idx].CalcFP = true; + } + + public Int64 GetTotalLength() + { + Int64 totalLen = 0; + for (UInt32 i = 0; i < this.NumberOfOutputs; i++) + { + Int64 chLen = this.m_writers[i].GetTotalLength(); + if (chLen < 0) return -1; + totalLen += chLen; + } + return totalLen; + } + + internal Stream OutputStream + { + get { + if (this.m_numberOfOutputs != 1) + { + throw new InvalidOperationException(); + } + NativeBlockStream nativeStream = new HpcLinqChannel(this.m_nativeHandle, this.m_startPort, false); + return new HpcBinaryWriterToStreamAdapter(new HpcBinaryWriter(nativeStream)); + } + } + + public void FlushWriters() + { + for (UInt32 i = 0; i < this.NumberOfOutputs; i++) + { + this.m_writers[i].Flush(); + } + } + + public void CloseWriters() + { + for (UInt32 i = 0; i < this.NumberOfOutputs; i++) + { + this.m_writers[i].Close(); + } + } + } +} diff --git a/LinqToDryad/DscClientHelper.cs b/LinqToDryad/DscClientHelper.cs new file mode 100644 index 0000000..4bafb14 --- /dev/null +++ b/LinqToDryad/DscClientHelper.cs @@ -0,0 +1,548 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Text; +using System.IO; +using System.Runtime.InteropServices; +using System.Diagnostics; +using System.IO.Compression; +using System.Collections; +using System.Collections.Generic; + +using Microsoft.Win32.SafeHandles; + +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq.Internal +{ + + + + internal class DscIOStream : System.IO.Stream + { + private DscService m_dscClient; + private string m_fileSetName; + private FileAccess m_mode; + private FileStream m_fstream; + private DscFileSet m_dscFileSet; + private IEnumerator m_dscFileEnumerator; + private ulong size = 0; + private bool m_atEOF; + private DscCompressionScheme m_compressionScheme; + + public DscIOStream(string streamName, FileAccess access, DscCompressionScheme compressionScheme) + { + if (String.IsNullOrEmpty(streamName)) + { + throw new ArgumentNullException("streamName"); + } + + Uri streamUri = new Uri(streamName); + this.m_dscClient = new DscService(streamUri.Host); + this.m_fileSetName = streamUri.LocalPath; + this.m_mode = access; + this.m_fstream = null; + this.m_atEOF = false; + this.m_compressionScheme = compressionScheme; + + if (access == FileAccess.Read) + { + this.m_dscFileSet = this.m_dscClient.GetFileSet(streamName); + this.m_dscFileEnumerator = this.m_dscFileSet.GetFiles().GetEnumerator(); + } + else if (access == FileAccess.Write) + { + this.m_dscFileSet = this.m_dscClient.CreateFileSet(streamName, compressionScheme); + } + else + { + throw new ArgumentException(SR.ReadWriteNotSupported, "access"); + } + } + + public DscIOStream(string streamName, FileAccess access, FileMode createMode, DscCompressionScheme compressionScheme) + { + if (String.IsNullOrEmpty(streamName)) + { + throw new ArgumentNullException("streamName"); + } + + Uri streamUri = new Uri(streamName); + this.m_dscClient = new DscService(streamUri.Host); + this.m_fileSetName = streamUri.LocalPath.TrimStart('/'); + this.m_mode = access; + this.m_compressionScheme = compressionScheme; + + bool streamExists = this.m_dscClient.FileSetExists(this.m_fileSetName); + + if (access == FileAccess.Read) + { + switch (createMode) + { + case FileMode.Open: + case FileMode.OpenOrCreate: + if (!streamExists) + { + throw new FileNotFoundException(String.Format( SR.StreamDoesNotExist , streamName)); + } + break; + + case FileMode.Append: + case FileMode.Create: + case FileMode.CreateNew: + case FileMode.Truncate: + throw new NotSupportedException(); + } + + this.m_dscFileSet = this.m_dscClient.GetFileSet(streamName); + this.m_dscFileEnumerator = this.m_dscFileSet.GetFiles().GetEnumerator(); + } + else if (access == FileAccess.Write) + { + switch (createMode) + { + case FileMode.Append: + if (!streamExists) + { + this.m_dscFileSet = this.m_dscClient.CreateFileSet(this.m_fileSetName, this.m_compressionScheme); + } + break; + case FileMode.Create: + if (streamExists) + { + this.m_dscClient.DeleteFileSet(this.m_fileSetName); + } + this.m_dscFileSet = this.m_dscClient.CreateFileSet(this.m_fileSetName, this.m_compressionScheme); + break; + case FileMode.CreateNew: + if (streamExists) + { + throw new IOException(String.Format(SR.StreamAlreadyExists, streamName)); + } + break; + case FileMode.Truncate: + if (streamExists) + { + this.m_dscClient.DeleteFileSet(this.m_fileSetName); + } + this.m_dscFileSet = this.m_dscClient.CreateFileSet(this.m_fileSetName, this.m_compressionScheme); + break; + case FileMode.Open: + case FileMode.OpenOrCreate: // TODO: this should be dealt with correctly, + // although it's not obvious what open should do + throw new NotSupportedException(); + } + } + else + { + throw new ArgumentException(SR.ReadWriteNotSupported, "access"); + } + + this.m_fstream = null; + this.m_atEOF = false; + } + + public override bool CanRead + { + get { + return this.m_mode == FileAccess.Read; + } + } + + public override bool CanSeek + { + get { return false; } + } + + public override bool CanWrite + { + get { + return this.m_mode == FileAccess.Write; + } + } + + public override void Close() + { + try + { + if (this.m_fstream != null && this.m_mode == FileAccess.Write) + { + this.SealPartition(); + } + } + + finally + { + this.m_dscClient.Close(); + } + } + + public override void Flush() + { + if (this.m_fstream != null) + { + this.m_fstream.Flush(); + } + } + + public override long Length + { + get { throw new NotImplementedException(); } + } + + private void OpenForRead() + { + Debug.Assert(this.m_fstream == null); + + if (this.m_dscFileEnumerator.MoveNext()) + { + // TODO(bug 15879): Should failover to other readpath on failure if available + string path = this.m_dscFileEnumerator.Current.ReadPaths[0]; + this.m_fstream = new FileStream(path, FileMode.Open, FileAccess.Read, FileShare.Read, 4 * 65536, false); + } + else + { + this.m_atEOF = true; + } + } + + private void OpenForWrite(bool synchronously) + { + if (this.m_fstream != null) + { + throw new InvalidOperationException(); + } + + // @@TODO: Should try to estimate size + DscFile dscFile = this.m_dscFileSet.AddNewFile(1); + this.m_fstream = new FileStream(dscFile.WritePath, FileMode.Create, FileAccess.Write, FileShare.None, 4 * 65536, synchronously); + } + + internal void SealPartition() + { + if (this.m_fstream != null) + { + this.m_fstream.Close(); + this.m_fstream = null; + + this.m_dscFileSet.Seal(); + } + } + + public override long Position + { + get + { + throw new NotImplementedException(); + } + set + { + throw new NotImplementedException(); + } + } + + public override int Read(byte[] buffer, int offset, int count) + { + if (this.m_mode == FileAccess.Write) + { + throw new DryadLinqException(HpcLinqErrorCode.AttemptToReadFromAWriteStream, + SR.AttemptToReadFromAWriteStream); + } + int totalBytesRead = 0; + while (totalBytesRead < count && !this.m_atEOF) + { + if (this.m_fstream == null) + { + this.OpenForRead(); + if (this.m_atEOF) + { + break; // we hit EOF (EOS, really), so fall out of the loop + } + } + int bytesRead = this.m_fstream.Read(buffer, offset + totalBytesRead, count - totalBytesRead); + totalBytesRead += bytesRead; + if (bytesRead == 0) + { + this.m_fstream.Close(); + this.m_fstream = null; + } + } + return totalBytesRead; + } + + internal unsafe int Read(byte* buffer, int bufferSize) + { + int totalBytesRead = 0; + do + { + SafeFileHandle handle; + if (this.m_fstream == null) + { + this.OpenForRead(); + } + if (this.m_atEOF) break; + + handle = this.m_fstream.SafeFileHandle; + int size = 0; + Int32* pBlockSize = &size; + bool success = HpcLinqNative.ReadFile(handle, buffer, (UInt32)bufferSize, (IntPtr)pBlockSize, null); + if (!success) + { + throw new DryadLinqException(HpcLinqErrorCode.ReadFileError, + String.Format(SR.ReadFileError, Marshal.GetLastWin32Error())); + } + totalBytesRead += size; + + if (size == 0) + { + this.m_fstream.Close(); + this.m_fstream = null; + } + } while (totalBytesRead == 0 && !this.m_atEOF); + + return totalBytesRead; + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public unsafe int Write(byte* buffer, int offset, int count) + { + if (this.m_mode == FileAccess.Read) + { + throw new DryadLinqException(HpcLinqErrorCode.AttemptToReadFromAWriteStream, + SR.AttemptToReadFromAWriteStream); + } + if (this.m_fstream == null) + { + this.OpenForWrite(false); + } + + SafeFileHandle handle = this.m_fstream.SafeFileHandle; + int size; + Int32* pBlockSize = &size; + + bool success = HpcLinqNative.WriteFile(handle, buffer, (UInt32)count, (IntPtr)pBlockSize, null); + if (!success) + { + throw new DryadLinqException(HpcLinqErrorCode.WriteFileError, + String.Format(SR.WriteFileError, Marshal.GetLastWin32Error())); + } + + this.size += (ulong)size; + return size; + } + + public override void Write(byte[] buffer, int offset, int count) + { + if (this.m_mode == FileAccess.Read) + { + throw new DryadLinqException(HpcLinqErrorCode.AttemptToReadFromAWriteStream, + SR.AttemptToReadFromAWriteStream); + } + if (this.m_fstream == null) + { + this.OpenForWrite(true); + } + this.size += (ulong)count; + this.m_fstream.Write(buffer, offset, count); + } + } + + /// + /// Handle interaction between DryadLINQ serialization and DSC streams. + /// + internal unsafe class DscBlockStream : NativeBlockStream + { + private const int DefaultBuffSize = 8192*32; + + private DscIOStream m_dscStream; + private DscCompressionScheme m_compressionScheme; + private bool m_isClosed; + private Stream m_compressStream; + + public DscBlockStream(DscIOStream dscStream, DscCompressionScheme scheme) + { + this.m_dscStream = dscStream; + this.m_compressionScheme = scheme; + this.m_isClosed = false; + this.m_compressStream = null; + } + + private void Initialize(string filePath, FileMode mode, FileAccess access, DscCompressionScheme scheme) + { + try + { + this.m_dscStream = new DscIOStream(filePath, access, mode, scheme); + } + catch (Exception e) + { + throw new DryadLinqException(HpcLinqErrorCode.FailedToCreateStream, + String.Format(SR.FailedToCreateStream, filePath), e); + } + this.m_isClosed = false; + this.m_compressionScheme = scheme; + this.m_compressStream = null; + } + + public DscBlockStream(string filePath, FileAccess access, DscCompressionScheme scheme) + { + FileMode mode = (access == FileAccess.Read) ? FileMode.Open : FileMode.OpenOrCreate; + this.Initialize(filePath, mode, access, scheme); + } + + public DscBlockStream(string filePath, FileMode mode, FileAccess access, DscCompressionScheme scheme) + { + this.Initialize(filePath, mode, access, scheme); + } + + internal override Int64 GetTotalLength() + { + return (Int64)this.m_dscStream.Length; + } + + internal override DataBlockInfo ReadDataBlock() + { + DataBlockInfo blockInfo; + blockInfo.dataBlock = (byte*)Marshal.AllocHGlobal(DefaultBuffSize); + blockInfo.itemHandle = (IntPtr)blockInfo.dataBlock; + if (this.m_compressionScheme == DscCompressionScheme.None) + { + blockInfo.blockSize = this.m_dscStream.Read(blockInfo.dataBlock, DefaultBuffSize); + } + else + { + if (this.m_compressStream == null) + { + if (this.m_compressionScheme == DscCompressionScheme.Gzip) + { + this.m_compressStream = new GZipStream(this.m_dscStream, CompressionMode.Decompress); + } + else + { + throw new DryadLinqException(HpcLinqErrorCode.UnknownCompressionScheme, + SR.UnknownCompressionScheme); + } + } + // YY: Made an extra copy here. Could do better. + byte[] buffer = new byte[DefaultBuffSize]; + blockInfo.blockSize = this.m_compressStream.Read(buffer, 0, DefaultBuffSize); + fixed (byte* pBuffer = buffer) + { + HpcLinqUtil.memcpy(pBuffer, blockInfo.dataBlock, blockInfo.blockSize); + } + } + + return blockInfo; + } + + internal override unsafe bool WriteDataBlock(IntPtr itemHandle, Int32 numBytesToWrite) + { + byte* dataBlock = (byte*)itemHandle; + if (this.m_compressionScheme == DscCompressionScheme.None) + { + Int32 numBytesWritten = 0; + Int32 remainingBytes = numBytesToWrite; + + while (remainingBytes > 0) + { + numBytesWritten = this.m_dscStream.Write(dataBlock, 0, remainingBytes); + dataBlock += numBytesWritten; + remainingBytes -= numBytesWritten; + } + } + else + { + if (this.m_compressStream == null) + { + if (this.m_compressionScheme == DscCompressionScheme.Gzip) + { + this.m_compressStream = new GZipStream(this.m_dscStream, CompressionMode.Compress); + } + else + { + throw new DryadLinqException(HpcLinqErrorCode.UnknownCompressionScheme, + SR.UnknownCompressionScheme); + } + } + // YY: Made an extra copy here. Could do better. + byte[] buffer = new byte[numBytesToWrite]; + fixed (byte* pBuffer = buffer) + { + HpcLinqUtil.memcpy(dataBlock, pBuffer, numBytesToWrite); + } + this.m_compressStream.Write(buffer, 0, numBytesToWrite); + } + return true; + } + + internal override void Flush() + { + if (this.m_compressStream != null) + { + this.m_compressStream.Flush(); + } + this.m_dscStream.Flush(); + } + + internal override void Close() + { + if (!this.m_isClosed) + { + this.m_isClosed = true; + if (this.m_compressStream != null) + { + this.m_compressStream.Close(); + } + this.m_dscStream.Close(); + this.m_compressStream = null; + this.m_dscStream = null; + } + } + + internal override unsafe DataBlockInfo AllocateDataBlock(Int32 size) + { + DataBlockInfo blockInfo; + blockInfo.itemHandle = Marshal.AllocHGlobal((IntPtr)size); + blockInfo.dataBlock = (byte*)blockInfo.itemHandle; + blockInfo.blockSize = size; + return blockInfo; + } + + internal override unsafe void ReleaseDataBlock(IntPtr itemHandle) + { + if (itemHandle != IntPtr.Zero) + { + Marshal.FreeHGlobal(itemHandle); + } + } + } +} diff --git a/LinqToDryad/DscStubs.cs b/LinqToDryad/DscStubs.cs new file mode 100644 index 0000000..4aded50 --- /dev/null +++ b/LinqToDryad/DscStubs.cs @@ -0,0 +1,178 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//-------------------------------------------------------------------------- +// +// +// Fileset compression modes supported by DSC. +// +//-------------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.Research.DryadLinq +{ + //YARN + public enum DscCompressionScheme + { + None, + Gzip + } + + public class DscService + { + private string m_headNode; + public DscService(string headNode) + { + m_headNode = headNode; + } + + internal DscFileSet GetFileSet(string streamName) + { + throw new NotImplementedException(); + } + + internal bool FileSetExists(string dscFileSetName) + { + throw new NotImplementedException(); + } + + internal void DeleteFileSet(string dscFileSetName) + { + throw new NotImplementedException(); + } + + internal DscFileSet CreateFileSet(string streamName, DscCompressionScheme compressionScheme) + { + throw new NotImplementedException(); + } + + internal void Close() + { + throw new NotImplementedException(); + } + + public string HostName { get; set; } + } + + public class DscFileSet + { + + internal DscFile AddNewFile(int p) + { + throw new NotImplementedException(); + } + + public DscCompressionScheme CompressionScheme { get; set; } + + internal byte[] GetMetadata(string p) + { + throw new NotImplementedException(); + } + + internal void Seal() + { + throw new NotImplementedException(); + } + + internal IEnumerable GetFiles() + { + throw new NotImplementedException(); + } + + internal bool IsSealed() + { + throw new NotImplementedException(); + } + + internal void SetLeaseEndTime(DateTime dateTime) + { + throw new NotImplementedException(); + } + + internal void SetMetadata(string p1, byte[] p2) + { + throw new NotImplementedException(); + } + } + + public class DscFile + { + public string[] ReadPaths { get; set; } + public string WritePath { get; set; } + } + + internal class DscInstance: IDisposable + { + + public DscInstance(Uri uri) + { + throw new NotImplementedException(); + } + + + internal DscStream GetStream(Uri uri) + { + throw new NotImplementedException(); + } + + public void Dispose() + { + throw new NotImplementedException(); + } + } + + internal class DscStream + { + + public long Length { get; set; } + + public int PartitionCount { get; set; } + } + + public class DscException : Exception + { + + } + + public interface IScheduler + { + + void Connect(string headNode); + + void Dispose(); + + IServerVersion GetServerVersion(); + } + + public interface IServerVersion + { + + int Major { get; set; } + + int Minor { get; set; } + + int Build { get; set; } + + int Revision { get; set; } + } +} diff --git a/LinqToDryad/DynamicManager.cs b/LinqToDryad/DynamicManager.cs new file mode 100644 index 0000000..a2d66d7 --- /dev/null +++ b/LinqToDryad/DynamicManager.cs @@ -0,0 +1,205 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.Text; +using System.IO; +using System.Reflection; +using System.Linq; +using System.Linq.Expressions; +using System.CodeDom; +using System.Diagnostics; +using System.Xml; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + internal enum DynamicManagerType + { + None, + Splitter, + PartialAggregator, + FullAggregator, + HashDistributor, + RangeDistributor, + Broadcast + } + + internal class DynamicManager + { + internal static DynamicManager None = new DynamicManager(DynamicManagerType.None); + internal static DynamicManager Splitter = new DynamicManager(DynamicManagerType.Splitter); + internal static DynamicManager PartialAggregator = new DynamicManager(DynamicManagerType.PartialAggregator); + internal static DynamicManager Broadcast = new DynamicManager(DynamicManagerType.Broadcast); + + private DynamicManagerType m_managerType; + internal protected List m_vertexNodes; + private string[] m_vertexNames; + + private DynamicManager(DynamicManagerType type) + { + this.m_managerType = type; + this.m_vertexNodes = new List(); + this.m_vertexNames = null; + this.AggregationLevels = 0; + // default aggregation has 1 level + if (type == DynamicManagerType.FullAggregator) + { + AggregationLevels = 1; + } + } + + /// + /// Create a dynamic manager (a Dryad policy manager) with a list of parameter nodes. + /// + /// Type of dynamic manager to create. + /// Nodes that the manager depends on. + internal DynamicManager(DynamicManagerType type, List nodes) + : this(type) + { + this.m_vertexNodes.AddRange(nodes); + } + + /// + /// Create a dynamic manager with a single parameter node. + /// + /// Type of manager to create. + /// Node that the manager depends on. + internal DynamicManager(DynamicManagerType type, DryadQueryNode node) + : this(type) + { + this.m_vertexNodes.Add(node); + } + + internal DynamicManagerType ManagerType + { + get { return this.m_managerType; } + } + + internal DynamicManager CreateManager(DynamicManagerType type) + { + return new DynamicManager(type, this.m_vertexNodes); + } + + /// + /// The aggregation level of the dynamic manager (used for aggregations only). + /// + internal int AggregationLevels { get; set; } + + internal DryadQueryNode GetVertexNode(int index) + { + return this.m_vertexNodes[index]; + } + + internal void InsertVertexNode(int index, DryadQueryNode node) + { + if (index == -1) + { + this.m_vertexNodes.Add(node); + } + else + { + this.m_vertexNodes.Insert(index, node); + } + } + + internal virtual void CreateVertexCode() + { + if (this.m_vertexNodes.Count != 0) + { + this.m_vertexNames = new string[this.m_vertexNodes.Count]; + for (int i = 0; i < this.m_vertexNames.Length; i++) + { + CodeMemberMethod vertexMethod = this.m_vertexNodes[i].QueryGen.CodeGen.AddVertexMethod(this.m_vertexNodes[i]); + this.m_vertexNames[i] = vertexMethod.Name; + } + } + } + + internal virtual XmlElement CreateElem(XmlDocument queryDoc) + { + XmlElement managerElem = queryDoc.CreateElement("DynamicManager"); + XmlElement elem = queryDoc.CreateElement("Type"); + elem.InnerText = Convert.ToString(this.ManagerType.ToString()); + managerElem.AppendChild(elem); + + if (AggregationLevels != 0) + { + XmlElement agg = queryDoc.CreateElement("AggregationLevels"); + agg.InnerText = AggregationLevels.ToString(); + managerElem.AppendChild(agg); + } + + if (this.m_vertexNames != null) + { + for (int i = 0; i < this.m_vertexNames.Length; i++) + { + string dllName = this.m_vertexNodes[i].QueryGen.CodeGen.GetDryadLinqDllName(); + XmlElement entryElem = DryadQueryDoc.CreateVertexEntryElem(queryDoc, dllName, this.m_vertexNames[i]); + managerElem.AppendChild(entryElem); + } + } + return managerElem; + } + } + + internal class DynamicRangeDistributor : DynamicManager + { + private double m_sampleRate; + + internal DynamicRangeDistributor(DryadQueryNode node) + : base(DynamicManagerType.RangeDistributor, node) + { + //@@TODO[P2]: This sample rate used here should really be its own constant. + this.m_sampleRate = HpcLinqSampler.SAMPLE_RATE; + } + + internal override void CreateVertexCode() + { + } + + internal override XmlElement CreateElem(XmlDocument queryDoc) + { + XmlElement managerElem = queryDoc.CreateElement("DynamicManager"); + XmlElement elem = queryDoc.CreateElement("Type"); + elem.InnerText = Convert.ToString(this.ManagerType.ToString()); + managerElem.AppendChild(elem); + + elem = queryDoc.CreateElement("SampleRate"); + elem.InnerText = Convert.ToString(this.m_sampleRate); + managerElem.AppendChild(elem); + + elem = queryDoc.CreateElement("VertexId"); + DryadQueryNode node = this.m_vertexNodes[0]; + if (node.SuperNode != null) + { + node = node.SuperNode; + } + elem.InnerText = Convert.ToString(node.m_uniqueId); + managerElem.AppendChild(elem); + + return managerElem; + } + } +} diff --git a/LinqToDryad/ExpressionMatcher.cs b/LinqToDryad/ExpressionMatcher.cs new file mode 100644 index 0000000..573517b --- /dev/null +++ b/LinqToDryad/ExpressionMatcher.cs @@ -0,0 +1,458 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Text; +using System.IO; +using System.Reflection; +using System.Linq; +using System.Linq.Expressions; +using System.Diagnostics; + +namespace Microsoft.Research.DryadLinq +{ + // This class implements an expression matcher. This is useful for + // many interesting static analysis. Again, it would be nice if this + // is a functionality provided by C# and LINQ. But unfortunately it + // is not. + internal class ExpressionMatcher + { + public static bool Match(Expression e1, Expression e2) + { + return Match(e1, e2, Substitution.Empty); + } + + // return true if e1 subsumes e2 in terms of member access + internal static bool MemberAccessSubsumes(Expression e1, Expression e2) + { + if (Match(e1, e2)) return true; + Expression e = e2; + while (e is MemberExpression) + { + e = ((MemberExpression)e).Expression; + if (Match(e1, e)) return true; + } + return false; + } + + internal static bool Match(Expression e1, Expression e2, Substitution subst) + { + if (!e1.Type.Equals(e2.Type)) return false; + + if (e1 is BinaryExpression) + { + return ((e2 is BinaryExpression) && + MatchBinary((BinaryExpression)e1, (BinaryExpression)e2, subst)); + } + else if (e1 is ConditionalExpression) + { + return ((e2 is ConditionalExpression) && + MatchConditional((ConditionalExpression)e1, (ConditionalExpression)e2, subst)); + } + else if (e1 is ConstantExpression) + { + return ((e2 is ConstantExpression) && + MatchConstant((ConstantExpression)e1, (ConstantExpression)e2, subst)); + } + else if (e1 is InvocationExpression) + { + return ((e2 is InvocationExpression) && + MatchInvocation((InvocationExpression)e1, (InvocationExpression)e2, subst)); + } + else if (e1 is LambdaExpression) + { + return ((e2 is LambdaExpression) && + MatchLambda((LambdaExpression)e1, (LambdaExpression)e2, subst)); + } + else if (e1 is MemberExpression) + { + return ((e2 is MemberExpression) && + MatchMember((MemberExpression)e1, (MemberExpression)e2, subst)); + } + else if (e1 is MethodCallExpression) + { + return ((e2 is MethodCallExpression) && + MatchMethodCall((MethodCallExpression)e1, (MethodCallExpression)e2, subst)); + } + else if (e1 is NewExpression) + { + return ((e2 is NewExpression) && + MatchNew((NewExpression)e1, (NewExpression)e2, subst)); + } + else if (e1 is NewArrayExpression) + { + return ((e2 is NewArrayExpression) && + MatchNewArray((NewArrayExpression)e1, (NewArrayExpression)e2, subst)); + } + else if (e1 is MemberInitExpression) + { + return ((e2 is MemberInitExpression) && + MatchMemberInit((MemberInitExpression)e1, (MemberInitExpression)e2, subst)); + } + else if (e1 is ListInitExpression) + { + return ((e2 is ListInitExpression) && + MatchListInit((ListInitExpression)e1, (ListInitExpression)e2, subst)); + } + else if (e1 is ParameterExpression) + { + return ((e2 is ParameterExpression) && + MatchParameter((ParameterExpression)e1, (ParameterExpression)e2, subst)); + } + else if (e1 is TypeBinaryExpression) + { + return ((e2 is TypeBinaryExpression) && + MatchTypeBinary((TypeBinaryExpression)e1, (TypeBinaryExpression)e2, subst)); + } + else if (e1 is UnaryExpression) + { + return ((e2 is UnaryExpression) && + MatchUnary((UnaryExpression)e1, (UnaryExpression)e2, subst)); + } + + throw new DryadLinqException(HpcLinqErrorCode.ExpressionTypeNotHandled, + String.Format(SR.ExpressionTypeNotHandled, + "ExpressionMatcher", e1.NodeType)); + } + + private static bool MatchInvocation(InvocationExpression e1, + InvocationExpression e2, + Substitution subst) + { + ReadOnlyCollection args1 = e1.Arguments; + ReadOnlyCollection args2 = e2.Arguments; + if (!Match(e1.Expression, e2.Expression, subst) || args1.Count != args2.Count) + { + return false; + } + for (int i = 0; i < args1.Count; i++) + { + if (!Match(args1[i], args2[i], subst)) + { + return false; + } + } + return true; + } + + private static bool MatchBinary(BinaryExpression e1, + BinaryExpression e2, + Substitution subst) + { + return (e1.NodeType == e2.NodeType && + Match(e1.Left, e2.Left, subst) && + Match(e1.Right, e2.Right, subst)); + } + + private static bool MatchConditional(ConditionalExpression e1, + ConditionalExpression e2, + Substitution subst) + { + return (Match(e1.Test, e2.Test, subst) && + Match(e1.IfTrue, e2.IfTrue, subst) && + Match(e1.IfFalse, e2.IfFalse, subst)); + } + + private static bool MatchConstant(ConstantExpression e1, + ConstantExpression e2, + Substitution subst) + { + if (e1.Value == null) + { + return (e2.Value == null); + } + else + { + return e1.Value.Equals(e2.Value); + } + } + + private static bool MatchLambda(LambdaExpression e1, + LambdaExpression e2, + Substitution subst) + { + if (e1.Parameters.Count != e2.Parameters.Count) + { + return false; + } + Substitution subst1 = subst; + for (int i = 0, n = e1.Parameters.Count; i < n; i++) + { + if (!e1.Parameters[i].Equals(e2.Parameters[i])) + { + subst1 = subst1.Cons(e1.Parameters[i], e2.Parameters[i]); + } + } + return Match(e1.Body, e2.Body, subst1); + } + + private static bool MatchMember(MemberExpression e1, + MemberExpression e2, + Substitution subst) + { + if (e1.Expression == null) + { + if (e2.Expression != null) return false; + } + else + { + if (e2.Expression == null || + !Match(e1.Expression, e2.Expression, subst)) + { + return false; + } + } + return e1.Member.Equals(e2.Member); + } + + private static bool MatchMethodCall(MethodCallExpression e1, + MethodCallExpression e2, + Substitution subst) + { + if (e1.Method != e2.Method) return false; + + if (e1.Object == null || e2.Object == null) + { + if (e1.Object != e2.Object) return false; + } + else if (!Match(e1.Object, e2.Object, subst) || + e1.Arguments.Count != e2.Arguments.Count) + { + return false; + } + for (int i = 0, n = e1.Arguments.Count; i < n; i++) + { + if (!Match(e1.Arguments[i], e2.Arguments[i], subst)) + { + return false; + } + } + return true; + } + + private static bool MatchNew(NewExpression e1, NewExpression e2, Substitution subst) + { + if (e1.Arguments.Count != e2.Arguments.Count) + { + return false; + } + for (int i = 0, n = e1.Arguments.Count; i < n; i++) + { + if (!Match(e1.Arguments[i], e2.Arguments[i], subst)) + { + return false; + } + } + return true; + } + + public static bool MatchNewArray(NewArrayExpression e1, + NewArrayExpression e2, + Substitution subst) + { + if (e1.NodeType != e2.NodeType || + e1.Expressions.Count != e2.Expressions.Count) + { + return false; + } + for (int i = 0, n = e1.Expressions.Count; i < n; i++) + { + if (!Match(e1.Expressions[i], e2.Expressions[i], subst)) + { + return false; + } + } + return true; + } + + public static bool MatchMemberInit(MemberInitExpression e1, + MemberInitExpression e2, + Substitution subst) + { + if (!Match(e1.NewExpression, e2.NewExpression, subst) || + e1.Bindings.Count != e2.Bindings.Count) + { + return false; + } + for (int i = 0, n = e1.Bindings.Count; i < n; i++) + { + if (!MatchMemberBinding(e1.Bindings[i], e2.Bindings[i], subst)) + { + return false; + } + } + return true; + } + + public static bool MatchListInit(ListInitExpression e1, + ListInitExpression e2, + Substitution subst) + { + if (!Match(e1.NewExpression, e2.NewExpression, subst)) + { + return false; + } + if (e1.Initializers.Count != e2.Initializers.Count) + { + return false; + } + for (int i = 0, n = e1.Initializers.Count; i < n; i++) + { + ElementInit init1 = e1.Initializers[i]; + ElementInit init2 = e2.Initializers[i]; + if (!MatchElementInit(init1, init2, subst)) + { + return false; + } + } + return true; + } + + public static bool MatchParameter(ParameterExpression e1, + ParameterExpression e2, + Substitution subst) + { + if (e1.Equals(e2)) return true; + ParameterExpression e = subst.Find(e1); + return (e != null && e.Equals(e2)); + } + + public static bool MatchTypeBinary(TypeBinaryExpression e1, + TypeBinaryExpression e2, + Substitution subst) + { + return (e1.NodeType == ExpressionType.TypeIs && + e2.NodeType == ExpressionType.TypeIs && + e1.TypeOperand.Equals(e2.TypeOperand) && + Match(e1.Expression, e2.Expression, subst)); + } + + public static bool MatchUnary(UnaryExpression e1, UnaryExpression e2, Substitution subst) + { + return (e1.NodeType == e2.NodeType && + Match(e1.Operand, e2.Operand, subst)); + } + + private static bool MatchMemberBinding(MemberBinding b1, MemberBinding b2, Substitution subst) + { + if (b1.BindingType != b2.BindingType || + !b1.Member.Equals(b2.Member)) + { + return false; + } + if (b1 is MemberAssignment) + { + return Match(((MemberAssignment)b1).Expression, + ((MemberAssignment)b2).Expression, + subst); + } + else if (b1 is MemberMemberBinding) + { + MemberMemberBinding mmb1 = (MemberMemberBinding)b1; + MemberMemberBinding mmb2 = (MemberMemberBinding)b2; + if (mmb1.Bindings.Count != mmb2.Bindings.Count) + { + return false; + } + for (int i = 0, n = mmb1.Bindings.Count; i < n; i++) + { + if (!MatchMemberBinding(mmb1.Bindings[i], mmb2.Bindings[i], subst)) + { + return false; + } + } + return true; + } + else + { + MemberListBinding mlb1 = (MemberListBinding)b1; + MemberListBinding mlb2 = (MemberListBinding)b2; + if (mlb1.Initializers.Count != mlb2.Initializers.Count) + { + return false; + } + for (int i = 0, n = mlb1.Initializers.Count; i < n; i++) + { + if (!MatchElementInit(mlb1.Initializers[i], mlb2.Initializers[i], subst)) + { + return false; + } + } + return true; + } + } + + private static bool MatchElementInit(ElementInit init1, + ElementInit init2, + Substitution subst) + { + if (init1.AddMethod != init2.AddMethod || + init1.Arguments.Count != init2.Arguments.Count) + { + return false; + } + for (int i = 0; i < init1.Arguments.Count; i++) + { + if (!Match(init1.Arguments[i], init2.Arguments[i], subst)) + { + return false; + } + } + return true; + } + } + + internal class Substitution + { + private ParameterExpression x; + private ParameterExpression y; + private Substitution next; + + public static Substitution Empty = new Substitution(null, null, null); + + private Substitution(ParameterExpression x, ParameterExpression y, Substitution s) + { + this.x = x; + this.y = y; + this.next = s; + } + + public Substitution Cons(ParameterExpression a, ParameterExpression b) + { + return new Substitution(a, b, this); + } + + public ParameterExpression Find(ParameterExpression a) + { + Substitution curSubst = this; + while (curSubst != Empty) + { + if (curSubst.x.Equals(a)) return y; + curSubst = curSubst.next; + } + return null; + } + } +} diff --git a/LinqToDryad/ExpressionSimplifier.cs b/LinqToDryad/ExpressionSimplifier.cs new file mode 100644 index 0000000..a67da9e --- /dev/null +++ b/LinqToDryad/ExpressionSimplifier.cs @@ -0,0 +1,70 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.Text; +using System.IO; +using System.Reflection; +using System.Linq; +using System.Linq.Expressions; +using System.Diagnostics; + +namespace Microsoft.Research.DryadLinq +{ + // This class implements a simple evaluator for expression. It may + // evolve into a better and more useful thing: a simplifier that + // goes through an expression to remove all dependency to the + // current local execution context. This is necessary to ensure + // that expressions can be remotely executed. Even better, this + // simplifier could detect potential "impure" expressions. Again, + // it would be nice if this is a functionality provided by C# and + // LINQ. But unfortunately it is not. + internal abstract class ExpressionSimplifier + { + internal abstract object EvalBoxed(Expression expr); + + // Evaluate the expression in the current local execution context. + internal static object Evaluate(Expression expr) + { + Type qType = typeof(ExpressionSimplifier<>).MakeGenericType(expr.Type); + ExpressionSimplifier evaluator = (ExpressionSimplifier)Activator.CreateInstance(qType); + return evaluator.EvalBoxed(expr); + } + } + + internal class ExpressionSimplifier : ExpressionSimplifier + { + internal override object EvalBoxed(Expression expr) + { + return this.Eval(expr); + } + + internal T Eval(Expression expr) + { + Expression> lambda = Expression.Lambda>(expr); + Func func = lambda.Compile(); + return func(); + } + } +} diff --git a/LinqToDryad/ExpressionVisitor.cs b/LinqToDryad/ExpressionVisitor.cs new file mode 100644 index 0000000..b446944 --- /dev/null +++ b/LinqToDryad/ExpressionVisitor.cs @@ -0,0 +1,756 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.Text; +using System.Collections.ObjectModel; +using System.Reflection; +using System.Linq; +using System.Linq.Expressions; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + // This class implements a generic expression visitor. The code was stolen + // from LINQ source code. + internal abstract class ExpressionVisitor + { + internal ExpressionVisitor() + { + } + + internal virtual Expression Visit(Expression exp) + { + if (exp == null) return exp; + switch (exp.NodeType) + { + case ExpressionType.Negate: + case ExpressionType.NegateChecked: + case ExpressionType.Not: + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + case ExpressionType.ArrayLength: + case ExpressionType.Quote: + case ExpressionType.TypeAs: + { + return this.VisitUnary((UnaryExpression)exp); + } + case ExpressionType.Add: + case ExpressionType.AddChecked: + case ExpressionType.Subtract: + case ExpressionType.SubtractChecked: + case ExpressionType.Multiply: + case ExpressionType.MultiplyChecked: + case ExpressionType.Divide: + case ExpressionType.Modulo: + case ExpressionType.And: + case ExpressionType.AndAlso: + case ExpressionType.Or: + case ExpressionType.OrElse: + case ExpressionType.LessThan: + case ExpressionType.LessThanOrEqual: + case ExpressionType.GreaterThan: + case ExpressionType.GreaterThanOrEqual: + case ExpressionType.Equal: + case ExpressionType.NotEqual: + case ExpressionType.Coalesce: + case ExpressionType.ArrayIndex: + case ExpressionType.RightShift: + case ExpressionType.LeftShift: + case ExpressionType.ExclusiveOr: + { + return this.VisitBinary((BinaryExpression)exp); + } + case ExpressionType.TypeIs: + { + return this.VisitTypeIs((TypeBinaryExpression)exp); + } + case ExpressionType.Conditional: + { + return this.VisitConditional((ConditionalExpression)exp); + } + case ExpressionType.Constant: + { + return this.VisitConstant((ConstantExpression)exp); + } + case ExpressionType.Parameter: + { + return this.VisitParameter((ParameterExpression)exp); + } + case ExpressionType.MemberAccess: + { + return this.VisitMemberAccess((MemberExpression)exp); + } + case ExpressionType.Call: + { + return this.VisitMethodCall((MethodCallExpression)exp); + } + case ExpressionType.Lambda: + { + return this.VisitLambda((LambdaExpression)exp); + } + case ExpressionType.New: + { + return this.VisitNew((NewExpression)exp); + } + case ExpressionType.NewArrayInit: + case ExpressionType.NewArrayBounds: + { + return this.VisitNewArray((NewArrayExpression)exp); + } + case ExpressionType.Invoke: + { + return this.VisitInvocation((InvocationExpression)exp); + } + case ExpressionType.MemberInit: + { + return this.VisitMemberInit((MemberInitExpression)exp); + } + case ExpressionType.ListInit: + { + return this.VisitListInit((ListInitExpression)exp); + } + default: + { + throw new DryadLinqException(HpcLinqErrorCode.ExpressionTypeNotHandled, + String.Format(SR.ExpressionTypeNotHandled, + "ExpressionVisitor", exp.NodeType)); + } + } + } + + internal virtual MemberBinding VisitBinding(MemberBinding binding) + { + switch (binding.BindingType) + { + case MemberBindingType.Assignment: + { + return this.VisitMemberAssignment((MemberAssignment)binding); + } + case MemberBindingType.MemberBinding: + { + return this.VisitMemberMemberBinding((MemberMemberBinding)binding); + } + case MemberBindingType.ListBinding: + { + return this.VisitMemberListBinding((MemberListBinding)binding); + } + default: + { + throw new DryadLinqException(HpcLinqErrorCode.ExpressionTypeNotHandled, + String.Format(SR.ExpressionTypeNotHandled, + "ExpressionVisitor", binding.BindingType)); + } + } + } + + internal virtual ElementInit VisitElementInitializer(ElementInit initializer) + { + ReadOnlyCollection arguments = this.VisitExpressionList(initializer.Arguments); + if (arguments != initializer.Arguments) + { + return Expression.ElementInit(initializer.AddMethod, arguments); + } + return initializer; + } + + internal virtual Expression VisitUnary(UnaryExpression u) + { + Expression operand = this.Visit(u.Operand); + if (operand != u.Operand) + { + return Expression.MakeUnary(u.NodeType, operand, u.Type, u.Method); + } + return u; + } + + internal virtual Expression VisitBinary(BinaryExpression b) + { + Expression left = this.Visit(b.Left); + Expression right = this.Visit(b.Right); + if (left != b.Left || right != b.Right) + { + return Expression.MakeBinary(b.NodeType, left, right, b.IsLiftedToNull, b.Method); + } + return b; + } + + internal virtual Expression VisitTypeIs(TypeBinaryExpression b) + { + Expression expr = this.Visit(b.Expression); + if (expr != b.Expression) + { + return Expression.TypeIs(expr, b.TypeOperand); + } + return b; + } + + internal virtual Expression VisitConstant(ConstantExpression c) + { + return c; + } + + internal virtual Expression VisitConditional(ConditionalExpression c) + { + Expression test = this.Visit(c.Test); + Expression ifTrue = this.Visit(c.IfTrue); + Expression ifFalse = this.Visit(c.IfFalse); + if (test != c.Test || ifTrue != c.IfTrue || ifFalse != c.IfFalse) + { + return Expression.Condition(test, ifTrue, ifFalse); + } + return c; + } + + internal virtual Expression VisitParameter(ParameterExpression p) + { + return p; + } + + internal virtual Expression VisitMemberAccess(MemberExpression m) + { + Expression exp = this.Visit(m.Expression); + if (exp != m.Expression) + { + return Expression.MakeMemberAccess(exp, m.Member); + } + return m; + } + + internal virtual Expression VisitMethodCall(MethodCallExpression m) + { + Expression obj = this.Visit(m.Object); + IEnumerable args = this.VisitExpressionList(m.Arguments); + if (obj != m.Object || args != m.Arguments) + { + return Expression.Call(obj, m.Method, args); + } + return m; + } + + internal virtual ReadOnlyCollection VisitExpressionList(ReadOnlyCollection original) + { + List list = null; + for (int i = 0, n = original.Count; i < n; i++) + { + Expression p = this.Visit(original[i]); + if (list != null) + { + list.Add(p); + } + else if (p != original[i]) + { + list = new List(n); + for (int j = 0; j < i; j++) + { + list.Add(original[j]); + } + list.Add(p); + } + } + if (list != null) + { + return new ReadOnlyCollection(list); + } + return original; + } + + internal virtual MemberAssignment VisitMemberAssignment(MemberAssignment assignment) + { + Expression e = this.Visit(assignment.Expression); + if (e != assignment.Expression) + { + return Expression.Bind(assignment.Member, e); + } + return assignment; + } + + internal virtual MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding) + { + IEnumerable bindings = this.VisitBindingList(binding.Bindings); + if (bindings != binding.Bindings) + { + return Expression.MemberBind(binding.Member, bindings); + } + return binding; + } + + internal virtual MemberListBinding VisitMemberListBinding(MemberListBinding binding) + { + IEnumerable initializers = this.VisitElementInitializerList(binding.Initializers); + if (initializers != binding.Initializers) + { + return Expression.ListBind(binding.Member, initializers); + } + return binding; + } + + internal virtual IEnumerable VisitBindingList(ReadOnlyCollection original) + { + List list = null; + for (int i = 0, n = original.Count; i < n; i++) + { + MemberBinding b = this.VisitBinding(original[i]); + if (list != null) + { + list.Add(b); + } + else if (b != original[i]) + { + list = new List(n); + for (int j = 0; j < i; j++) + { + list.Add(original[j]); + } + list.Add(b); + } + } + if (list != null) + { + return list; + } + return original; + } + + internal virtual IEnumerable VisitElementInitializerList(ReadOnlyCollection original) + { + List list = null; + for (int i = 0, n = original.Count; i < n; i++) + { + ElementInit init = this.VisitElementInitializer(original[i]); + if (list != null) + { + list.Add(init); + } + else if (init != original[i]) + { + list = new List(n); + for (int j = 0; j < i; j++) + { + list.Add(original[j]); + } + list.Add(init); + } + } + if (list != null) + { + return list; + } + return original; + } + + internal virtual Expression VisitLambda(LambdaExpression lambda) + { + Expression body = this.Visit(lambda.Body); + if (body != lambda.Body) + { + return Expression.Lambda(lambda.Type, body, lambda.Parameters); + } + return lambda; + } + + internal virtual NewExpression VisitNew(NewExpression nex) + { + IEnumerable args = this.VisitExpressionList(nex.Arguments); + if (args != nex.Arguments) + { + return Expression.New(nex.Constructor, args); + } + return nex; + } + + internal virtual Expression VisitMemberInit(MemberInitExpression init) + { + NewExpression n = this.VisitNew(init.NewExpression); + IEnumerable bindings = this.VisitBindingList(init.Bindings); + if (n != init.NewExpression || bindings != init.Bindings) + { + return Expression.MemberInit(n, bindings); + } + return init; + } + + internal virtual Expression VisitListInit(ListInitExpression init) + { + NewExpression n = this.VisitNew(init.NewExpression); + IEnumerable initializers = this.VisitElementInitializerList(init.Initializers); + if (n != init.NewExpression || initializers != init.Initializers) + { + return Expression.ListInit(n, initializers); + } + return init; + } + + internal virtual Expression VisitNewArray(NewArrayExpression na) + { + IEnumerable exprs = this.VisitExpressionList(na.Expressions); + if (exprs != na.Expressions) + { + if (na.NodeType == ExpressionType.NewArrayInit) + { + return Expression.NewArrayInit(na.Type.GetElementType(), exprs); + } + else + { + return Expression.NewArrayBounds(na.Type.GetElementType(), exprs); + } + } + return na; + } + + internal virtual Expression VisitInvocation(InvocationExpression iv) + { + IEnumerable args = this.VisitExpressionList(iv.Arguments); + Expression expr = this.Visit(iv.Expression); + if (args != iv.Arguments || expr != iv.Expression) + { + return Expression.Invoke(expr, args); + } + return iv; + } + } + + internal sealed class ParameterSubst : ExpressionVisitor + { + private ParameterExpression m_pexpr; + private Expression m_aexpr; + + internal ParameterSubst(ParameterExpression pexpr, Expression aexpr) + { + this.m_pexpr = pexpr; + this.m_aexpr = aexpr; + } + + internal override Expression VisitParameter(ParameterExpression p) + { + return (p == this.m_pexpr) ? this.m_aexpr : p; + } + } + + internal sealed class ExpressionSubst : ExpressionVisitor + { + private Substitution m_paramSubst; + private List m_leftExprList; + private List m_rightExprList; + + internal ExpressionSubst(Substitution paramSubst) + { + this.m_paramSubst = paramSubst; + this.m_leftExprList = new List(2); + this.m_rightExprList = new List(2); + } + + internal void AddSubst(Expression left, Expression right) + { + foreach (Expression expr in this.m_leftExprList) + { + if (ExpressionMatcher.MemberAccessSubsumes(expr, left)) + { + return; + } + } + this.m_leftExprList.Add(left); + this.m_rightExprList.Add(right); + } + + internal override Expression Visit(Expression expr) + { + if (expr == null) return expr; + for (int i = 0; i < this.m_leftExprList.Count; i++) + { + if (ExpressionMatcher.Match(expr, this.m_leftExprList[i], this.m_paramSubst)) + { + return this.m_rightExprList[i]; + } + } + return base.Visit(expr); + } + } + + internal sealed class CombinerSubst : ExpressionVisitor + { + private Expression m_expr; + private ParameterExpression m_keyParam; + private ParameterExpression m_groupParam; + private Expression m_keyExpr; + private Expression[] m_fromExprs; + private Expression[] m_toExprs; + private int m_idx = 0; + + internal CombinerSubst(LambdaExpression lambdaExpr, + ParameterExpression keyValueParam, + Expression[] fromExprs, + Expression[] toExprs) + { + this.m_expr = lambdaExpr.Body; + if (lambdaExpr.Parameters.Count == 1) + { + this.m_keyParam = null; + this.m_groupParam = lambdaExpr.Parameters[0]; + } + else + { + this.m_keyParam = lambdaExpr.Parameters[0]; + this.m_groupParam = lambdaExpr.Parameters[1]; + } + PropertyInfo keyPropInfo = keyValueParam.Type.GetProperty("Key"); + this.m_keyExpr = Expression.Property(keyValueParam, keyPropInfo); + + this.m_fromExprs = fromExprs; + this.m_toExprs = toExprs; + this.m_idx = 0; + } + + internal Expression Visit() + { + return this.Visit(this.m_expr); + } + + internal override Expression VisitMethodCall(MethodCallExpression mcExpr) + { + if (this.m_idx < this.m_fromExprs.Length && + this.m_fromExprs[this.m_idx] == mcExpr) + { + return this.m_toExprs[this.m_idx++]; + } + return base.VisitMethodCall(mcExpr); + } + + internal override Expression VisitMemberAccess(MemberExpression m) + { + if (this.m_keyParam == null && + m.Expression == this.m_groupParam && + m.Member.Name == "Key") + { + return this.m_keyExpr; + } + return base.VisitMemberAccess(m); + } + + internal override Expression VisitParameter(ParameterExpression p) + { + if (this.m_keyParam == p) + { + return this.m_keyExpr; + } + return base.VisitParameter(p); + } + } + + internal sealed class FreeParameters : ExpressionVisitor + { + private HashSet freeParameters; + + internal FreeParameters() + { + this.freeParameters = new HashSet(); + } + + internal HashSet Parameters + { + get { return this.freeParameters; } + } + + internal override Expression VisitParameter(ParameterExpression p) + { + this.freeParameters.Add(p); + return p; + } + + internal override Expression VisitLambda(LambdaExpression lambda) + { + Expression body = this.Visit(lambda.Body); + foreach (ParameterExpression param in lambda.Parameters) + { + this.freeParameters.Remove(param); + } + return lambda; + } + + public override string ToString() + { + StringBuilder sb = new StringBuilder(); + sb.Append("{ "); + bool isFirst = true; + foreach (ParameterExpression p in this.freeParameters) + { + if (isFirst) + { + isFirst = false; + } + else + { + sb.Append(","); + } + if (p.Name != null) + { + sb.Append(p.Name); + } + else + { + sb.Append(""); + } + } + sb.Append(" }"); + return sb.ToString(); + } + } + + internal sealed class ExpressionQuerySet : ExpressionVisitor + { + private HashSet m_querySet; + + public ExpressionQuerySet() + { + this.m_querySet = new HashSet(); + } + + internal HashSet QuerySet + { + get { return this.m_querySet; } + } + + internal override Expression Visit(Expression expr) + { + if (expr == null) return expr; + if (HpcLinqExpression.IsConstant(expr)) + { + object val = ExpressionSimplifier.Evaluate(expr); + if (val is IQueryable) + { + this.m_querySet.Add(((IQueryable)val).Expression); + } + return expr; + } + return base.Visit(expr); + } + } + + internal sealed class ReferencedQuerySubst : ExpressionVisitor + { + private Dictionary m_referencedQueryMap; + private int m_idx; + private List> m_referencedQueries; + + public ReferencedQuerySubst(Dictionary referencedQueryMap) + { + this.m_referencedQueryMap = referencedQueryMap; + this.m_idx = 0; + this.m_referencedQueries = new List>(); + } + + internal override Expression Visit(Expression expr) + { + if (expr == null) return expr; + if (HpcLinqExpression.IsConstant(expr)) + { + object val = ExpressionSimplifier.Evaluate(expr); + if (val is IQueryable) + { + QueryNodeInfo nodeInfo; + if (this.m_referencedQueryMap.TryGetValue(((IQueryable)val).Expression, out nodeInfo)) + { + string name = "side__" + this.m_idx; + this.m_idx++; + this.m_referencedQueries.Add(new Pair(name, nodeInfo.queryNode)); + return Expression.Parameter(expr.Type, name); + } + throw new DryadLinqException(HpcLinqErrorCode.UnhandledQuery, + String.Format(SR.UnhandledQuery, HpcLinqExpression.Summarize(expr))); + } + return expr; + } + return base.Visit(expr); + } + + public List> GetReferencedQueries() + { + return this.m_referencedQueries; + } + } + + internal sealed class ExpressionInfo : ExpressionVisitor + { + private bool m_isExpensive; + + public ExpressionInfo(Expression expr) + { + this.m_isExpensive = false; + this.Visit(expr); + } + + internal override Expression VisitMethodCall(MethodCallExpression mcExpr) + { + Attribute resourceAttrib = AttributeSystem.GetAttrib(mcExpr, typeof(ResourceAttribute)); + if (resourceAttrib == null) + { + this.m_isExpensive = true; + } + else + { + this.m_isExpensive = this.m_isExpensive || ((ResourceAttribute)resourceAttrib).IsExpensive; + } + return mcExpr; + } + + internal override Expression VisitBinary(BinaryExpression b) + { + if (b.Method == null) + { + return base.VisitBinary(b); + } + else + { + Attribute resourceAttrib = AttributeSystem.GetAttrib(b, typeof(ResourceAttribute)); + if (resourceAttrib == null) + { + this.m_isExpensive = true; + } + else + { + this.m_isExpensive = this.m_isExpensive || ((ResourceAttribute)resourceAttrib).IsExpensive; + } + return b; + } + } + + internal override Expression VisitUnary(UnaryExpression u) + { + if (u.Method == null) + { + return base.VisitUnary(u); + } + else + { + this.m_isExpensive = true; + return u; + } + } + + public bool IsExpensive + { + get { return this.m_isExpensive; } + } + } +} diff --git a/LinqToDryad/ForkTuple.cs b/LinqToDryad/ForkTuple.cs new file mode 100644 index 0000000..ad0e65f --- /dev/null +++ b/LinqToDryad/ForkTuple.cs @@ -0,0 +1,182 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Globalization; +using System.Reflection; +using System.Linq; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + interface IForkValue + { + bool HasValue { get; } + object Value { get; } + } + + [Serializable] + public struct ForkValue : IForkValue + { + private T m_x; + private bool m_hasX; + + public ForkValue(T x) + { + this.m_x = x; + this.m_hasX = true; + } + + public bool HasValue + { + get { return this.m_hasX; } + } + + object IForkValue.Value + { + get { return this.m_x; } + } + + public T Value + { + get { return this.m_x; } + set { this.m_x = value; this.m_hasX = true; } + } + } + + [Serializable] + public struct ForkTuple + { + private T1 m_x; + private T2 m_y; + private bool m_hasX; + private bool m_hasY; + + public ForkTuple(T1 x, T2 y) + { + this.m_x = x; + this.m_y = y; + this.m_hasX = true; + this.m_hasY = true; + } + + public bool HasFirst + { + get { return this.m_hasX; } + } + + public bool HasSecond + { + get { return this.m_hasY; } + } + + public T1 First + { + get { return this.m_x; } + set { this.m_x = value; this.m_hasX = true; } + } + + public T2 Second + { + get { return this.m_y; } + set { this.m_y = value; this.m_hasY = true; } + } + } + + [Serializable] + public struct ForkTuple + { + private T1 m_x; + private T2 m_y; + private T3 m_z; + private bool m_hasX; + private bool m_hasY; + private bool m_hasZ; + + public ForkTuple(T1 x, T2 y, T3 z) + { + this.m_x = x; + this.m_y = y; + this.m_z = z; + this.m_hasX = true; + this.m_hasY = true; + this.m_hasZ = true; + } + + public bool HasFirst + { + get { return this.m_hasX; } + } + + public bool HasSecond + { + get { return this.m_hasY; } + } + + public bool HasThird + { + get { return this.m_hasZ; } + } + + public T1 First + { + get { return this.m_x; } + set { this.m_x = value; this.m_hasX = true; } + } + + public T2 Second + { + get { return this.m_y; } + set { this.m_y = value; this.m_hasY = true; } + } + + public T3 Third + { + get { return this.m_z; } + set { this.m_z = value; this.m_hasZ = true; } + } + } + + internal struct ForkTuple + { + private IForkValue[] m_values; + + public ForkTuple(params IForkValue[] values) + { + this.m_values = values; + } + + public bool HasValue(int index) + { + return this.m_values[index].HasValue; + } + + public object Value(int index) + { + return this.m_values[index].Value; + } + } +} diff --git a/LinqToDryad/Hash64.cs b/LinqToDryad/Hash64.cs new file mode 100644 index 0000000..a0c6e8e --- /dev/null +++ b/LinqToDryad/Hash64.cs @@ -0,0 +1,355 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.IO; + +namespace Microsoft.Research.DryadLinq.Internal +{ + /// + /// A class to compute 64 bit Rabin fingerprints. + /// + internal class Hash64 + { + private const int LOGZEROBLOCK = 8; + private const int ZEROBLOCK = 1 << LOGZEROBLOCK; + + internal const UInt64 Empty = 0x911498ae0e66bad6UL; + internal static readonly Hash64 Hasher = new Hash64(Empty, 8); + + // poly[0] = 0; poly[1] = polynomial + private UInt64[] poly = new UInt64[2]; + + // bybyte[b,i] is i*X^(64+8*b) mod poly[1] + private UInt64[,] bybyte = new UInt64[8,256]; + + // extend[i] is X^(8*2^i) mod poly[1] + private UInt64[] powers = new UInt64[64]; + + private byte[] zeroes = new byte[ZEROBLOCK]; + + // bybyte[b,i] is i*X^(64+8*(b+span)) mod poly[1] + private UInt64[,] bybyte_out = new UInt64[8,256]; + + private int span; + + /// + /// Computes the tables needed for fingerprint manipulations. + /// Requires that "poly" be the binary representation + /// of an irreducible polynomial in GF(2) of degree 64. The X^64 term + /// is not represented. The X^0 term is the high order bit, and the + /// X^63 term is the low-order bit. + /// span is used in later calls to SlideWord(). If SlideWord() + /// is not to be called, span should be set to zero. + /// + internal Hash64(UInt64 poly, int span) + { + this.poly[0] = 0; + this.poly[1] = poly; // This must be initialized early on + this.span = span; + // bybyte[][] must be initialized before powers[] + this.InitByByte(this.bybyte, poly); + // zeroes must be initialized before powers[] + for (int i = 0; i < this.zeroes.Length; i++) this.zeroes[i] = 0; + // The initialization of powers[] must happen after bybyte[][] + // and zeroes are initialized because concat uses all of + // bybyte[][], zeroes and the prefix of powers[] internally. + this.powers[0] = 1ul << 55; + uint l = 1; + for (int i = 1; i < this.powers.Length; i++, l <<= 1) + { + this.powers[i] = this.Concat(this.powers[i-1] ^ poly, 0, l); + } + if (span != 0) + { + this.InitByByte(this.bybyte_out, this.Concat(0, 0, (uint)(span-1) * 8)); + } + } + + private void InitByByte(UInt64[,] bybyte, UInt64 f) + { + for (int b = 0; b != 8; b++) + { + bybyte[b,0] = 0; + for (int i = 0x80; i != 0; i >>= 1) + { + bybyte[b,i] = f; + f = this.poly[f & 1] ^ (f >> 1); + } + for (int i = 1; i != 256; i <<= 1) + { + UInt64 xf = bybyte[b,i]; + for (int k = 1; k != i; k++) + { + bybyte[b,i+k] = xf ^ bybyte[b,k]; + } + } + } + } + + /// + /// If fp was generated with polynomial P, "a" is the fingerprint under + /// P of string A, and 64-bit words "data[0, ..., len-1]" contain string + /// B, return the fingerprint under P of the concatenation of A and B. + /// Arrays of words are treated as polynomials. The low-order bit in + /// the first word is the highest degree coefficient in the polynomial. + /// This routine differs from Extend() on bigendian machines, where the + /// byte order within each word is backwards. + /// + internal UInt64 ExtendWord(UInt64 fpa, UInt64[] data, int start, int len) + { + for (int i = start; i != start+len; i++) + { + fpa ^= data[i]; + fpa = this.bybyte[7, fpa & 0xff] ^ + this.bybyte[6, (fpa >> 8) & 0xff] ^ + this.bybyte[5, (fpa >> 16) & 0xff] ^ + this.bybyte[4, (fpa >> 24) & 0xff] ^ + this.bybyte[3, (fpa >> 32) & 0xff] ^ + this.bybyte[2, (fpa >> 40) & 0xff] ^ + this.bybyte[1, (fpa >> 48) & 0xff] ^ + this.bybyte[0, fpa >> 56]; + } + return fpa; + } + + /// + /// If fp was generated with polynomial P, "a" is the fingerprint under + /// P of string A, and "b" is the fingerprint under P of string B, which + /// has length "blen" bytes, return the fingerprint under P of the + /// concatenation of A and B. + /// + internal UInt64 Concat(UInt64 a, UInt64 b, UInt64 blen) + { + UInt64 x = blen; + int low = (int)(x & ((1 << LOGZEROBLOCK)-1)); + a ^= this.poly[1]; + if (low != 0) + { + a = this.Extend(a, this.zeroes, 0, low); + } + x >>= LOGZEROBLOCK; + for (int i = LOGZEROBLOCK; x != 0; i++) + { + if ((x & 1) != 0) + { + UInt64 m = 0; + UInt64 e = this.powers[i]; + for (UInt64 bit = 1ul << 63; bit != 0; bit >>= 1) + { + if ((e & bit) != 0) + { + m ^= a; + } + a = (a >> 1) ^ this.poly[a & 1]; + } + a = m; + } + x >>= 1; + } + return a ^ b; + } + + /// + /// if "fp" was generated with polynomial P, X is some string of length + /// "(span-1)*8" bytes (see the FingerPrint constructor), "fpa" is the + /// fingerprint under P of word "a" concatenated with X, return the + /// fingerprint under P of X concatenated with word "b". The words "a" + /// and "b" represent polynomials whose X^0 term is in the high-order bit, + /// and whose X^63 term is in the low order bit. + /// + internal UInt64 SlideWord(UInt64 fp, UInt64 a, UInt64 b) + { + a ^= this.poly[1] ^ (1ul << 63); + fp ^= this.bybyte_out[7,a & 0xff] ^ + this.bybyte_out[6,(a >> 8) & 0xff] ^ + this.bybyte_out[5,(a >> 16) & 0xff] ^ + this.bybyte_out[4,(a >> 24) & 0xff] ^ + this.bybyte_out[3,(a >> 32) & 0xff] ^ + this.bybyte_out[2,(a >> 40) & 0xff] ^ + this.bybyte_out[1,(a >> 48) & 0xff] ^ + this.bybyte_out[0,a >> 56]; + fp ^= b; + fp = this.bybyte[7,fp & 0xff] ^ + this.bybyte[6,(fp >> 8) & 0xff] ^ + this.bybyte[5,(fp >> 16) & 0xff] ^ + this.bybyte[4,(fp >> 24) & 0xff] ^ + this.bybyte[3,(fp >> 32) & 0xff] ^ + this.bybyte[2,(fp >> 40) & 0xff] ^ + this.bybyte[1,(fp >> 48) & 0xff] ^ + this.bybyte[0,fp >> 56]; + return fp; + } + + /// + /// if fp was generated with polynomial P, "fpa" is the fingerprint under + /// P of string A, and bytes "data[start, ..., start+len-1]" contain + /// string B, return the fingerprint under P of the concatenation of A + /// and B. Strings are treated as polynomials. The low-order bit in + /// the first byte is the highest degree coefficient in the polynomial. + /// This routine differs from ExtendWord() in that it will read bytes + /// in increasing address order, regardless of the endianness of the + /// machine. + /// + internal UInt64 Extend(UInt64 fpa, byte[] data, int start, int len) + { + for (int i = 0; i < len; i++) + { + fpa = (fpa >> 8) ^ this.bybyte[0,(fpa & 0xff) ^ data[start++]]; + } + return fpa; + } + + internal unsafe UInt64 Extend(UInt64 fpa, byte* data, int start, int len) + { + for (int i = 0; i < len; i++) + { + fpa = (fpa >> 8) ^ this.bybyte[0,(fpa & 0xff) ^ data[start++]]; + } + return fpa; + } + + internal UInt64 Extend(UInt64 fp, byte b) + { + return (fp >> 8) ^ this.bybyte[0,(fp & 0xff) ^ b]; + } + + internal UInt64 Extend(UInt64 fp, sbyte b) + { + return this.Extend(fp, (byte)b); + } + + internal UInt64 Extend(UInt64 fp, bool b) + { + byte b1 = (byte)((b) ? 1 : 0); + return (fp >> 8) ^ this.bybyte[0,(fp & 0xff) ^ b1]; + } + + internal UInt64 Extend(UInt64 fp, char c) + { + return this.Extend(fp, (ushort)c); + } + + internal UInt64 Extend(UInt64 fp, short v) + { + return this.Extend(fp, (ushort)v); + } + + internal UInt64 Extend(UInt64 fp, ushort v) + { + fp ^= v; + return ((fp >> 16) ^ + this.bybyte[1, fp & 0xff] ^ + this.bybyte[0, (fp >> 8) & 0xff]); + + } + + internal UInt64 Extend(UInt64 fp, int v) + { + return this.Extend(fp, (uint)v); + } + + internal UInt64 Extend(UInt64 fp, uint v) + { + fp ^= v; + return ((fp >> 32) ^ + (this.bybyte[3, fp & 0xff] ^ + this.bybyte[2, (fp >> 8) & 0xff] ^ + this.bybyte[1, (fp >> 16) & 0xff] ^ + this.bybyte[0, (fp >> 24) & 0xff])); + } + + internal UInt64 Extend(UInt64 fp, long v) + { + return this.Extend(fp, (UInt64)v); + } + + internal UInt64 Extend(UInt64 fp, UInt64 v) + { + fp ^= v; + return (this.bybyte[7, fp & 0xff] ^ + this.bybyte[6, (fp >> 8) & 0xff] ^ + this.bybyte[5, (fp >> 16) & 0xff] ^ + this.bybyte[4, (fp >> 24) & 0xff] ^ + this.bybyte[3, (fp >> 32) & 0xff] ^ + this.bybyte[2, (fp >> 40) & 0xff] ^ + this.bybyte[1, (fp >> 48) & 0xff] ^ + this.bybyte[0, (fp >> 56) & 0xff]); + + } + + internal unsafe UInt64 Extend(UInt64 fp, float v) + { + uint v1 = *(uint*)&v; + return this.Extend(fp, v1); + } + + internal unsafe UInt64 Extend(UInt64 fp, decimal v) + { + UInt64* vals = (UInt64*)&v; + fp = this.Extend(fp, *vals); + return this.Extend(fp, *(vals + 1)); + } + + internal unsafe UInt64 Extend(UInt64 fp, double v) + { + UInt64 v1 = *(UInt64*)&v; + return this.Extend(fp, v1); + } + + internal UInt64 Extend(UInt64 fp, string s) + { + byte[] bytes = new byte[s.Length]; + for (int i = 0; i < s.Length; i++) + { + bytes[i] = (byte)(s[i] & 0xff); + } + return this.Extend(fp, bytes, 0, bytes.Length); + } + + internal UInt64 ExtendFile(UInt64 fp, string filename) + { + int size = 65536 * 4; + byte[] readBuf = new byte[size]; + byte[] fpBuf = new byte[size]; + + ulong fileFP = fp; + using (Stream ifs = new FileStream(filename, FileMode.Open, FileAccess.Read, FileShare.Read, size, FileOptions.Asynchronous | FileOptions.SequentialScan)) + { + IAsyncResult readResult = ifs.BeginRead(readBuf, 0, readBuf.Length, null, null); + while (true) + { + int bytesRead = ifs.EndRead(readResult); + if (bytesRead == 0) break; + + byte[] tmpBuf = fpBuf; + fpBuf = readBuf; + readBuf = tmpBuf; + readResult = ifs.BeginRead(readBuf, 0, readBuf.Length, null, null); + fileFP = this.Extend(fileFP, fpBuf, 0, bytesRead); + } + } + return fileFP; + } + } +} diff --git a/LinqToDryad/HpcJobSubmission.cs b/LinqToDryad/HpcJobSubmission.cs new file mode 100644 index 0000000..886f627 --- /dev/null +++ b/LinqToDryad/HpcJobSubmission.cs @@ -0,0 +1,336 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +#if REMOVE_FOR_YARN +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.IO.Compression; +using System.Linq; +using System.Net; +using System.Security.Principal; +using System.Text; +using System.Xml; +using Microsoft.Hpc.Scheduler; +using Microsoft.Hpc.Scheduler.Properties; +using Microsoft.Hpc.Dryad; +using Microsoft.Research.DryadLinq.Internal; +using System.Collections.Specialized; + +namespace Microsoft.Research.DryadLinq +{ + internal class HpcJobSubmission : IHpcLinqJobSubmission + { + private HpcLinqContext m_context; + private DryadJobSubmission m_job; + private JobStatus m_status; + + internal void Initialize() + { + this.m_job.FriendlyName = m_context.Configuration.JobFriendlyName; + + // if the user specified MinNodes and it is less than 2, return an error. Otherwise let job run with job template which + // must specify a value of 2 or higher + if (m_context.Configuration.JobMinNodes.HasValue && m_context.Configuration.JobMinNodes < 2) + { + throw new HpcLinqException(HpcLinqErrorCode.HpcLinqJobMinMustBe2OrMore, + SR.HpcLinqJobMinMustBe2OrMore); + } + + this.m_job.DryadJobMinNodes = m_context.Configuration.JobMinNodes; + this.m_job.DryadJobMaxNodes = m_context.Configuration.JobMaxNodes; + this.m_job.DryadNodeGroup = m_context.Configuration.NodeGroup; + this.m_job.DryadUserName = m_context.Configuration.JobUsername; + this.m_job.DryadPassword = m_context.Configuration.JobPassword; + this.m_job.DryadRuntime = m_context.Configuration.JobRuntimeLimit; + this.m_job.EnableSpeculativeDuplication = m_context.Configuration.EnableSpeculativeDuplication; + this.m_job.RuntimeTraceLevel = (int)m_context.Configuration.RuntimeTraceLevel; + this.m_job.GraphManagerNode = m_context.Configuration.GraphManagerNode; + + System.Collections.Specialized.NameValueCollection collection = new System.Collections.Specialized.NameValueCollection(); + + foreach (var keyValuePair in m_context.Configuration.JobEnvironmentVariables) + { + collection.Add(keyValuePair.Key, keyValuePair.Value); + } + + this.m_job.JobEnvironmentVariables = collection; + } + + internal bool LocalJM + { + get + { + return m_job.Type == DryadJobSubmission.JobType.Local; + } + set + { + if (value == true) + { + m_job.Type = DryadJobSubmission.JobType.Local; + } + else + { + m_job.Type = DryadJobSubmission.JobType.Cluster; + } + } + } + + internal string CommandLine + { + get + { + return m_job.CommandLine; + } + set + { + m_job.CommandLine = value; + } + } + + public string ErrorMsg + { + get + { + return m_job.ErrorMessage; + } + private set + { + m_job.ErrorMessage = value; + } + } + + internal HpcJobSubmission(HpcLinqContext context) + { + this.m_context = context; + this.m_status = JobStatus.NotSubmitted; + + //@@TODO[P0] pass the runtime to the DryadJobSubmission so that it can use the scheduler instance. + //@@TODO: Merge DryadJobSubmission into Ms.Hpc.Linq. Until then make sure Context is not disposed before DryadJobSubmission. + this.m_job = new DryadJobSubmission(m_context.GetIScheduler()); + } + + public void AddJobOption(string fieldName, string fieldVal) + { + if (fieldName == "cmdline") + { + m_job.CommandLine = fieldVal; + } + else + { + throw new HpcLinqException(HpcLinqErrorCode.JobOptionNotImplemented, + String.Format(SR.JobOptionNotImplemented, fieldName, fieldVal)); + } + } + + public void AddLocalFile(string fileName) + { + m_job.AddFileToJob(fileName); + } + + public void AddRemoteFile(string fileName) + { + string msg = String.Format("HpcJobSubmission.AddRemoteFile({0}) not implemented", fileName); + } + + public JobStatus GetStatus() + { + if (this.m_status == JobStatus.Success || + this.m_status == JobStatus.Failure ) + { + return this.m_status; + } + + if (this.m_job == null) + { + return JobStatus.NotSubmitted; + } + + switch (this.m_job.State) + { + case JobState.ExternalValidation: + case JobState.Queued: + case JobState.Submitted: + case JobState.Validating: + { + this.m_status = JobStatus.Waiting; + break; + } + case JobState.Configuring: + case JobState.Running: + case JobState.Canceling: + case JobState.Finishing: + { + this.m_status = JobStatus.Running; + break; + } + case JobState.Failed: + // a job only fails if the job manager fails. + { + ISchedulerCollection tasks = this.m_job.Job.GetTaskList(null, null, false); + if (tasks.Count < 1) + { + this.ErrorMsg = this.m_job.ErrorMessage; + this.m_status = JobStatus.Failure; + } + else + { + ISchedulerTask jm = tasks[0] as ISchedulerTask; + switch (jm.State) + { + case TaskState.Finished: + this.m_status = JobStatus.Success; + break; + default: + this.m_status = JobStatus.Failure; + this.ErrorMsg = "JM error: " + jm.ErrorMessage; + break; + } + } + break; + } + case JobState.Canceled: + { + this.ErrorMsg = this.m_job.ErrorMessage; + this.m_status = JobStatus.Failure; + break; + } + case JobState.Finished: + { + this.m_status = JobStatus.Success; + break; + } + } + + return this.m_status; + } + + public void SubmitJob() + { + // Verify that the head node is set + if (m_context.Configuration.HeadNode == null) + { + throw new HpcLinqException(HpcLinqErrorCode.ClusterNameMustBeSpecified, + SR.ClusterNameMustBeSpecified); + } + + try + { + this.m_job.SubmitJob(); + } + catch (Exception e) + { + throw new HpcLinqException(HpcLinqErrorCode.SubmissionFailure, + String.Format(SR.SubmissionFailure, m_context.Configuration.HeadNode), e); + } + } + + public JobStatus TerminateJob() + { + JobStatus status = GetStatus(); + switch (status) + { + case JobStatus.Failure: + case JobStatus.NotSubmitted: + case JobStatus.Success: + case JobStatus.Cancelled: + // Nothing to do. + return status; + default: + break; + } + + this.m_job.CancelJob(); + return JobStatus.Cancelled; + } + + public int GetJobId() + { + if (m_job == null || m_job.Job == null) + { + throw new InvalidOperationException("(internal) GetDryadJobInfo called when no job is available"); + } + return m_job.Job.Id; + } + } +} +#else +namespace Microsoft.Research.DryadLinq +{ + internal class HpcJobSubmission : IHpcLinqJobSubmission + { + private HpcLinqContext m_context; + + public HpcJobSubmission(HpcLinqContext context) + { + m_context = context; + } + + public void AddJobOption(string fieldName, string fieldVal) + { + throw new System.NotImplementedException(); + } + + public void AddLocalFile(string fileName) + { + throw new System.NotImplementedException(); + } + + public void AddRemoteFile(string fileName) + { + throw new System.NotImplementedException(); + } + + public string ErrorMsg + { + get { throw new System.NotImplementedException(); } + } + + public JobStatus GetStatus() + { + throw new System.NotImplementedException(); + } + + public void SubmitJob() + { + throw new System.NotImplementedException(); + } + + public JobStatus TerminateJob() + { + throw new System.NotImplementedException(); + } + + public int GetJobId() + { + throw new System.NotImplementedException(); + } + + internal void Initialize() + { + throw new System.NotImplementedException(); + } + } +} +#endif diff --git a/LinqToDryad/HpcLinqCache.cs b/LinqToDryad/HpcLinqCache.cs new file mode 100644 index 0000000..3824f25 --- /dev/null +++ b/LinqToDryad/HpcLinqCache.cs @@ -0,0 +1,205 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.IO; + +namespace Microsoft.Research.DryadLinq.Internal +{ + public class HpcLinqCache + { + private const int EvictionPeriod = 10000; + + private static Dictionary s_cache = new Dictionary(); + private static bool s_isInitialized = false; + private static Timer s_evictionTimer; + + public void Initalize() + { + lock (s_cache) + { + if (!s_isInitialized) + { + s_evictionTimer = new Timer(new TimerCallback(DoEviction), + null, + EvictionPeriod, EvictionPeriod); + s_isInitialized = true; + } + } + } + + public static void Add(string key, object val, Type elemType, object factory) + { + lock (s_cache) + { + if (!s_cache.ContainsKey(key)) + { + CacheRecord rec = CacheRecord.Create(val, elemType, factory); + s_cache.Add(key, rec); + } + } + } + + public static bool Contains(string key) + { + lock (s_cache) + { + return s_cache.ContainsKey(key); + } + } + + public static bool TryGet(string key, out object val) + { + lock (s_cache) + { + val = null; + CacheRecord rec; + bool found = s_cache.TryGetValue(key, out rec); + if (found) + { + val = rec.Value; + rec.LastAccessed = DateTime.Now; + rec.RefCount++; + } + return found; + } + } + + public static void DecRefCount(string key) + { + lock (s_cache) + { + CacheRecord rec; + bool found = s_cache.TryGetValue(key, out rec); + if (!found) + { + DryadLinqLog.Add("Can't find the cache entry with key {0}.", key); + } + else if (rec.RefCount > 0) + { + rec.RefCount--; + } + else + { + DryadLinqLog.Add("The reference count of the cache entry {0} is already 0.", + key); + } + } + } + + private unsafe static void DoEviction(object stateInfo) + { + while (true) + { + try + { + MEMORYSTATUSEX memStatus = new MEMORYSTATUSEX(); + memStatus.dwLength = (UInt32)sizeof(MEMORYSTATUSEX); + HpcLinqNative.GlobalMemoryStatusEx(ref memStatus); + if (HpcLinqNative.GlobalMemoryStatusEx(ref memStatus) && + memStatus.ullAvailPhys < 4 * 1024 * 1024 * 1024UL) + { + // Perform eviction only when feeling memory pressure + lock (s_cache) + { + var candidates = s_cache.Where(x => x.Value.RefCount == 0); + foreach (var rec in candidates) + { + s_cache.Remove(rec.Key); + } + } + } + } + catch (Exception e) + { + DryadLinqLog.Add("Exception occurred when performing cache eviction: {0}.", + e.Message); + } + } + } + + private abstract class CacheRecord + { + private object m_value; + private DateTime m_lastAccessed; + private int m_refCount; + + public CacheRecord(object value) + { + this.m_value = value; + this.m_lastAccessed = DateTime.Now; + this.m_refCount = 0; + } + + public static CacheRecord + Create(object value, Type elemType, object factory) + { + Type type = typeof(IEnumerable<>).MakeGenericType(elemType); + return (CacheRecord)Activator.CreateInstance( + elemType, new object[] { value, factory }); + } + + public object Value + { + get { return this.m_value; } + } + + public DateTime LastAccessed + { + get { return this.m_lastAccessed; } + set { this.m_lastAccessed = value; } + } + + public int RefCount + { + get { return this.m_refCount; } + set { this.m_refCount = value; } + } + + public abstract void Write(NativeBlockStream stream); + } + + private class CacheRecord : CacheRecord + { + private HpcLinqFactory m_factory; + + public CacheRecord(object value, object factory) + : base(value) + { + this.m_factory = (HpcLinqFactory)factory; + } + + public override void Write(NativeBlockStream stream) + { + HpcRecordWriter writer = this.m_factory.MakeWriter(stream); + IEnumerable elems = (IEnumerable)this.Value; + foreach (var x in elems) + { + writer.WriteRecordSync(x); + } + writer.Close(); + } + } + } +} diff --git a/LinqToDryad/HpcLinqConfiguration.cs b/LinqToDryad/HpcLinqConfiguration.cs new file mode 100644 index 0000000..d4b2097 --- /dev/null +++ b/LinqToDryad/HpcLinqConfiguration.cs @@ -0,0 +1,567 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; +using System.Collections.Generic; +using System.Collections.Specialized; +using System.Linq; +using System.Text; +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("DistributedDandelion")] + +namespace Microsoft.Research.DryadLinq +{ + /// + /// Configuration information for a HPC Query. + /// + public sealed class HpcLinqConfiguration + { + internal bool _isReadOnly; + + private HpcLinqStringList _resourcesToAdd = new HpcLinqStringList(); + private HpcLinqStringList _resourcesToRemove = new HpcLinqStringList(); + + private DscCompressionScheme _intermediateDataCompressionScheme = DscCompressionScheme.Gzip; + private DscCompressionScheme _outputCompressionScheme = DscCompressionScheme.None; + private bool _compileForVertexDebugging = false; // Ship PDBs + No optimization + private string _headNode; + private string _hdfsNameNode; + private int _hdfsNameNodeHttpPort = 8033; //TODO - Read Config + private string _jobFriendlyName; + private int? _jobMinNodes; + private int? _jobMaxNodes; + private string _nodeGroup; + private int? _jobRuntimeLimit; + private bool _localDebug = false; + private bool _orderPreserving = true; + private string _jobUsername = null; + private string _jobPassword = null; + private HpcQueryTraceLevel _runtimeTraceLevel = HpcQueryTraceLevel.Error; + private string _graphManagerNode; + private bool _enableSpeculativeDuplication = false; + private HpcLinqStringDictionary _jobEnvironmentVariables = new HpcLinqStringDictionary(); + private bool _selectAndWherePreserveOrder = false; + private bool _matchClientNetFrameworkVersion = false; + private bool _multiThreading = true; + + //Set these values using YARN_HOME and DRYAD_HOME environment variables + private string _yarnHome = Environment.GetEnvironmentVariable("YARN_HOME"); + private string _dryadHome = Environment.GetEnvironmentVariable("DRYAD_HOME"); + + private void ThrowIfReadOnly() + { + if (_isReadOnly) + { + throw new NotSupportedException(SR.ConfigReadonly); + } + } + + /// + /// Gets the value indicating whether the HpcLinqConfiguration is read-only. + /// + /// + /// When is true, every property except JobFriendlyName will throw a + /// from its setter. + /// + public bool IsReadOnly + { + get { return _isReadOnly; } + } + + /// + /// Gets or sets the value specifying whether data passed between stages in a HPC Query will be compressed. + /// + /// + /// The default is true. + /// + public DscCompressionScheme IntermediateDataCompressionScheme + { + get { return _intermediateDataCompressionScheme; } + set + { + ThrowIfReadOnly(); + _intermediateDataCompressionScheme = value; + } + } + + /// + /// Gets or sets the value specifying the compression scheme for output data. + /// + /// + /// The default is . + /// + public DscCompressionScheme OutputDataCompressionScheme + { + get { return _outputCompressionScheme; } + set + { + ThrowIfReadOnly(); + _outputCompressionScheme = value; + } + } + + /// + /// Gets or sets the value specifying whether to compile code that support debugging vertex tasks that execute on a HPC Server cluster. + /// + /// + /// If true, vertex code will be compiled with no code-level optimizations and a PDB will be generated. + /// Also, the query execution job look for and include the PDB associated with every DLL resource + /// that is part of the submitted job. + /// The default is false. + /// + public bool CompileForVertexDebugging + { + get { return _compileForVertexDebugging; } + set + { + ThrowIfReadOnly(); + _compileForVertexDebugging = value; + } + } + + /// + /// Gets or sets the bin directory for Dryad. + /// + public string DryadHomeDirectory + { + get { return _dryadHome; } + set + { + ThrowIfReadOnly(); + _dryadHome = value; + } + } + + /// + /// Gets or sets the home directory for Yarn. + /// + public string YarnHomeDirectory + { + get { return _yarnHome; } + set + { + ThrowIfReadOnly(); + _yarnHome = value; + } + } + + /// + /// Gets or sets the head node for the HPC Server used to execute the HPC Query job. + /// + public string HeadNode + { + get { return _headNode; } + set + { + ThrowIfReadOnly(); + _headNode = value; + } + } + + /// + /// Gets or sets the namenode for the HDFS. + /// + public string HdfsNameNode + { + get { return _hdfsNameNode; } + set + { + ThrowIfReadOnly(); + _hdfsNameNode = value; + } + } + + /// + /// Gets or sets the HTTP port used by the namenode for the HDFS. + /// + public int HdfsNameNodeHttpPort + { + get { return _hdfsNameNodeHttpPort; } + set + { + ThrowIfReadOnly(); + _hdfsNameNodeHttpPort = value; + } + } + + /// + /// Gets the collection of environment variables associated with the HPC Query job. + /// + public IDictionary JobEnvironmentVariables + { + get { return _jobEnvironmentVariables; } + } + + /// + /// Gets or sets the descriptive name used to describe the HPC Query job. + /// + /// + /// The default is null (no name). May be overriden by cluster settings such as node templates. + /// This property can be altered even when is true. + /// + public string JobFriendlyName + { + get { return _jobFriendlyName; } + set + { + _jobFriendlyName = value; + } + } + + /// + /// Gets or sets the minimum number of cluster nodes that the HPC Server job will use. + /// + /// + /// The default is null (no lower limit). May be overriden by cluster settings such as node templates. + /// + public int? JobMinNodes + { + get { return _jobMinNodes; } + set + { + ThrowIfReadOnly(); + _jobMinNodes = value; + } + } + + /// + /// Gets or sets the maximum number of cluster nodes that the HPC Server job will use. + /// + /// + /// The default is null (no upper limit). May be overriden by cluster settings such as node templates. + /// + public int? JobMaxNodes + { + get { return _jobMaxNodes; } + set + { + ThrowIfReadOnly(); + _jobMaxNodes = value; + } + } + + /// + /// Gets or sets the name of the compute node group that the HPC Server job will use. + /// + /// + /// Creation and management of nodes groups is performed using the HPC Cluster Manager. + /// + /// + /// The default is null (no node group restriction). May be overriden by cluster settings such as node templates. + /// + public string NodeGroup + { + get { return _nodeGroup; } + set + { + ThrowIfReadOnly(); + _nodeGroup = value; + } + } + + + /// + /// Gets or sets the maximum execution time for the HPC Query job, in seconds. + /// + /// + /// The default is null (no runtime limit). May be overriden by cluster settings such as node templates. + /// + public int? JobRuntimeLimit + { + get { return _jobRuntimeLimit; } + set + { + ThrowIfReadOnly(); + _jobRuntimeLimit = value; + } + } + + /// + /// Enables or disables speculative duplication of vertices based on runtime performance analysis. + /// + /// + /// The default is true. + /// + public bool EnableSpeculativeDuplication + { + get { return _enableSpeculativeDuplication; } + set + { + ThrowIfReadOnly(); + _enableSpeculativeDuplication = value; + } + } + + /// + /// Gets or sets the value specifying whether to use Local debugging mode. + /// + /// + /// + /// If true, the HPC Query will execute in the current AppDomain via LINQ-to-Objects. + /// This mode is particularly useful for debugging user-functions before attempting cluster execution. + /// LocalDebug mode accesses DSC as usual for input and output data. + /// + /// + /// LocalDebug mode does not perform vertex-code compilation, nor is a job submitted to HPC Server. + /// + /// The default is false. + /// + public bool LocalDebug + { + get { return _localDebug; } + set + { + ThrowIfReadOnly(); + _localDebug = value; + } + } + + /// + /// Get the list of resources to add to the HPC job used to execute a HPC Query. + /// + /// + /// + /// During query submission, some resources will be detected and added automatically. It is only necessary + /// to add resources that are not detected automatically. + /// + /// + /// Each resource should be a complete path to a file-based resource accessible from the local machine. + /// + /// + public IList ResourcesToAdd + { + get { return _resourcesToAdd; } + } + + /// + /// Get the list of resources to remove from the HPC job used to execute a HPC Query. + /// + /// + /// + /// During query submission, some resources will be detected and added automatically. + /// Remove resources that are detected automatically but that are not required for job execution. + /// + /// + /// Each resource should be a complete path to a file-based resource accessible from the local machine. + /// + /// + public IList ResourcesToRemove + { + get { return _resourcesToRemove; } + } + + /// + /// Gets or sets the RunAs password for jobs submitted to HPC Server. + /// + /// + /// The default is null (use the credentials of the current Thread) + /// + public string JobUsername + { + get { return _jobUsername; } + set + { + ThrowIfReadOnly(); + _jobUsername = value; + } + } + + /// + /// Gets or sets the RunAs password for jobs submitted to HPC Server. + /// + /// + /// The default is null (use the credentials of the current Thread) + /// + public string JobPassword + { + get { return _jobPassword; } + set + { + ThrowIfReadOnly(); + _jobPassword = value; + } + } + + /// + /// Gets or sets the trace level to use for HPC Query jobs. + /// + /// + /// The RuntimeTraceLevel affects the logs produced by all components associated with the execution + /// of a HPC Query job. + /// + /// The default is HpcQueryTraceLevel.Error + /// + public HpcQueryTraceLevel RuntimeTraceLevel + { + get { return _runtimeTraceLevel; } + set + { + ThrowIfReadOnly(); + _runtimeTraceLevel = value; + } + } + +#if YARN_MISSING_FEATURE + /// + /// Gets or sets the node that should be used for running the HPC Query Graph Manager task. + /// + /// + /// If null, the Graph Manager task will run on an arbitrary machine that is allocated to the HPC Query job. + /// + public string GraphManagerNode + { + get { return _graphManagerNode; } + set + { + ThrowIfReadOnly(); + _graphManagerNode = value; + } + } +#endif + + /// + /// Gets or sets whether certain operators will preserve item ordering. + /// When true, the Select, SelectMany and Where operators will preserve item ordering; + /// otherwise, they may shuffle the input items as they are processed. + /// + public bool SelectiveOrderPreservation + { + get { return _selectAndWherePreserveOrder; } + set + { + ThrowIfReadOnly(); + _selectAndWherePreserveOrder = value; + } + } + + /// + /// Configures query jobs to be launched on the cluster nodes against a .NET framework version + /// matching that of the client process. This should only be set if all cluster nodes are known to have + /// the same .NET version as the client. + /// When set to false (default), the vertex code will be compiled and run against .NET Framework 3.5. + /// + public bool MatchClientNetFrameworkVersion + { + get { return _matchClientNetFrameworkVersion; } + set + { + ThrowIfReadOnly(); + _matchClientNetFrameworkVersion = value; + } + } + + /// + /// Gets or sets whether user-defined methods and custom serializers may be called on multiple threads of a single process. + /// + /// + /// This option affects the internal behavior of individual queries and applies to both the client process (for serialization and local-debug mode) + /// and to vertex processes. + /// This option does not have any serializing effect for queries that are submitted concurrently by one or more client processes. + /// If true, user-defined methods may be called concurrently. + /// If false, user-defined methods will be called without concurrency. + /// + public bool AllowConcurrentUserDelegatesInSingleProcess + { + get { return _multiThreading; } + set + { + ThrowIfReadOnly(); + _multiThreading = value; + } + } + + /// + /// Initializes a new instance of the HpcLinqConfiguration class. + /// + public HpcLinqConfiguration() + { + CommonInit(); + } + + /// + /// Initializes a new instance of the HpcLinqConfiguration class. + /// + /// The head node for the HPC Server used to execute the HPC Query job. + public HpcLinqConfiguration(string headNode) + { + _headNode = headNode; + _hdfsNameNode = headNode; //default + CommonInit(); + } + + /// + /// Initializes a new instance of the HpcLinqConfiguration class. + /// + /// The head node for the HPC Server used to execute the HPC Query job. + /// The namenode for the HDFS. + public HpcLinqConfiguration(string headNode, string hdfsNameNode) + { + _headNode = headNode; + _hdfsNameNode = hdfsNameNode; + CommonInit(); + } + + private void CommonInit() + { + _yarnHome = Environment.GetEnvironmentVariable("YARN_HOME"); + _dryadHome = Environment.GetEnvironmentVariable("DRYAD_HOME"); + } + + internal HpcLinqConfiguration MakeImmutableCopy() + { + HpcLinqConfiguration newConfig = new HpcLinqConfiguration(); + + newConfig._isReadOnly = true; + + newConfig._jobEnvironmentVariables = this._jobEnvironmentVariables.GetImmutableClone(); + + newConfig._resourcesToAdd = this._resourcesToAdd.GetImmutableClone(); + newConfig._resourcesToRemove = this._resourcesToRemove.GetImmutableClone(); + + newConfig._intermediateDataCompressionScheme = this._intermediateDataCompressionScheme; + newConfig._outputCompressionScheme = this._outputCompressionScheme; + newConfig._compileForVertexDebugging = this._compileForVertexDebugging; + newConfig._headNode = this._headNode; + newConfig._hdfsNameNode = this._hdfsNameNode; + newConfig._hdfsNameNodeHttpPort = this._hdfsNameNodeHttpPort; + newConfig._jobFriendlyName = this._jobFriendlyName; + newConfig._jobMinNodes = this._jobMinNodes; + newConfig._jobMaxNodes = this._jobMaxNodes; + newConfig._nodeGroup = this._nodeGroup; + newConfig._jobRuntimeLimit = this._jobRuntimeLimit; + newConfig._localDebug = this._localDebug; + newConfig._orderPreserving = this._orderPreserving; + newConfig._jobUsername = this._jobUsername; + newConfig._jobPassword = this.JobPassword; + newConfig._runtimeTraceLevel = this._runtimeTraceLevel; + newConfig._graphManagerNode = this._graphManagerNode; + newConfig._selectAndWherePreserveOrder = this._selectAndWherePreserveOrder; + newConfig._matchClientNetFrameworkVersion = this._matchClientNetFrameworkVersion; + newConfig._enableSpeculativeDuplication = this._enableSpeculativeDuplication; + newConfig._multiThreading = this._multiThreading; + + newConfig._dryadHome = this._dryadHome; + newConfig._yarnHome = this._yarnHome; + + return newConfig; + } + } +} diff --git a/LinqToDryad/HpcLinqContext.cs b/LinqToDryad/HpcLinqContext.cs new file mode 100644 index 0000000..164d546 --- /dev/null +++ b/LinqToDryad/HpcLinqContext.cs @@ -0,0 +1,304 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Diagnostics; +using System.Linq.Expressions; +using System.Reflection; +using System.IO; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + + + /// + /// Represents the context necessary to prepare and execute a HPC LINQ Query, + /// + /// + /// + /// HpcLinqContext is the main entry point for the HPC LINQ framework. The context + /// that is maintained by a HpcLinqContext instance includes configuration information, a connection to a DSC Service + /// that can be used during execution, and a connection to a HPC Server. + /// + /// + /// A HpcLinqContext may be reused by multiple queries and query executions. + /// + /// + /// A HpcLinqContext may hold open connections to DSC and a HPC Server. To release these connections, call + /// HpcLinqContext.Dispose() + /// + /// + public sealed class HpcLinqContext : IDisposable + { + private HpcLinqConfiguration _configuration; + private HpcQueryRuntime _runtime; + private DscService _dscService; + private string _hdfsServiceNode; + private Version _clientVersion; + private Version _serverVersion; + + /// + /// Gets the configuration associated with this HpcLinqContext. + /// + /// + /// The Configuration object returns will be read-only. + /// + public HpcLinqConfiguration Configuration + { + get { + ThrowIfDisposed(); + return _configuration; + } + } + + /// + /// Gets the DscService associated with this HpcLinqContext. + /// + public DscService DscService + { + get + { + ThrowIfDisposed(); + return _dscService; + } + } + + /// + /// Gets the HdfsService associated with this HpcLinqContext. + /// + public string HdfsService + { + get + { + ThrowIfDisposed(); + return _hdfsServiceNode; + } + } + + // internal: the runtime associated with this HpcLinqContext. + internal HpcQueryRuntime Runtime { + get + { + ThrowIfDisposed(); + return _runtime; + } + } + + + + /// + /// Version of the HpcLinq client components + /// + public Version ClientVersion { + get { + ThrowIfDisposed(); + if (_clientVersion == null) + { + LoadClientVersion(); // thread-safe + } + return _clientVersion; + } + } + + /// + /// Version of the HpcLinq server components + /// + public Version ServerVersion { + get + { + ThrowIfDisposed(); + if (_serverVersion == null) + { + LoadServerVersion(); // thread-safe + } + return _serverVersion; + } + } + + /// + /// Initializes a new instance of the HpcLinqConfiguration class. + /// + /// Configuration information. + /// + /// Connections will be opened to DSC and HPC Server using configuration.HeadNode. + /// The connections will be opened regardless of whether DSC is used and/or whether + /// configuration.LocalDebug is true + /// + public HpcLinqContext(HpcLinqConfiguration configuration) + { + // Verify that the head node is set + if (configuration.HeadNode == null) + { + throw new DryadLinqException(HpcLinqErrorCode.ClusterNameMustBeSpecified, + SR.ClusterNameMustBeSpecified); + } + + _configuration = configuration.MakeImmutableCopy(); + _runtime = new HpcQueryRuntime(_configuration.HeadNode); + _dscService = new DscService(_configuration.HeadNode); + _hdfsServiceNode = _configuration.HdfsNameNode; + } + + private void LoadClientVersion() + { + try + { + Assembly asm = Assembly.GetExecutingAssembly(); + _clientVersion = new Version(FileVersionInfo.GetVersionInfo(asm.Location).FileVersion); + } + catch (Exception ex) + { + throw new DryadLinqException(HpcLinqErrorCode.CouldNotGetClientVersion, + SR.CouldNotGetClientVersion, ex); + } + } + + private void LoadServerVersion() + { + try + { + IServerVersion version = this.GetIScheduler().GetServerVersion(); + _serverVersion = new Version(version.Major, version.Minor, version.Build, version.Revision); + } + catch (Exception ex) + { + throw new DryadLinqException(HpcLinqErrorCode.CouldNotGetServerVersion, + SR.CouldNotGetServerVersion, ex); + } + } + + /// + /// Open a DSC fileset as a LINQ-to-HPC IQueryable{T}. + /// + /// The type of the records in the table. + /// The name of the DSC fileset. + /// An IQueryable{T} representing the data and associated with the HPC LINQ query provider. + public IQueryable FromDsc(string fileSetName) + { + ThrowIfDisposed(); + + string fullPath = DataPath.MakeDscStreamUri(_dscService, fileSetName); + + try { + DscFileSet fs = _dscService.GetFileSet(fileSetName); + if (!fs.IsSealed()) + { + throw new DryadLinqException(HpcLinqErrorCode.FileSetMustBeSealed, + SR.FileSetMustBeSealed); + } + + int fileCount = fs.GetFiles().Count(); + if (fileCount < 1) + { + throw new DryadLinqException(HpcLinqErrorCode.FileSetMustHaveAtLeastOneFile, + SR.FileSetMustHaveAtLeastOneFile); + } + + } + catch (DscException dscEx){ + throw new DryadLinqException(HpcLinqErrorCode.FileSetCouldNotBeOpened, + SR.FileSetCouldNotBeOpened, dscEx); + } + + DryadLinqQuery q = DataProvider.GetPartitionedTable(this, fullPath); + q.CheckAndInitialize(); // force the data-info checks. + return q; + } + + /// + /// Open a HDFS fileset as an IQueryable{T}. + /// + /// The type of the records in the table. + /// The name of the HDFS fileset. + /// An IQueryable{T} representing the data and associated with the HPC LINQ query provider. + public IQueryable FromHdfs(string fileSetName) + { + ThrowIfDisposed(); + + string fullPath = DataPath.MakeHdfsStreamUri(_hdfsServiceNode, fileSetName); + return DataProvider.GetPartitionedTable(this, fullPath); + } + + /// + /// Converts an IEnumerable{T} to a LINQ-to-HPC IQueryable{T}. + /// + /// The type of the records in the table. + /// The source data. + /// An IQueryable{T} representing the data and associated with the HPC LINQ query provider. + /// + /// The source data will be serialized to a DSC fileset using the LINQ-to-HPC serialization approach. + /// The resulting fileset will have an auto-generated name and a temporary lease. + /// + public IQueryable FromEnumerable(IEnumerable data) + { + string fileSetName = DataPath.MakeUniqueTemporaryDscFileSetName(); + DscCompressionScheme compressionScheme = Configuration.IntermediateDataCompressionScheme; + DryadLinqMetaData metadata = DryadLinqMetaData.ForLocalDebug(this, typeof(T), fileSetName, compressionScheme); + return DataProvider.IngressTemporaryDataDirectlyToDsc(this, data, fileSetName, metadata, compressionScheme); + } + + + + /// + /// Releases all resources used by the HpcLinqContext. + /// + public void Dispose() + { + _configuration = null; + if (_runtime != null) + { + _runtime.Dispose(); + _runtime = null; + } + if (_dscService != null) + { + _dscService.Close(); + _dscService = null; + } + } + + internal static HpcLinqContext GetContext(DryadLinqProviderBase provider) + { + HpcLinqContext context = provider.Context; + Debug.Assert(context != null, "A context should always be associated with a HpcLinqQuery"); + context.ThrowIfDisposed(); + return context; + } + + // Return IScheduler reference for internal use + internal IScheduler GetIScheduler() + { + return _runtime.GetIScheduler(); + } + + internal void ThrowIfDisposed() + { + if (_configuration == null) + { + throw new DryadLinqException(HpcLinqErrorCode.ContextDisposed, + SR.ContextDisposed); + } + } + } +} diff --git a/LinqToDryad/HpcLinqStringDictionary.cs b/LinqToDryad/HpcLinqStringDictionary.cs new file mode 100644 index 0000000..8084a71 --- /dev/null +++ b/LinqToDryad/HpcLinqStringDictionary.cs @@ -0,0 +1,170 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.Specialized; +using System.Linq; +using System.Text; + +namespace Microsoft.Research.DryadLinq +{ + /// + /// HpcLinq specific (string,string) dictionary that supports a readonly flag. + /// + internal class HpcLinqStringDictionary : IDictionary + { + private Dictionary _store = new Dictionary(); + private bool _isReadOnly = false; + + /// + /// Creatable only from HpcLinq assembly + /// + internal HpcLinqStringDictionary() + { + + } + + //this is useful to transfer data to HpcJobSubmission class + internal Dictionary BackingStore + { + get { return _store; } + } + + private void ThrowIfReadOnly() + { + if (_isReadOnly) + { + throw new NotSupportedException(SR.HpcLinqStringDictionaryReadonly); + } + } + + public void Add(string key, string value) + { + ThrowIfReadOnly(); + _store.Add(key, value); + } + + public bool ContainsKey(string key) + { + return _store.ContainsKey(key); + } + + public ICollection Keys + { + get { return _store.Keys; } + } + + public bool Remove(string key) + { + ThrowIfReadOnly(); + return _store.Remove(key); + } + + public bool TryGetValue(string key, out string value) + { + return _store.TryGetValue(key, out value); + } + + public ICollection Values + { + get { return _store.Keys; } + } + + public string this[string key] + { + get + { + return _store[key]; + } + set + { + ThrowIfReadOnly(); + _store[key] = value; + } + } + + public void Add(KeyValuePair item) + { + ThrowIfReadOnly(); + _store.Add(item.Key, item.Value); + } + + public void Clear() + { + ThrowIfReadOnly(); + _store.Clear(); + } + + public bool Contains(KeyValuePair item) + { + return _store.Contains(item); + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + ((IDictionary)_store).CopyTo(array, arrayIndex); + } + + public int Count + { + get { return _store.Count; } + } + + public bool IsReadOnly + { + get { return this._isReadOnly; } + } + + public bool Remove(KeyValuePair item) + { + ThrowIfReadOnly(); + return ((IDictionary)_store).Remove(item); + } + + public IEnumerator> GetEnumerator() + { + return _store.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + /// + /// Gets or sets the collection as read only + /// + internal HpcLinqStringDictionary GetImmutableClone() + { + HpcLinqStringDictionary clone = new HpcLinqStringDictionary(); + + foreach ( var keyValuePair in this._store) + { + clone.Add(keyValuePair); + } + + clone._isReadOnly = true; + + return clone; + } + } +} diff --git a/LinqToDryad/HpcLinqStringList.cs b/LinqToDryad/HpcLinqStringList.cs new file mode 100644 index 0000000..4583830 --- /dev/null +++ b/LinqToDryad/HpcLinqStringList.cs @@ -0,0 +1,141 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.Specialized; +using System.Linq; +using System.Text; + +namespace Microsoft.Research.DryadLinq +{ + /// + /// HpcLinq specific list-of-string that supports a readonly flag. + /// + internal class HpcLinqStringList : IList + { + private List _store = new List(); + private bool _isReadOnly = false; + + /// + /// Creatable only from HpcLinq assembly + /// + internal HpcLinqStringList() + { + } + + private void ThrowIfReadOnly() + { + if (_isReadOnly) + { + // @TODO: Post June'11 RTM. Get its own res-string. Current msg is apppriate, however. + throw new NotSupportedException(SR.HpcLinqStringDictionaryReadonly); + } + } + + internal HpcLinqStringList GetImmutableClone() + { + HpcLinqStringList clone = new HpcLinqStringList(); + clone._store.AddRange(_store); + clone._isReadOnly = true; + + return clone; + } + + public int IndexOf(string item) + { + return _store.IndexOf(item); + } + + public void Insert(int index, string item) + { + ThrowIfReadOnly(); + _store.Insert(index, item); + } + + public void RemoveAt(int index) + { + ThrowIfReadOnly(); + _store.RemoveAt(index); + } + + public string this[int index] + { + get + { + return _store[index]; + } + set + { + ThrowIfReadOnly(); + _store[index] = value; + } + } + + public void Add(string item) + { + ThrowIfReadOnly(); + _store.Add(item); + } + + public void Clear() + { + ThrowIfReadOnly(); + _store.Clear(); + } + + public bool Contains(string item) + { + return _store.Contains(item); + } + + public void CopyTo(string[] array, int arrayIndex) + { + _store.CopyTo(array, arrayIndex); + } + + public int Count + { + get { return _store.Count; } + } + + public bool IsReadOnly + { + get { return _isReadOnly; } + } + + public bool Remove(string item) + { + ThrowIfReadOnly(); + return _store.Remove(item); + } + + public IEnumerator GetEnumerator() + { + return _store.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } +} diff --git a/LinqToDryad/IAssociative.cs b/LinqToDryad/IAssociative.cs new file mode 100644 index 0000000..62b14c1 --- /dev/null +++ b/LinqToDryad/IAssociative.cs @@ -0,0 +1,49 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Microsoft.Research.DryadLinq +{ + public interface IAssociative + { + TAccumulate Seed(); + TAccumulate RecursiveAccumulate(TAccumulate acc, TAccumulate val); + } + + public static class GenericAssociative + where TAssoc : IAssociative, new() + { + private static TAssoc a = new TAssoc(); + + public static TAccumulate Seed() + { + return a.Seed(); + } + + public static TAccumulate RecursiveAccumulate(TAccumulate acc, TAccumulate val) + { + return a.RecursiveAccumulate(acc, val); + } + } +} diff --git a/LinqToDryad/IDecomposable.cs b/LinqToDryad/IDecomposable.cs new file mode 100644 index 0000000..aee2678 --- /dev/null +++ b/LinqToDryad/IDecomposable.cs @@ -0,0 +1,70 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Microsoft.Research.DryadLinq +{ + public interface IDecomposable + { + void Initialize(object state); + TAccumulate Seed(TSource val); + TAccumulate Accumulate(TAccumulate acc, TSource val); + TAccumulate RecursiveAccumulate(TAccumulate acc, TAccumulate val); + TResult FinalReduce(TAccumulate val); + } + + public static class GenericDecomposable + where TDecomposable : IDecomposable, new() + { + private static TDecomposable d = new TDecomposable(); + + public static void Initialize(object state) + { + d.Initialize(state); + } + + public static TAccumulate Seed(TSource val) + { + return d.Seed(val); + } + + public static TAccumulate Accumulate(TAccumulate acc, TSource val) + { + return d.Accumulate(acc, val); + } + + public static TAccumulate RecursiveAccumulate(TAccumulate acc, TAccumulate val) + { + return d.RecursiveAccumulate(acc, val); + } + + public static TResult FinalReduce(TAccumulate val) + { + return d.FinalReduce(val); + } + } +} diff --git a/LinqToDryad/IDryadLinqJobSubmission.cs b/LinqToDryad/IDryadLinqJobSubmission.cs new file mode 100644 index 0000000..b109498 --- /dev/null +++ b/LinqToDryad/IDryadLinqJobSubmission.cs @@ -0,0 +1,70 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; + +namespace Microsoft.Research.DryadLinq +{ + /// + /// Status of a Dryad job computing a set of PartitionedTables. + /// + internal enum JobStatus + { + /// + /// Job has not been submitted yet. + /// + NotSubmitted, + /// + /// Job is waiting in the scheduler queue. + /// + Waiting, + /// + /// Job is running on the cluster. + /// + Running, + /// + /// Job has completed successfully. + /// + Success, + /// + /// Job execution failed. + /// + Failure, + /// + /// Job has been cancelled by user. + /// + Cancelled + } + + internal interface IHpcLinqJobSubmission + { + void AddJobOption(string fieldName, string fieldVal); + void AddLocalFile(string fileName); + void AddRemoteFile(string fileName); + string ErrorMsg { get; } + JobStatus GetStatus(); + void SubmitJob(); + JobStatus TerminateJob(); + int GetJobId(); + } +} diff --git a/LinqToDryad/LineRecord.cs b/LinqToDryad/LineRecord.cs new file mode 100644 index 0000000..bf758c2 --- /dev/null +++ b/LinqToDryad/LineRecord.cs @@ -0,0 +1,122 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; +using System.Reflection; +using System.Diagnostics; + +namespace Microsoft.Research.DryadLinq +{ + // We override Equals and GetHashCode for better performance. + [Serializable] + public struct LineRecord : IComparable, IComparable, IEnumerable, IEnumerable, IEquatable + { + private string _line; + public string Line + { + get { return _line; } + internal set { _line = value; } + } + + public LineRecord(string line) + { + _line = line; + } + + public override bool Equals(Object obj) + { + if (!(obj is LineRecord)) return false; + return this.Line.Equals(((LineRecord)obj).Line); + } + + public bool Equals(LineRecord val) + { + return this.Line.Equals(val.Line); + } + + public static bool Equals(LineRecord a, LineRecord b) + { + return a.Equals(b); + } + + public static bool operator ==(LineRecord a, LineRecord b) + { + return a.Equals(b); + } + + public static bool operator !=(LineRecord a, LineRecord b) + { + return !a.Equals(b); + } + + public static bool operator <(LineRecord a, LineRecord b) + { + return a.CompareTo(b) < 0; + } + + public static bool operator >(LineRecord a, LineRecord b) + { + return a.CompareTo(b) > 0; + } + + public override int GetHashCode() + { + return this.Line.GetHashCode(); + } + + public int CompareTo(Object val) + { + if (val == null) return 1; + + if (!(val is LineRecord)) + { + throw new ArgumentException(SR.CompareArgIncorrect, "val"); + } + + return StringComparer.Ordinal.Compare(this.Line, ((LineRecord)val).Line); + } + + public int CompareTo(LineRecord val) + { + return StringComparer.Ordinal.Compare(this.Line, val.Line); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.Line.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.Line.GetEnumerator(); + } + + public override String ToString() + { + return this.Line; + } + } +} diff --git a/LinqToDryad/LinqToDryad.csproj b/LinqToDryad/LinqToDryad.csproj new file mode 100644 index 0000000..6601e7c --- /dev/null +++ b/LinqToDryad/LinqToDryad.csproj @@ -0,0 +1,161 @@ + + + + Debug + AnyCPU + 8.0.30703 + 2.0 + {D33C34CC-6DB2-417C-88B7-299830711774} + Library + Properties + Microsoft.Research.DryadLinq + Microsoft.Research.DryadLinq + v4.0 + 512 + //depot/v3sp1/private/disc/LinqToDryad + . + + + MSSCCI:Source Depot + + + + true + full + false + ..\bin\Debug\ + TRACE;DEBUG + prompt + 4 + true + x64 + + + pdbonly + true + ..\x64\Release\ + TRACE + prompt + 4 + true + x64 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + True + True + SR.resx + + + + + + + + + + + + + ResXFileCodeGenerator + Designer + SR.Designer.cs + + + + + {C0F4C1E3-1F9E-4C55-BD6A-0241D35425F5} + HdfsBridgeManaged + + + + + \ No newline at end of file diff --git a/LinqToDryad/MultiBlockStream.cs b/LinqToDryad/MultiBlockStream.cs new file mode 100644 index 0000000..a008591 --- /dev/null +++ b/LinqToDryad/MultiBlockStream.cs @@ -0,0 +1,229 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Reflection; +using System.Diagnostics; +using Microsoft.Win32.SafeHandles; +using System.Runtime.InteropServices; +using Microsoft.Research.DryadLinq; + + +namespace Microsoft.Research.DryadLinq.Internal +{ + // The class directly talks to a list of NTFS/Cosmos files. + internal unsafe class MultiBlockStream : NativeBlockStream + { + private List m_srcList; // each source is represented by an array of alternative paths (i.e. replica paths) + private string m_associatedDscStreamName; // stored here only to provide a better exception message in case of IO errors + private DscCompressionScheme m_compressionScheme; + private int m_curIdx; + private NativeBlockStream m_curStream; + private static byte* s_newlineByteBlock; // holds a dummy block comprising ['\r', '\n'], lazily initialized. + + //re bug: 16011, initialize trailingByes as a newline pair so that we dont add a newline if the first file (or files) are empty. + byte[] m_trailingBytesOfData = new byte[2] { (byte)'\r', (byte)'\n' }; + private bool m_appendNewLinesToFiles; + + internal MultiBlockStream(List srcList, + string associatedDscStreamName, + FileAccess access, + DscCompressionScheme scheme, + bool appendNewLinesToFiles) + { + this.m_srcList = srcList; + m_associatedDscStreamName = associatedDscStreamName; + if (srcList.Count == 0) + { + throw new DryadLinqException(HpcLinqErrorCode.MultiBlockEmptyPartitionList, + SR.MultiBlockEmptyPartitionList); + } + this.m_compressionScheme = scheme; + this.m_curIdx = 0; + this.m_curStream = this.GetStream(this.m_curIdx++, access); + this.m_appendNewLinesToFiles = appendNewLinesToFiles; + } + + private NativeBlockStream GetStream(int idx, FileAccess access) + { + HpcLinqFileStream fileStream = null; + string[] pathAlternatives = this.m_srcList[idx]; + + for (int i = 0; i < pathAlternatives.Length; i++) + { + bool bLastIter = (i == pathAlternatives.Length - 1); + string curSrcPath = pathAlternatives[i]; + try + { + fileStream = new HpcLinqFileStream(curSrcPath, access, this.m_compressionScheme); + } + catch(Exception exp) + { + + // if we have more path alternatives to try we will continue, + // otherwise we'll propagate the exception from this last attempt + if (bLastIter) + { + // if we caught the HpcLinqException thrown by HpcLinqFileStream.Initialize, + // we want to propagate its inner exception (which contains the actual IO error) + // otherwise we'll attach the exception we caught as is + Exception innerException = exp is DryadLinqException ? innerException = exp.InnerException : exp; + throw new DryadLinqException(HpcLinqErrorCode.MultiBlockCannotAccesFilePath, + String.Format(SR.MultiBlockCannotAccesFilePath, + curSrcPath, m_associatedDscStreamName), + innerException); + } + } + + // if the attempt to initialize an HpcLinqFileStream with this path succeeded we'll return the object + if (fileStream != null) break; + } + + return fileStream; + } + + internal override unsafe Int64 GetTotalLength() + { + Int64 totalLen = 0; + for (int i = 0; i < this.m_srcList.Count; i++) + { + NativeBlockStream ns = this.GetStream(i, FileAccess.Read); + totalLen += ns.GetTotalLength(); + ns.Close(); + } + return totalLen; + } + + internal override unsafe DataBlockInfo ReadDataBlock() + { + // free the dummy block if it was allocated. + if (s_newlineByteBlock != null) + { + Marshal.FreeHGlobal((IntPtr) s_newlineByteBlock); + s_newlineByteBlock = null; + } + + while (true) + { + DataBlockInfo dataBlockInfo = this.m_curStream.ReadDataBlock(); + + if (dataBlockInfo.blockSize == 0 && this.m_curIdx == m_srcList.Count) + { + // data has been exhausted. We return the empty block and the caller knows what to do. + return dataBlockInfo; + } + + // normal case.. record the last two bytes for newline tracking, and return the block. + if (dataBlockInfo.blockSize > 0) + { + if (m_appendNewLinesToFiles) + { + if (dataBlockInfo.blockSize >= 2) + { + m_trailingBytesOfData[0] = dataBlockInfo.dataBlock[dataBlockInfo.blockSize - 2]; + m_trailingBytesOfData[1] = dataBlockInfo.dataBlock[dataBlockInfo.blockSize - 1]; + } + else + { + Debug.Assert(dataBlockInfo.blockSize == 1); + // CASE: dataBlockInfo.blockSize == 1 + // shift left. + // We must do this otherwise the following data could fail to be identified + // Blocks = [.........\r] [\n] + m_trailingBytesOfData[0] = m_trailingBytesOfData[1]; //shift + m_trailingBytesOfData[1] = dataBlockInfo.dataBlock[0]; // record the single element. + } + } + + return dataBlockInfo; + } + + this.m_curStream.ReleaseDataBlock(dataBlockInfo.itemHandle); + this.m_curStream.Close(); + this.m_curStream = this.GetStream(this.m_curIdx++, FileAccess.Read); + + // we only get here when a file is fully consumed and it wasn't the last file in the set. + // if the data stream didn't end with a newline-pair, emit one so that + // LineRecord-parsing will work correctly. + // the next time we enter this method, we will free the unmanaged data and the + // next real block of data will be read and consumed. + + if (m_appendNewLinesToFiles) + { + //@@TODO[p3]: we currently only observe and insert \r\n pairs. + // Unicode may have other types of newline to consider. + if (m_trailingBytesOfData[0] != '\r' || m_trailingBytesOfData[1] != '\n') + { + // create the dummy block. + if (s_newlineByteBlock == null) + { + s_newlineByteBlock = (byte*)Marshal.AllocHGlobal(2); + s_newlineByteBlock[0] = (byte)'\r'; + s_newlineByteBlock[1] = (byte)'\n'; + } + DataBlockInfo dummyblock = new DataBlockInfo(); + dummyblock.blockSize = 2; + dummyblock.dataBlock = s_newlineByteBlock; + dummyblock.itemHandle = IntPtr.Zero; + + return dummyblock; + } + } + } + } + + internal override unsafe bool WriteDataBlock(IntPtr itemHandle, Int32 numBytesToWrite) + { + return this.m_curStream.WriteDataBlock(itemHandle, numBytesToWrite); + } + + internal override void Flush() + { + this.m_curStream.Flush(); + } + + internal override void Close() + { + this.m_curStream.Close(); + } + + internal override unsafe DataBlockInfo AllocateDataBlock(Int32 size) + { + return this.m_curStream.AllocateDataBlock(size); + } + + internal override unsafe void ReleaseDataBlock(IntPtr itemHandle) + { + this.m_curStream.ReleaseDataBlock(itemHandle); + } + + internal override unsafe string GetURI() + { + return ""; // the only use of this is to be put in an error msg.. We don't have an appropriate URI, so we just use the empty string. (bug 13970) + } + } +} diff --git a/LinqToDryad/MultiEnumerable.cs b/LinqToDryad/MultiEnumerable.cs new file mode 100644 index 0000000..b20fbf2 --- /dev/null +++ b/LinqToDryad/MultiEnumerable.cs @@ -0,0 +1,169 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Globalization; +using System.Reflection; +using System.Linq; +using Microsoft.Research.DryadLinq; + +namespace Microsoft.Research.DryadLinq.Internal +{ + public interface IMultiEnumerable + { + UInt32 NumberOfInputs { get; } + } + + public interface IMultiEnumerable : IEnumerable + { + UInt32 NumberOfInputs { get; } + IEnumerable this[int index] { get; } + } + + public interface IMultiEnumerable : IMultiEnumerable + { + IEnumerable First { get; } + IEnumerable Second { get; } + } + + public interface IMultiEnumerable : IMultiEnumerable + { + IEnumerable First { get; } + IEnumerable Second { get; } + IEnumerable Third { get; } + } + + public class MultiEnumerable : IMultiEnumerable + { + private IEnumerable[] m_enumList; + + public MultiEnumerable(IEnumerable[] enumList) + { + this.m_enumList = enumList; + } + + public UInt32 NumberOfInputs + { + get { return (UInt32)this.m_enumList.Length; } + } + + public IEnumerable this[int index] + { + get + { + if (index < this.m_enumList.Length) + { + return this.m_enumList[index]; + } + + //@@TODO: throw ArgumentOutOfRangeException? + throw new DryadLinqException(HpcLinqErrorCode.IndexOutOfRange, SR.IndexOutOfRange); + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + foreach (var x in this.m_enumList) + { + foreach (var y in x) + { + yield return y; + } + } + } + } + + public class MultiEnumerable : IMultiEnumerable + { + private IEnumerable m_first; + private IEnumerable m_second; + + public MultiEnumerable(IEnumerable first, IEnumerable second) + { + this.m_first = first; + this.m_second = second; + } + + public UInt32 NumberOfInputs + { + get { return 2; } + } + + public IEnumerable First + { + get { return this.m_first; } + } + + public IEnumerable Second + { + get { return this.m_second; } + } + } + + public class MultiEnumerable : IMultiEnumerable + { + private IEnumerable m_first; + private IEnumerable m_second; + private IEnumerable m_third; + + public MultiEnumerable(IEnumerable first, IEnumerable second, IEnumerable third) + { + this.m_first = first; + this.m_second = second; + this.m_third = third; + } + + public UInt32 NumberOfInputs + { + get { return 3; } + } + + public IEnumerable First + { + get { return this.m_first; } + } + + public IEnumerable Second + { + get { return this.m_second; } + } + + public IEnumerable Third + { + get { return this.m_third; } + } + } +} diff --git a/LinqToDryad/MultiQueryable.cs b/LinqToDryad/MultiQueryable.cs new file mode 100644 index 0000000..9f504f9 --- /dev/null +++ b/LinqToDryad/MultiQueryable.cs @@ -0,0 +1,330 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// © Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Globalization; +using System.Reflection; +using System.Linq.Expressions; +using System.Linq; +using System.Diagnostics; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + public interface IMultiQueryable + { + Type ElementType(int index); + Expression Expression { get; } + IQueryProvider Provider { get; } + UInt32 NumberOfInputs { get; } + } + + public interface IKeyedMultiQueryable : IMultiQueryable + { + IQueryable this[K key] { get; } + K[] Keys { get; } + } + + public interface IMultiQueryable : IMultiQueryable + { + IQueryable First { get; } + IQueryable Second { get; } + } + + public interface IMultiQueryable : IMultiQueryable + { + IQueryable First { get; } + IQueryable Second { get; } + IQueryable Third { get; } + } + + public class MultiQueryable : IKeyedMultiQueryable + { + private IQueryable m_source; + private Expression m_queryExpression; + private IMultiEnumerable m_enumerables; + private K[] m_keys; + private Dictionary m_keyMap; + + public MultiQueryable(IQueryable source, K[] keys, IMultiEnumerable enumerables) + { + this.m_source = source; + this.m_queryExpression = null; + this.m_enumerables = enumerables; + this.m_keys = new K[keys.Length]; + this.m_keyMap = new Dictionary(keys.Length); + for (int i = 0; i < keys.Length; i++) + { + this.m_keys[i] = keys[i]; + this.m_keyMap.Add(keys[i], i); + } + } + + public MultiQueryable(IQueryable source, K[] keys, Expression queryExpr) + { + this.m_source = source; + this.m_queryExpression = queryExpr; + this.m_enumerables = null; + this.m_keys = new K[keys.Length]; + this.m_keyMap = new Dictionary(keys.Length); + for (int i = 0; i < keys.Length; i++) + { + this.m_keys[i] = keys[i]; + this.m_keyMap.Add(keys[i], i); + } + } + + public Type ElementType(int index) + { + return typeof(T); + } + + public Expression Expression + { + get { return this.m_queryExpression; } + } + + public IQueryProvider Provider + { + get { return this.m_source.Provider; } + } + + public UInt32 NumberOfInputs + { + get + { + if (this.m_enumerables != null) + { + return (UInt32)this.m_enumerables.NumberOfInputs; + } + return (UInt32)this.m_keys.Length; + } + } + + public K[] Keys + { + get { return this.m_keys; } + } + + public IQueryable this[K key] + { + get + { + int index; + if (this.m_keyMap.TryGetValue(key, out index)) + { + if (this.m_enumerables != null) + { + var q = this.m_enumerables[index].AsQueryable(); + return new DryadLinqLocalQuery(this.Provider, q); + } + return this.ForkChoose(index); + } + + //@@TODO: throw ArgumentOutOfRangeException? + throw new DryadLinqException(HpcLinqErrorCode.MultiQueryableKeyOutOfRange, + SR.MultiQueryableKeyOutOfRange); + } + } + } + + public class MultiQueryable : IMultiQueryable + { + private IQueryable m_source; + private Expression m_queryExpression; + private IMultiEnumerable m_enumerables; + + public MultiQueryable(IQueryable source, IMultiEnumerable enumerables) + { + this.m_source = source; + this.m_queryExpression = null; + this.m_enumerables = enumerables; + } + + public MultiQueryable(IQueryable source, Expression queryExpr) + { + this.m_source = source; + this.m_queryExpression = queryExpr; + this.m_enumerables = null; + } + + public Type ElementType(int index) + { + if (index == 0) + { + return typeof(R1); + } + else if (index == 1) + { + return typeof(R2); + } + else + { + //@@TODO: throw ArgumentOutOfRangeException? + throw new DryadLinqException(HpcLinqErrorCode.IndexOutOfRange, + SR.IndexOutOfRange); + } + } + + public Expression Expression + { + get { return this.m_queryExpression; } + } + + public IQueryProvider Provider + { + get { return this.m_source.Provider; } + } + + public UInt32 NumberOfInputs + { + get { return 2; } + } + + public IQueryable First + { + get + { + if (this.m_enumerables != null) + { + var q = this.m_enumerables.First.AsQueryable(); + return new DryadLinqLocalQuery(this.Provider, q); + } + return this.ForkChoose(0); + } + } + + public IQueryable Second + { + get + { + if (this.m_enumerables != null) + { + var q = this.m_enumerables.Second.AsQueryable(); + return new DryadLinqLocalQuery(this.Provider, q); + } + return this.ForkChoose(1); + } + } + } + + public class MultiQueryable : IMultiQueryable + { + private IQueryable m_source; + private Expression m_queryExpression; + private IMultiEnumerable m_enumerables; + + public MultiQueryable(IQueryable source, Expression queryExpr) + { + this.m_source = source; + this.m_queryExpression = queryExpr; + this.m_enumerables = null; + } + + public MultiQueryable(IQueryable source, IMultiEnumerable enumerables) + { + this.m_source = source; + this.m_queryExpression = null; + this.m_enumerables = enumerables; + } + + public Type ElementType(int index) + { + if (index == 0) + { + return typeof(R1); + } + else if (index == 1) + { + return typeof(R2); + } + else if (index == 2) + { + return typeof(R3); + } + else + { + //@@TODO: throw ArgumentOutOfRangeException? + throw new DryadLinqException(HpcLinqErrorCode.IndexOutOfRange, + SR.IndexOutOfRange); + } + } + + public Expression Expression + { + get { return this.m_queryExpression; } + } + + public IQueryProvider Provider + { + get { return this.m_source.Provider; } + } + + public UInt32 NumberOfInputs + { + get { return 3; } + } + + public IQueryable First + { + get + { + if (this.m_enumerables != null) + { + var q = this.m_enumerables.First.AsQueryable(); + return new DryadLinqLocalQuery(this.Provider, q); + } + return this.ForkChoose(0); + } + } + + public IQueryable Second + { + get + { + if (this.m_enumerables != null) + { + var q = this.m_enumerables.Second.AsQueryable(); + return new DryadLinqLocalQuery(this.Provider, q); + } + return this.ForkChoose(1); + } + } + + public IQueryable Third + { + get + { + if (this.m_enumerables != null) + { + var q = this.m_enumerables.Third.AsQueryable(); + return new DryadLinqLocalQuery(this.Provider, q); + } + return this.ForkChoose(2); + } + } + } +} diff --git a/LinqToDryad/NativeBlockStream.cs b/LinqToDryad/NativeBlockStream.cs new file mode 100644 index 0000000..56dcf0d --- /dev/null +++ b/LinqToDryad/NativeBlockStream.cs @@ -0,0 +1,239 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Security; +using System.Threading; +using System.Runtime.InteropServices; +using Microsoft.Win32.SafeHandles; +using System.Diagnostics; +using System.IO; +using Microsoft.Research.DryadLinq; +using Microsoft.Research.DryadLinq.Internal; + + +namespace Microsoft.Research.DryadLinq.Internal +{ + internal unsafe struct DataBlockInfo + { + internal byte* dataBlock; + internal Int32 blockSize; + internal IntPtr itemHandle; + } + + // this type is public on the outside but all its members are marked internal + // because generated vertex code needs to pass around references to it but + // doesn't call any methods, nor should client code. + public abstract class NativeBlockStream + { + internal abstract Int64 GetTotalLength(); + + internal abstract unsafe DataBlockInfo ReadDataBlock(); + + internal abstract unsafe bool WriteDataBlock(IntPtr itemHandle, Int32 numBytesToWrite); + + internal abstract unsafe DataBlockInfo AllocateDataBlock(Int32 size); + + internal abstract unsafe void ReleaseDataBlock(IntPtr itemHandle); + + internal abstract void Flush(); + + internal abstract void Close(); + + internal virtual string GetURI() + { + throw new DryadLinqException(HpcLinqErrorCode.GetURINotSupported, + SR.GetURINotSupported); + } + + internal virtual void SetCalcFP() + { + throw new DryadLinqException(HpcLinqErrorCode.SetCalcFPNotSupported, + SR.SetCalcFPNotSupported); + } + + internal virtual UInt64 GetFingerPrint() + { + throw new DryadLinqException(HpcLinqErrorCode.GetFPNotSupported, + SR.GetFPNotSupported); + } + } + + internal sealed class HpcLinqChannel : NativeBlockStream + { + private IntPtr m_vertexInfo; + private UInt32 m_portNum; + private bool m_isInput; + private bool m_isClosed; + + internal HpcLinqChannel(IntPtr vertexInfo, UInt32 portNum, bool isInput) + { + this.m_vertexInfo = vertexInfo; + this.m_portNum = portNum; + this.m_isInput = isInput; + this.m_isClosed = false; + } + + ~HpcLinqChannel() + { + this.Close(); + } + + internal IntPtr NativeHandle + { + get { return this.m_vertexInfo; } + } + + internal UInt32 PortNumber + { + get { return this.m_portNum; } + } + + internal override unsafe Int64 GetTotalLength() + { + if (this.m_isInput) + { + return HpcLinqNative.GetExpectedLength(this.m_vertexInfo, this.m_portNum); + } + else + { + throw new NotImplementedException(); + } + } + + internal override unsafe DataBlockInfo AllocateDataBlock(Int32 size) + { + DataBlockInfo blockInfo; + blockInfo.itemHandle = + HpcLinqNative.AllocateDataBlock(this.m_vertexInfo, size, &blockInfo.dataBlock); + blockInfo.blockSize = size; + if (blockInfo.itemHandle == IntPtr.Zero) + { + throw new DryadLinqException(HpcLinqErrorCode.FailedToAllocateNewNativeBuffer, + String.Format(SR.FailedToAllocateNewNativeBuffer, size)); + } + // DryadLinqLog.Add("Allocated data block {0} of {1} bytes.", blockInfo.itemHandle, size); + return blockInfo; + } + + internal override unsafe void ReleaseDataBlock(IntPtr itemHandle) + { + if (itemHandle != IntPtr.Zero) + { + HpcLinqNative.ReleaseDataBlock(this.m_vertexInfo, itemHandle); + } + // DryadLinqLog.Add("Released data block {0}.", itemHandle); + } + + internal override unsafe DataBlockInfo ReadDataBlock() + { + DataBlockInfo blockInfo; + Int32 errorCode = 0; + blockInfo.itemHandle = HpcLinqNative.ReadDataBlock(this.m_vertexInfo, + this.m_portNum, + &blockInfo.dataBlock, + &blockInfo.blockSize, + &errorCode); + if (errorCode != 0) + { + HpcLinqVertexEnv.ErrorCode = errorCode; + throw new DryadLinqException(HpcLinqErrorCode.FailedToReadFromInputChannel, + String.Format(SR.FailedToReadFromInputChannel, + this.m_portNum, errorCode)); + } + return blockInfo; + } + + internal override unsafe bool WriteDataBlock(IntPtr itemHandle, Int32 numBytesToWrite) + { + bool success = true; + if (numBytesToWrite > 0) + { + success = HpcLinqNative.WriteDataBlock(this.m_vertexInfo, + this.m_portNum, + itemHandle, + numBytesToWrite); + + if (!success) + { + + throw new DryadLinqException(HpcLinqErrorCode.FailedToWriteToOutputChannel, + String.Format(SR.FailedToWriteToOutputChannel, + this.m_portNum)); + } + } + return success; + } + + internal override void SetCalcFP() + { + throw new DryadLinqException(HpcLinqErrorCode.SetCalcFPNotSupported, + SR.SetCalcFPNotSupported); + } + + internal override UInt64 GetFingerPrint() + { + throw new DryadLinqException(HpcLinqErrorCode.GetFPNotSupported, + SR.GetFPNotSupported); + } + + internal override unsafe string GetURI() + { + IntPtr uriPtr; + if (this.m_isInput) + { + uriPtr = HpcLinqNative.GetInputChannelURI(this.m_vertexInfo, this.m_portNum); + } + else + { + uriPtr = HpcLinqNative.GetOutputChannelURI(this.m_vertexInfo, this.m_portNum); + } + return Marshal.PtrToStringAnsi(uriPtr); + } + + internal override void Flush() + { + HpcLinqNative.Flush(this.m_vertexInfo, this.m_portNum); + } + + internal override void Close() + { + if (!this.m_isClosed) + { + this.m_isClosed = true; + this.Flush(); + HpcLinqNative.Close(this.m_vertexInfo, this.m_portNum); + string ctype = (this.m_isInput) ? "Input" : "Output"; + DryadLinqLog.Add(ctype + " channel {0} was closed.", this.m_portNum); + } + GC.SuppressFinalize(this); + } + + public override string ToString() + { + return "DryadChannel[" + PortNumber + "]"; + } + } +} diff --git a/LinqToDryad/QueryTraceLevel.cs b/LinqToDryad/QueryTraceLevel.cs new file mode 100644 index 0000000..a94c1a6 --- /dev/null +++ b/LinqToDryad/QueryTraceLevel.cs @@ -0,0 +1,39 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; +using Microsoft.Hpc.Dryad; + +namespace Microsoft.Research.DryadLinq +{ + /// + /// Trace levels for HpcQuery runtime + /// + public enum HpcQueryTraceLevel : int + { + // Use internal constants since public type cannot be easily shared acrossed DLLs w/o resulting in ambiquities + Off = Constants.traceOffNum, + Critical = Constants.traceCriticalNum, + Error = Constants.traceErrorNum, + Warning = Constants.traceWarningNum, + Information = Constants.traceInfoNum, + Verbose = Constants.traceVerboseNum + } +} diff --git a/LinqToDryad/SR.Designer.cs b/LinqToDryad/SR.Designer.cs new file mode 100644 index 0000000..8607389 --- /dev/null +++ b/LinqToDryad/SR.Designer.cs @@ -0,0 +1,1953 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.18033 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.Research.DryadLinq { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "4.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class SR { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal SR() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Microsoft.Research.DryadLinq.SR", typeof(SR).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + /// + /// Looks up a localized string similar to Internal error: AddVertexMethod on {0} not handled.. + /// + internal static string AddVertexNotHandled { + get { + return ResourceManager.GetString("AddVertexNotHandled", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Aggregate: No elements.. + /// + internal static string AggregateNoElements { + get { + return ResourceManager.GetString("AggregateNoElements", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The aggregate operator {0} is not supported.. + /// + internal static string AggregateOperatorNotSupported { + get { + return ResourceManager.GetString("AggregateOperatorNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Aggregation operator '{0}' can only work on objects that implement IComparable.. + /// + internal static string AggregationOperatorRequiresIComparable { + get { + return ResourceManager.GetString("AggregationOperatorRequiresIComparable", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to This query instance has already been submitted. To submit a query more than once, create a new IQueryable<> instance.. + /// + internal static string AlreadySubmitted { + get { + return ResourceManager.GetString("AlreadySubmitted", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to A query instance has already been submitted. To submit a query more than once, create a new IQueryable<> instance.. + /// + internal static string AlreadySubmittedInMaterialize { + get { + return ResourceManager.GetString("AlreadySubmittedInMaterialize", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Length of array {0} must be greater than or equal to {1} ({2} + {3}).. + /// + internal static string ArrayLengthVsCountAndOffset { + get { + return ResourceManager.GetString("ArrayLengthVsCountAndOffset", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to A method tagged [Associative] should take two parameters of type T and return type T. Method={0}.. + /// + internal static string AssociativeMethodHasWrongForm { + get { + return ResourceManager.GetString("AssociativeMethodHasWrongForm", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Associative class must have a public parameterless constructor. Class={0}.. + /// + internal static string AssociativeTypeDoesNotHavePublicDefaultCtor { + get { + return ResourceManager.GetString("AssociativeTypeDoesNotHavePublicDefaultCtor", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Associative class must implement IDecomposable<,,> or IDecomposableRecursive<,,>. Class={0}.. + /// + internal static string AssociativeTypeDoesNotImplementInterface { + get { + return ResourceManager.GetString("AssociativeTypeDoesNotImplementInterface", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Associative class should implement only one IAssociativeRecursive interface. Class={0}.. + /// + internal static string AssociativeTypeImplementsTooManyInterfaces { + get { + return ResourceManager.GetString("AssociativeTypeImplementsTooManyInterfaces", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Associative class must be public. Class={0}.. + /// + internal static string AssociativeTypeMustBePublic { + get { + return ResourceManager.GetString("AssociativeTypeMustBePublic", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Associative class types must match the function that it decorates. Method={0}. Class={1}.. + /// + internal static string AssociativeTypesDoNotMatch { + get { + return ResourceManager.GetString("AssociativeTypesDoNotMatch", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to A LINQ to HPC query that is submitted should involve at least one operator.. + /// + internal static string AtLeastOneOperatorRequired { + get { + return ResourceManager.GetString("AtLeastOneOperatorRequired", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Attempt to read from a stream that was opened for writing.. + /// + internal static string AttemptToReadFromAWriteStream { + get { + return ResourceManager.GetString("AttemptToReadFromAWriteStream", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: The auto-generated LINQ to HPC vertex assembly was missing.. + /// + internal static string AutogeneratedAssemblyMissing { + get { + return ResourceManager.GetString("AutogeneratedAssemblyMissing", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Average: No elements. + /// + internal static string AverageNoElements { + get { + return ResourceManager.GetString("AverageNoElements", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The number of separators passed to AssumeRangePartition must equal src.nPartitions - 1. nRangeSeparators={0}, Expected={1}.. + /// + internal static string BadSeparatorCount { + get { + return ResourceManager.GetString("BadSeparatorCount", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Job cannot be submitted because either the client program or one of its dependencies is targetted for 32 bit execution ({0}). To correct this problem make sure your binaries are compiled as 'x64' or 'AnyCPU'.. + /// + internal static string Binaries32BitNotSupported { + get { + return ResourceManager.GetString("Binaries32BitNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The branch {0} of Fork is not used.. + /// + internal static string BranchOfForkNotUsed { + get { + return ResourceManager.GetString("BranchOfForkNotUsed", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error processing anonymous type in query.. + /// + internal static string BugInHandlingAnonymousClass { + get { + return ResourceManager.GetString("BugInHandlingAnonymousClass", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failed to access path {0}. + /// + internal static string CannotAccesFilePath { + get { + return ResourceManager.GetString("CannotAccesFilePath", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: Can't add Tee to a node with more than one outputs.. + /// + internal static string CannotAddTeeToNode { + get { + return ResourceManager.GetString("CannotAddTeeToNode", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: cannot attach a pipeline. + /// + internal static string CannotAttach { + get { + return ResourceManager.GetString("CannotAttach", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: cannot be empty. + /// + internal static string CannotBeEmpty { + get { + return ResourceManager.GetString("CannotBeEmpty", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: Can't be used for reference type.. + /// + internal static string CannotBeUsedForReferenceType { + get { + return ResourceManager.GetString("CannotBeUsedForReferenceType", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: Can't call on PartitionInfo of type {0}. + /// + internal static string CannotCallPartitionInfoOnType { + get { + return ResourceManager.GetString("CannotCallPartitionInfoOnType", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Can't concat two datasets with different compression schemes.. + /// + internal static string CannotConcatDatasetsWithDifferentCompression { + get { + return ResourceManager.GetString("CannotConcatDatasetsWithDifferentCompression", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: Can't create partition node based on random partition.. + /// + internal static string CannotCreatePartitionNodeRandom { + get { + return ResourceManager.GetString("CannotCreatePartitionNodeRandom", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Can't create multiple tables with different compression schemes.. + /// + internal static string CannotCreateTablesWithDifferentCompression { + get { + return ResourceManager.GetString("CannotCreateTablesWithDifferentCompression", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Auto-serialization is not supported for types containing circular references. Type = {0}.. + /// + internal static string CannotHandleCircularTypes { + get { + return ResourceManager.GetString("CannotHandleCircularTypes", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Auto-serialization is not supported for type {0} because it derives from the non-primitive type {1}. Consider using a custom serializer for {0}. Please note this auto-serialization rule may have exceptions for built-in types. Please see product documentation for details.. + /// + internal static string CannotHandleDerivedtypes { + get { + return ResourceManager.GetString("CannotHandleDerivedtypes", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Auto-serialization is not supported for types containing fields of System.Object, System.Object[] or other collections of System.Object. Type = {0}.. + /// + internal static string CannotHandleObjectFields { + get { + return ResourceManager.GetString("CannotHandleObjectFields", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Auto-serialization is not supported for type that has subtypes or derives from a non-primitive type. Type = {0}.. + /// + internal static string CannotHandleSubtypes { + get { + return ResourceManager.GetString("CannotHandleSubtypes", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: Can't have more than one output channel.. + /// + internal static string CannotHaveMoreThanOneOutput { + get { + return ResourceManager.GetString("CannotHaveMoreThanOneOutput", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Cannot read query plan for job: {0}.. + /// + internal static string CannotReadQueryPlan { + get { + return ResourceManager.GetString("CannotReadQueryPlan", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: cannot rebuild optimized query expression for node.. + /// + internal static string CannotRebuildOptimizedQueryExpression { + get { + return ResourceManager.GetString("CannotRebuildOptimizedQueryExpression", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: Cannot reset this IEnumerator.. + /// + internal static string CannotResetIEnumerator { + get { + return ResourceManager.GetString("CannotResetIEnumerator", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to HpcLinqQuery IQueryable objects cannot be added to object store.. + /// + internal static string CannotSerializeHpcLinqQuery { + get { + return ResourceManager.GetString("CannotSerializeHpcLinqQuery", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Cannot serialize object store due to non-serializable object. Type = {0}.. + /// + internal static string CannotSerializeObject { + get { + return ResourceManager.GetString("CannotSerializeObject", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to A LINQ to HPC channel cannot be read more than once. For example, the delegate in Apply() may only enumerate its input once.. + /// + internal static string ChannelCannotBeReadMoreThanOnce { + get { + return ResourceManager.GetString("ChannelCannotBeReadMoreThanOnce", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The ClusterName configuration must be set to the name of a Windows HPC Server head node.. + /// + internal static string ClusterNameMustBeSpecified { + get { + return ResourceManager.GetString("ClusterNameMustBeSpecified", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Argument is not a LineRecord.. + /// + internal static string CompareArgIncorrect { + get { + return ResourceManager.GetString("CompareArgIncorrect", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to If a comparer expression is not provided, TElement must implement IEquatable or override both Equals() and GetHashCode(). TElement={0}.. + /// + internal static string ComparerExpressionMustBeSpecifiedOrElementTypeMustBeIEquatable { + get { + return ResourceManager.GetString("ComparerExpressionMustBeSpecifiedOrElementTypeMustBeIEquatable", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to If a key-comparer is not provided, TKey must implement IComparable. TKey={0}.. + /// + internal static string ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable { + get { + return ResourceManager.GetString("ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to If a key-comparer is not provided, TKey must override GetHashCode() and either implement IEquatable or override Equals(). TKey={0}.. + /// + internal static string ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable { + get { + return ResourceManager.GetString("ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The configuration object is read-only.. + /// + internal static string ConfigReadonly { + get { + return ResourceManager.GetString("ConfigReadonly", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The HpcLinqContext has been disposed.. + /// + internal static string ContextDisposed { + get { + return ResourceManager.GetString("ContextDisposed", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Could not get version information for HpcLinq client component. See inner exception.. + /// + internal static string CouldNotGetClientVersion { + get { + return ResourceManager.GetString("CouldNotGetClientVersion", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Could not get version information for HpcLinq server component. See inner exception.. + /// + internal static string CouldNotGetServerVersion { + get { + return ResourceManager.GetString("CouldNotGetServerVersion", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Error creating DSC data from local debug mode. See inner exception for details.. + /// + internal static string CreatingDscDataFromLocalDebugFailed { + get { + return ResourceManager.GetString("CreatingDscDataFromLocalDebugFailed", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Custom serializer type must either be a class or a struct that implements IHpcSerializer({1}). Type = {0}. + /// + internal static string CustomSerializerMustBeClassOrStruct { + get { + return ResourceManager.GetString("CustomSerializerMustBeClassOrStruct", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Custom serializer type must have a public default constructor. Type = {0}.. + /// + internal static string CustomSerializerMustSupportDefaultCtor { + get { + return ResourceManager.GetString("CustomSerializerMustSupportDefaultCtor", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Decomposition class must have a public parameterless constructor. Class={0}.. + /// + internal static string DecomposerTypeDoesNotHavePublicDefaultCtor { + get { + return ResourceManager.GetString("DecomposerTypeDoesNotHavePublicDefaultCtor", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Decomposition class must implement IDecomposable<,,> or IDecomposableRecursive<,,>. Class={0}.. + /// + internal static string DecomposerTypeDoesNotImplementInterface { + get { + return ResourceManager.GetString("DecomposerTypeDoesNotImplementInterface", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Decomposition class should implement only one decomposable interface. Class={0}.. + /// + internal static string DecomposerTypeImplementsTooManyInterfaces { + get { + return ResourceManager.GetString("DecomposerTypeImplementsTooManyInterfaces", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Decomposition class must be public. Class={0}.. + /// + internal static string DecomposerTypeMustBePublic { + get { + return ResourceManager.GetString("DecomposerTypeMustBePublic", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Decomposition class types must match the function that it decorates. Method={0}. Class={1}.. + /// + internal static string DecomposerTypesDoNotMatch { + get { + return ResourceManager.GetString("DecomposerTypesDoNotMatch", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The job did not complete successfully. Refer to HPC Cluster Manager and/or HPC APIs for more detail.. + /// + internal static string DidNotCompleteSuccessfully { + get { + return ResourceManager.GetString("DidNotCompleteSuccessfully", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to DistinctAttribute: Comparer {0} is not defined.. + /// + internal static string DistinctAttributeComparerNotDefined { + get { + return ResourceManager.GetString("DistinctAttributeComparerNotDefined", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal Error: The Distinct operator can only take at most 2 arguments.. + /// + internal static string DistinctOnlyTakesTwoArgs { + get { + return ResourceManager.GetString("DistinctOnlyTakesTwoArgs", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to DryadLINQ requires the DRYAD_HOME environment variable to be set to the Dryad binary folder.. + /// + internal static string DryadHomeMustBeSpecified { + get { + return ResourceManager.GetString("DryadHomeMustBeSpecified", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to DSC fileset error: {0}.. + /// + internal static string DSCStreamError { + get { + return ResourceManager.GetString("DSCStreamError", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: Cannot have dynamic manager of type {0}.. + /// + internal static string DynamicManagerType { + get { + return ResourceManager.GetString("DynamicManagerType", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Read failure: End of stream encountered while reading {0}. Data may be corrupt or does not match the file set compression scheme.. + /// + internal static string EndOfStreamEncountered { + get { + return ResourceManager.GetString("EndOfStreamEncountered", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Error reading metadata.. + /// + internal static string ErrorReadingMetadata { + get { + return ResourceManager.GetString("ErrorReadingMetadata", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Error writing metadata.. + /// + internal static string ErrorWritingMetadata { + get { + return ResourceManager.GetString("ErrorWritingMetadata", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The expression must be a method call: [{0}].. + /// + internal static string ExpressionMustBeMethodCall { + get { + return ResourceManager.GetString("ExpressionMustBeMethodCall", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to {0} cannot handle expression of type {1}.. + /// + internal static string ExpressionTypeNotHandled { + get { + return ResourceManager.GetString("ExpressionTypeNotHandled", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failed to allocate a new native data block of size {0}.. + /// + internal static string FailedToAllocateNewNativeBuffer { + get { + return ResourceManager.GetString("FailedToAllocateNewNativeBuffer", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failed to build {0}. See the client side log ({1}) for compilation error messages.. + /// + internal static string FailedToBuild { + get { + return ResourceManager.GetString("FailedToBuild", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failed to create DSC fileset: {0}.. + /// + internal static string FailedToCreateStream { + get { + return ResourceManager.GetString("FailedToCreateStream", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failed to deserialize object from object store.. + /// + internal static string FailedToDeserialize { + get { + return ResourceManager.GetString("FailedToDeserialize", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failed to get read paths for DSC fileset {0}.. + /// + internal static string FailedToGetReadPathsForStream { + get { + return ResourceManager.GetString("FailedToGetReadPathsForStream", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failed to get stream properties for file set {0}.. + /// + internal static string FailedToGetStreamProps { + get { + return ResourceManager.GetString("FailedToGetStreamProps", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Native channel failed to read from input channel at port {0}. Win32 error = {1}.. + /// + internal static string FailedToReadFromInputChannel { + get { + return ResourceManager.GetString("FailedToReadFromInputChannel", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: Failed to remove the Merge node.. + /// + internal static string FailedToRemoveMergeNode { + get { + return ResourceManager.GetString("FailedToRemoveMergeNode", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Native channel failed to write to output channel at port {0}.. + /// + internal static string FailedToWriteToOutputChannel { + get { + return ResourceManager.GetString("FailedToWriteToOutputChannel", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failure in Distinct.. + /// + internal static string FailureInDistinct { + get { + return ResourceManager.GetString("FailureInDistinct", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failure during Except.. + /// + internal static string FailureInExcept { + get { + return ResourceManager.GetString("FailureInExcept", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failure in hash based GroupBy.. + /// + internal static string FailureInHashGroupBy { + get { + return ResourceManager.GetString("FailureInHashGroupBy", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failure in hash based GroupJoin.. + /// + internal static string FailureInHashGroupJoin { + get { + return ResourceManager.GetString("FailureInHashGroupJoin", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failure in hash based Join.. + /// + internal static string FailureInHashJoin { + get { + return ResourceManager.GetString("FailureInHashJoin", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failure during Intersect.. + /// + internal static string FailureInIntersect { + get { + return ResourceManager.GetString("FailureInIntersect", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failure in {0}.. + /// + internal static string FailureInOperator { + get { + return ResourceManager.GetString("FailureInOperator", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failure in ordered GroupBy.. + /// + internal static string FailureInOrderedGroupBy { + get { + return ResourceManager.GetString("FailureInOrderedGroupBy", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failure during sorting.. + /// + internal static string FailureInSort { + get { + return ResourceManager.GetString("FailureInSort", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failure in sort based GroupBy.. + /// + internal static string FailureInSortGroupBy { + get { + return ResourceManager.GetString("FailureInSortGroupBy", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Apply: Failure in user-defined function.. + /// + internal static string FailureInUserApplyFunction { + get { + return ResourceManager.GetString("FailureInUserApplyFunction", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Could not open FileSet and/or read its properties.. + /// + internal static string FileSetCouldNotBeOpened { + get { + return ResourceManager.GetString("FileSetCouldNotBeOpened", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to FileSet must be sealed before calling FromDsc().. + /// + internal static string FileSetMustBeSealed { + get { + return ResourceManager.GetString("FileSetMustBeSealed", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to FileSet must have at least one file.. + /// + internal static string FileSetMustHaveAtLeastOneFile { + get { + return ResourceManager.GetString("FileSetMustHaveAtLeastOneFile", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Fingerprint was disabled.. + /// + internal static string FingerprintDisabled { + get { + return ResourceManager.GetString("FingerprintDisabled", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to First: No elements.. + /// + internal static string FirstNoElementsFirst { + get { + return ResourceManager.GetString("FirstNoElementsFirst", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Error serializing object store. See inner exception.. + /// + internal static string GeneralSerializeFailure { + get { + return ResourceManager.GetString("GeneralSerializeFailure", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Error returned from GetFileSizeEx: {0}.. + /// + internal static string GetFileSizeError { + get { + return ResourceManager.GetString("GetFileSizeError", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to GetFP() is not implemented for this stream type.. + /// + internal static string GetFPNotSupported { + get { + return ResourceManager.GetString("GetFPNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to GetURI() is not implemented for this stream type.. + /// + internal static string GetURINotSupported { + get { + return ResourceManager.GetString("GetURINotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to All inputs to homomorphic apply must have same number of input partitions. + /// + internal static string HomomorphicApplyNeedsSamePartitionCount { + get { + return ResourceManager.GetString("HomomorphicApplyNeedsSamePartitionCount", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to JobMinNodes must be greater than 1.. + /// + internal static string HpcLinqJobMinMustBe2OrMore { + get { + return ResourceManager.GetString("HpcLinqJobMinMustBe2OrMore", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The collection is read-only.. + /// + internal static string HpcLinqStringDictionaryReadonly { + get { + return ResourceManager.GetString("HpcLinqStringDictionaryReadonly", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: Illegal type of dynamic manager in dynamic node. + /// + internal static string IllegalDynamicManagerType { + get { + return ResourceManager.GetString("IllegalDynamicManagerType", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The URI is not well formed: {0}. + /// + internal static string IllFormedUri { + get { + return ResourceManager.GetString("IllFormedUri", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The arguments in the table URI are not well formed.. + /// + internal static string IllFormedUriArguments { + get { + return ResourceManager.GetString("IllFormedUriArguments", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Index out of range.. + /// + internal static string IndexOutOfRange { + get { + return ResourceManager.GetString("IndexOutOfRange", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Index overflowed range of Int32.. + /// + internal static string IndexTooSmall { + get { + return ResourceManager.GetString("IndexTooSmall", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: InputArity must equal Children.Length.. + /// + internal static string InputArityMustEqualChildren { + get { + return ResourceManager.GetString("InputArityMustEqualChildren", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The input expression must be a LINQ to HPC source.. + /// + internal static string InputMustBeHpcLinqSource { + get { + return ResourceManager.GetString("InputMustBeHpcLinqSource", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Input data type cannot be an anonymous type.. + /// + internal static string InputTypeCannotBeAnonymous { + get { + return ResourceManager.GetString("InputTypeCannotBeAnonymous", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: Can't be used for value type.. + /// + internal static string Internal_CannotBeUsedForValueType { + get { + return ResourceManager.GetString("Internal_CannotBeUsedForValueType", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The arguments 'partitionKeys' and 'isDescending' are inconsistent.. + /// + internal static string IsDescendingIsInconsistent { + get { + return ResourceManager.GetString("IsDescendingIsInconsistent", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to JobOption({0}, {1}) not implemented.. + /// + internal static string JobOptionNotImplemented { + get { + return ResourceManager.GetString("JobOptionNotImplemented", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Repeated server error when querying job status.. + /// + internal static string JobStatusQueryError { + get { + return ResourceManager.GetString("JobStatusQueryError", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The job to create this HpcLinqQuery(T) failed with error: {0}.. + /// + internal static string JobToCreateTableFailed { + get { + return ResourceManager.GetString("JobToCreateTableFailed", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The job to create this HpcLinqQuery(T) was canceled by the user.. + /// + internal static string JobToCreateTableWasCanceled { + get { + return ResourceManager.GetString("JobToCreateTableWasCanceled", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Key not found in collection.. + /// + internal static string KeyNotFound { + get { + return ResourceManager.GetString("KeyNotFound", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Last: No elements.. + /// + internal static string LastNoElements { + get { + return ResourceManager.GetString("LastNoElements", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Max: No elements. + /// + internal static string MaxNoElements { + get { + return ResourceManager.GetString("MaxNoElements", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to RecordType does not match with file set metadata. Use FromDsc<T> with matching T, or an overload that supresses type-check. RecordType={0}, DscStream.recordType={1}.. + /// + internal static string MetadataRecordType { + get { + return ResourceManager.GetString("MetadataRecordType", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Min: No elements. + /// + internal static string MinNoElements { + get { + return ResourceManager.GetString("MinNoElements", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Failed to access file {0}, which is part of DSC stream {1}. If this file has other replicas, they were all attempted but could not be accessed either.. + /// + internal static string MultiBlockCannotAccesFilePath { + get { + return ResourceManager.GetString("MultiBlockCannotAccesFilePath", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The partition list of the table was empty.. + /// + internal static string MultiBlockEmptyPartitionList { + get { + return ResourceManager.GetString("MultiBlockEmptyPartitionList", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Multiple query outputs are targeted to the same DSC fileset name {0}. + /// + internal static string MultipleOutputsWithSameDscUri { + get { + return ResourceManager.GetString("MultipleOutputsWithSameDscUri", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Key out of range.. + /// + internal static string MultiQueryableKeyOutOfRange { + get { + return ResourceManager.GetString("MultiQueryableKeyOutOfRange", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: Must specify the file name for the output assembly.. + /// + internal static string MustSpecifyOutputAssemblyFileName { + get { + return ResourceManager.GetString("MustSpecifyOutputAssemblyFileName", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The query must be created from a HpcLinqContext object and only use LINQ to HPC operators.. + /// + internal static string MustStartFromContext { + get { + return ResourceManager.GetString("MustStartFromContext", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: Negative length in memcopy.. + /// + internal static string NegativeLengthInMemcopy { + get { + return ResourceManager.GetString("NegativeLengthInMemcopy", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to sources[{0}] is not a LINQ to HPC query. A LINQ to HPC IQueryable should be created via an HpcLinqContext object and use only LINQ to HPC operators.. + /// + internal static string NotAHpcLinqQuery { + get { + return ResourceManager.GetString("NotAHpcLinqQuery", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to KeySelector function must be non-null.. + /// + internal static string NullKeySelector { + get { + return ResourceManager.GetString("NullKeySelector", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: The result and element selectors must be non-null. + /// + internal static string NullSelector { + get { + return ResourceManager.GetString("NullSelector", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Fork: The number of keys must match the number of output ports.. + /// + internal static string NumberOfKeysMustEqualNumOutputPorts { + get { + return ResourceManager.GetString("NumberOfKeysMustEqualNumOutputPorts", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The number of keys must be equal to PartitionCount-1.. + /// + internal static string NumKeys { + get { + return ResourceManager.GetString("NumKeys", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Property only available for physical data.. + /// + internal static string OnlyAvailableForPhysicalData { + get { + return ResourceManager.GetString("OnlyAvailableForPhysicalData", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: OpenForWrite called when fs was not null.. + /// + internal static string OpenForWriteError { + get { + return ResourceManager.GetString("OpenForWriteError", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The operator '{0}' encountered in expression isn't a valid LINQ to HPC operator.. + /// + internal static string OperatorNotSupported { + get { + return ResourceManager.GetString("OperatorNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Output data type cannot be an anonymous type.. + /// + internal static string OutputTypeCannotBeAnonymous { + get { + return ResourceManager.GetString("OutputTypeCannotBeAnonymous", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The output DSC fileset name {0} is also used as the query source or as one of the referenced data sources.. + /// + internal static string OutputUriAlsoQueryInput { + get { + return ResourceManager.GetString("OutputUriAlsoQueryInput", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The partition count must be greater than 0.. + /// + internal static string PartitionCountMustBePositive { + get { + return ResourceManager.GetString("PartitionCountMustBePositive", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The return value of partitionFunc exceeded the number of ports.. + /// + internal static string PartitionFuncReturnValueExceedsNumPorts { + get { + return ResourceManager.GetString("PartitionFuncReturnValueExceedsNumPorts", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The provided range-partition keys are not consistenly ascending or descending.. + /// + internal static string PartitionKeysAreNotConsistentlyOrdered { + get { + return ResourceManager.GetString("PartitionKeysAreNotConsistentlyOrdered", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The partition keys are not provided.. + /// + internal static string PartitionKeysNotProvided { + get { + return ResourceManager.GetString("PartitionKeysNotProvided", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Each partition needs at least {0} records for this sliding window computation.. + /// + internal static string PartitionTooSmallForSlidingWindow { + get { + return ResourceManager.GetString("PartitionTooSmallForSlidingWindow", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Position is not supported.. + /// + internal static string PositionNotSupported { + get { + return ResourceManager.GetString("PositionNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Prefix {0} has already been used for another provider.. + /// + internal static string PrefixAlreadyUsedForOtherProvider { + get { + return ResourceManager.GetString("PrefixAlreadyUsedForOtherProvider", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to RangePartition: partition keys and output channels mismatch. There were {0} keys and {1} channels.. + /// + internal static string RangePartitionInputOutputMismatch { + get { + return ResourceManager.GetString("RangePartitionInputOutputMismatch", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to RangePartition: partition keys are missing.. + /// + internal static string RangePartitionKeysMissing { + get { + return ResourceManager.GetString("RangePartitionKeysMissing", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Native channel error while reading from file. Win32 error code = {0}.. + /// + internal static string ReadFileError { + get { + return ResourceManager.GetString("ReadFileError", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Read is not supported.. + /// + internal static string ReadNotAllowed { + get { + return ResourceManager.GetString("ReadNotAllowed", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to ReadWrite access is not supported.. + /// + internal static string ReadWriteNotSupported { + get { + return ResourceManager.GetString("ReadWriteNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The maximum record size is 2GB.. + /// + internal static string RecordSizeMax2GB { + get { + return ResourceManager.GetString("RecordSizeMax2GB", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to A query instance appears more than once in 'sources'. A query instance should be submitted at most once. To submit a query more than once, create a new IQueryable<> instance.. + /// + internal static string SameQuerySubmittedMultipleTimesInMaterialize { + get { + return ResourceManager.GetString("SameQuerySubmittedMultipleTimesInMaterialize", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Seek is not supported.. + /// + internal static string SeekNotSupported { + get { + return ResourceManager.GetString("SeekNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to SequenceEqual() is not supported.. + /// + internal static string SequenceEqualNotSupported { + get { + return ResourceManager.GetString("SequenceEqualNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to CustomHpcSerializer attribute for type {0} does not define a SerializerType.. + /// + internal static string SerializerTypeMustBeNonNull { + get { + return ResourceManager.GetString("SerializerTypeMustBeNonNull", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Type {0} referred to by CustomHpcSerializer attribute must implement IHpcSerializer({1}).. + /// + internal static string SerializerTypeMustSupportIHpcSerializer { + get { + return ResourceManager.GetString("SerializerTypeMustSupportIHpcSerializer", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to SetCalcFP() is not implemented for this stream type.. + /// + internal static string SetCalcFPNotSupported { + get { + return ResourceManager.GetString("SetCalcFPNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to SetLength is not supported.. + /// + internal static string SetLengthNotSupported { + get { + return ResourceManager.GetString("SetLengthNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Position_Set is not supported.. + /// + internal static string SettingPositionNotSupported { + get { + return ResourceManager.GetString("SettingPositionNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: should not create vertex code for Concat.. + /// + internal static string ShouldNotCreateCodeForConcat { + get { + return ResourceManager.GetString("ShouldNotCreateCodeForConcat", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: should not create vertex code for a dummy node.. + /// + internal static string ShouldNotCreateCodeForDummyNode { + get { + return ResourceManager.GetString("ShouldNotCreateCodeForDummyNode", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: should not create vertex code for input.. + /// + internal static string ShouldNotCreateCodeForInput { + get { + return ResourceManager.GetString("ShouldNotCreateCodeForInput", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: should not create vertex code for output.. + /// + internal static string ShouldNotCreateCodeForOutput { + get { + return ResourceManager.GetString("ShouldNotCreateCodeForOutput", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: should not create vertex code for Tee.. + /// + internal static string ShouldNotCreateCodeForTee { + get { + return ResourceManager.GetString("ShouldNotCreateCodeForTee", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: Should never call Reset(). + /// + internal static string ShouldNototCallReset { + get { + return ResourceManager.GetString("ShouldNototCallReset", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Single: More than one element.. + /// + internal static string SingleMoreThanOneElement { + get { + return ResourceManager.GetString("SingleMoreThanOneElement", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Single: No elements.. + /// + internal static string SingleNoElements { + get { + return ResourceManager.GetString("SingleNoElements", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: Sorted chunk cannot be empty.. + /// + internal static string SortedChunkCannotBeEmpty { + get { + return ResourceManager.GetString("SortedChunkCannotBeEmpty", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: source must be DryadVertexReader.. + /// + internal static string SourceMustBeDryadVertexReader { + get { + return ResourceManager.GetString("SourceMustBeDryadVertexReader", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Source is not ordered.. + /// + internal static string SourceNotOrdered { + get { + return ResourceManager.GetString("SourceNotOrdered", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The source for mergesort must be of type IMultiEnumerable.. + /// + internal static string SourceOfMergesortMustBeMultiEnumerable { + get { + return ResourceManager.GetString("SourceOfMergesortMustBeMultiEnumerable", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to DSC fileset already exists: {0}.. + /// + internal static string StreamAlreadyExists { + get { + return ResourceManager.GetString("StreamAlreadyExists", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The DSC fileset {0} doesn't exist.. + /// + internal static string StreamDoesNotExist { + get { + return ResourceManager.GetString("StreamDoesNotExist", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Error submitting job to {0}. Refer to inner exception for more detail.. + /// + internal static string SubmissionFailure { + get { + return ResourceManager.GetString("SubmissionFailure", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Only DSC URIs are supported.. + /// + internal static string TargetMustBeDscUri { + get { + return ResourceManager.GetString("TargetMustBeDscUri", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to ThenBy is not supported.. + /// + internal static string ThenByNotSupported { + get { + return ResourceManager.GetString("ThenByNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Queries ending in ToDsc() cannot be followed by operators other than Submit().. + /// + internal static string ToDscUsedIncorrectly { + get { + return ResourceManager.GetString("ToDscUsedIncorrectly", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Queries ending in ToHdfs() cannot be followed by operators other than Submit().. + /// + internal static string ToHdfsUsedIncorrectly { + get { + return ResourceManager.GetString("ToHdfsUsedIncorrectly", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: Too many elements before reduction.. + /// + internal static string TooManyElementsBeforeReduction { + get { + return ResourceManager.GetString("TooManyElementsBeforeReduction", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to A method should not be tagged with both HomomorphicAttribute and LeftHomomorphicAttribute.. + /// + internal static string TooManyHomomorphicAttributes { + get { + return ResourceManager.GetString("TooManyHomomorphicAttributes", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Too many items in collection.. + /// + internal static string TooManyItems { + get { + return ResourceManager.GetString("TooManyItems", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: {0} doesn't contain field/property {1}. + /// + internal static string TypeDoesNotContainMember { + get { + return ResourceManager.GetString("TypeDoesNotContainMember", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The type {0} does not contain a field named {1}. + /// + internal static string TypeDoesNotContainRequestedField { + get { + return ResourceManager.GetString("TypeDoesNotContainRequestedField", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Auto-serialized types must have at least one data member. Type={0}.. + /// + internal static string TypeMustHaveDataMembers { + get { + return ResourceManager.GetString("TypeMustHaveDataMembers", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Cannot auto-serialize a type containing pointers. Type = {0}. . + /// + internal static string TypeNotSerializable { + get { + return ResourceManager.GetString("TypeNotSerializable", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Cannot auto-serialize a type that is not public. Type = {0}.. + /// + internal static string TypeRequiredToBePublic { + get { + return ResourceManager.GetString("TypeRequiredToBePublic", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Auto-serialization is not supported for type with private field. Type = {0}.. + /// + internal static string UDTHasFieldOfNonPublicType { + get { + return ResourceManager.GetString("UDTHasFieldOfNonPublicType", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Auto-serialization is not supported for delegate type. Type = {0}.. + /// + internal static string UDTIsDelegateType { + get { + return ResourceManager.GetString("UDTIsDelegateType", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Auto-serialization is not supported for type that is an abstract type or has generic arguments which are abstract types. Type = {0}.. + /// + internal static string UDTMustBeConcreteType { + get { + return ResourceManager.GetString("UDTMustBeConcreteType", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Unexpected job status: {0}.. + /// + internal static string UnexpectedJobStatus { + get { + return ResourceManager.GetString("UnexpectedJobStatus", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: LINQ to HPC query expression cannot be handled : {0}.. + /// + internal static string UnhandledQuery { + get { + return ResourceManager.GetString("UnhandledQuery", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Unknown channel kind: {0}.. + /// + internal static string UnknownChannelType { + get { + return ResourceManager.GetString("UnknownChannelType", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Unknown channel of type: {0}.. + /// + internal static string UnknownChannelType2 { + get { + return ResourceManager.GetString("UnknownChannelType2", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Unknown compression scheme.. + /// + internal static string UnknownCompressionScheme { + get { + return ResourceManager.GetString("UnknownCompressionScheme", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Unknown connection type: {0}.. + /// + internal static string UnknownConnectionType { + get { + return ResourceManager.GetString("UnknownConnectionType", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Unknown method in expression to summarize: {0}.. + /// + internal static string UnknownMethodInExpression { + get { + return ResourceManager.GetString("UnknownMethodInExpression", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Provider {0} is unknown. Register it using DataProvider.Register.. + /// + internal static string UnknownProvier { + get { + return ResourceManager.GetString("UnknownProvier", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: This parameter expression didn't have a name.. + /// + internal static string UnnamedParameterExpression { + get { + return ResourceManager.GetString("UnnamedParameterExpression", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Unrecognized data source: {0}.. + /// + internal static string UnrecognizedDataSource { + get { + return ResourceManager.GetString("UnrecognizedDataSource", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Unrecognized operator name: {0}.. + /// + internal static string UnrecognizedOperatorName { + get { + return ResourceManager.GetString("UnrecognizedOperatorName", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Unexpected execution kind.. + /// + internal static string UnsupportedExecutionKind { + get { + return ResourceManager.GetString("UnsupportedExecutionKind", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Expression of type {0} is not supported.. + /// + internal static string UnsupportedExpressionsType { + get { + return ResourceManager.GetString("UnsupportedExpressionsType", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Expression of type {0} is not supported for expression-summarization.. + /// + internal static string UnsupportedExpressionType { + get { + return ResourceManager.GetString("UnsupportedExpressionType", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Scheduler type not supported: {0}.. + /// + internal static string UnsupportedSchedulerType { + get { + return ResourceManager.GetString("UnsupportedSchedulerType", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The non-generic methods CreateQuery() and Execute() are not supported. Use CreateQuery<T>() and Execute<T>() instead.. + /// + internal static string UntypedProviderMethodsNotSupported { + get { + return ResourceManager.GetString("UntypedProviderMethodsNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to VertexBridge arguments are malformed. argsString={0}. + /// + internal static string VertexBridgeBadArgs { + get { + return ResourceManager.GetString("VertexBridgeBadArgs", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The window size must be greater than 1.. + /// + internal static string WindowSizeMustyBeGTOne { + get { + return ResourceManager.GetString("WindowSizeMustyBeGTOne", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to WriteByte is not supported.. + /// + internal static string WriteByteNotSupported { + get { + return ResourceManager.GetString("WriteByteNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Error returned from WriteFile: {0}.. + /// + internal static string WriteFileError { + get { + return ResourceManager.GetString("WriteFileError", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Write is not supported.. + /// + internal static string WriteNotSupported { + get { + return ResourceManager.GetString("WriteNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Internal error: Wrong combination of flags.. + /// + internal static string WrongFlagCombination { + get { + return ResourceManager.GetString("WrongFlagCombination", resourceCulture); + } + } + } +} diff --git a/LinqToDryad/SR.resx b/LinqToDryad/SR.resx new file mode 100644 index 0000000..96c13a1 --- /dev/null +++ b/LinqToDryad/SR.resx @@ -0,0 +1,768 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + DistinctAttribute: Comparer {0} is not defined. + + + CustomHpcSerializer attribute for type {0} does not define a SerializerType. + + + Type {0} referred to by CustomHpcSerializer attribute must implement IHpcSerializer({1}). + + + The HpcLinqContext has been disposed. + + + Could not open FileSet and/or read its properties. + + + FileSet must be sealed before calling FromDsc(). + + + FileSet must have at least one file. + + + Could not get version information for HpcLinq client component. See inner exception. + + + Could not get version information for HpcLinq server component. See inner exception. + + + Internal error: Can't be used for value type. + + + The type {0} does not contain a field named {1} + + + Internal error: Can't be used for reference type. + + + The URI is not well formed: {0} + + + Prefix {0} has already been used for another provider. + + + Provider {0} is unknown. Register it using DataProvider.Register. + + + The arguments in the table URI are not well formed. + + + Internal error: Can't call on PartitionInfo of type {0} + + + Only DSC URIs are supported. + + + Internal error: Can't create partition node based on random partition. + + + The partition keys are not provided. + + + The provided range-partition keys are not consistenly ascending or descending. + + + The arguments 'partitionKeys' and 'isDescending' are inconsistent. + + + Read failure: End of stream encountered while reading {0}. Data may be corrupt or does not match the file set compression scheme. + + + Fingerprint was disabled. + + + The maximum record size is 2GB. + + + Position_Set is not supported. + + + Read is not supported. + + + Seek is not supported. + + + SetLength is not supported. + + + Length of array {0} must be greater than or equal to {1} ({2} + {3}). + + + Cannot auto-serialize a type that is not public. Type = {0}. + + + Auto-serialized types must have at least one data member. Type={0}. + + + Custom serializer type must have a public default constructor. Type = {0}. + + + Custom serializer type must either be a class or a struct that implements IHpcSerializer({1}). Type = {0} + + + Cannot auto-serialize a type containing pointers. Type = {0}. + + + Auto-serialization is not supported for type that has subtypes or derives from a non-primitive type. Type = {0}. + + + Auto-serialization is not supported for types containing circular references. Type = {0}. + + + Auto-serialization is not supported for type {0} because it derives from the non-primitive type {1}. Consider using a custom serializer for {0}. Please note this auto-serialization rule may have exceptions for built-in types. Please see product documentation for details. + + + Auto-serialization is not supported for type that is an abstract type or has generic arguments which are abstract types. Type = {0}. + + + Auto-serialization is not supported for type with private field. Type = {0}. + + + Auto-serialization is not supported for types containing fields of System.Object, System.Object[] or other collections of System.Object. Type = {0}. + + + Auto-serialization is not supported for delegate type. Type = {0}. + + + Internal error: AddVertexMethod on {0} not handled. + + + Internal error: cannot be empty + + + Internal error: Must specify the file name for the output assembly. + + + Failed to build {0}. See the client side log ({1}) for compilation error messages. + + + Internal error: The auto-generated LINQ to HPC vertex assembly was missing. + + + Key not found in collection. + + + Too many items in collection. + + + Internal error: Too many elements before reduction. + + + DryadLINQ requires the DRYAD_HOME environment variable to be set to the Dryad binary folder. + + + The ClusterName configuration must be set to the name of a Windows HPC Server head node. + + + The configuration object is read-only. + + + The collection is read-only. + + + Internal error: {0} doesn't contain field/property {1} + + + Unrecognized operator name: {0}. + + + Internal error processing anonymous type in query. + + + Expression of type {0} is not supported. + + + Internal error: This parameter expression didn't have a name. + + + Expression of type {0} is not supported for expression-summarization. + + + Source is not ordered. + + + The partition count must be greater than 0. + + + The window size must be greater than 1. + + + Each partition needs at least {0} records for this sliding window computation. + + + Failed to access path {0} + + + Error returned from GetFileSizeEx: {0}. + + + Native channel error while reading from file. Win32 error code = {0}. + + + Unknown compression scheme. + + + Error returned from WriteFile: {0}. + + + Index overflowed range of Int32. + + + Key out of range. + + + Index out of range. + + + sources[{0}] is not a LINQ to HPC query. A LINQ to HPC IQueryable should be created via an HpcLinqContext object and use only LINQ to HPC operators. + + + A LINQ to HPC query that is submitted should involve at least one operator. + + + Queries ending in ToDsc() cannot be followed by operators other than Submit(). + + + Scheduler type not supported: {0}. + + + Unexpected execution kind. + + + Unexpected job status: {0}. + + + Repeated server error when querying job status. + + + JobOption({0}, {1}) not implemented. + + + JobMinNodes must be greater than 1. + + + Error submitting job to {0}. Refer to inner exception for more detail. + + + The job did not complete successfully. Refer to HPC Cluster Manager and/or HPC APIs for more detail. + + + Job cannot be submitted because either the client program or one of its dependencies is targetted for 32 bit execution ({0}). To correct this problem make sure your binaries are compiled as 'x64' or 'AnyCPU'. + + + Error reading metadata. + + + Error writing metadata. + + + HpcLinqQuery IQueryable objects cannot be added to object store. + + + Cannot serialize object store due to non-serializable object. Type = {0}. + + + Error serializing object store. See inner exception. + + + Failed to deserialize object from object store. + + + The expression must be a method call: [{0}]. + + + The query must be created from a HpcLinqContext object and only use LINQ to HPC operators. + + + The non-generic methods CreateQuery() and Execute() are not supported. Use CreateQuery<T>() and Execute<T>() instead. + + + SequenceEqual() is not supported. + + + This query instance has already been submitted. To submit a query more than once, create a new IQueryable<> instance. + + + A query instance has already been submitted. To submit a query more than once, create a new IQueryable<> instance. + + + A query instance appears more than once in 'sources'. A query instance should be submitted at most once. To submit a query more than once, create a new IQueryable<> instance. + + + Position is not supported. + + + Write is not supported. + + + WriteByte is not supported. + + + Internal error: Negative length in memcopy. + + + The source for mergesort must be of type IMultiEnumerable. + + + ThenBy is not supported. + + + Internal error: Wrong combination of flags. + + + Aggregate: No elements. + + + First: No elements. + + + Single: More than one element. + + + Single: No elements. + + + Last: No elements. + + + Min: No elements + + + Max: No elements + + + Average: No elements + + + Internal error: source must be DryadVertexReader. + + + RangePartition: partition keys are missing. + + + The return value of partitionFunc exceeded the number of ports. + + + Fork: The number of keys must match the number of output ports. + + + The branch {0} of Fork is not used. + + + Internal error: The result and element selectors must be non-null + + + Internal error: Cannot reset this IEnumerator. + + + Failure during Except. + + + Failure during Intersect. + + + Failure during sorting. + + + Internal error: Sorted chunk cannot be empty. + + + RangePartition: partition keys and output channels mismatch. There were {0} keys and {1} channels. + + + Failure in hash based GroupBy. + + + Failure in sort based GroupBy. + + + Failure in hash based Join. + + + Failure in hash based GroupJoin. + + + Failure in Distinct. + + + Failure in {0}. + + + Failure in ordered GroupBy. + + + Apply: Failure in user-defined function. + + + VertexBridge arguments are malformed. argsString={0} + + + Unknown channel kind: {0}. + + + Cannot read query plan for job: {0}. + + + Unknown connection type: {0}. + + + Unknown channel of type: {0}. + + + Unknown method in expression to summarize: {0}. + + + The input expression must be a LINQ to HPC source. + + + Output data type cannot be an anonymous type. + + + Input data type cannot be an anonymous type. + + + Decomposition class must be public. Class={0}. + + + Decomposition class must implement IDecomposable<,,> or IDecomposableRecursive<,,>. Class={0}. + + + Decomposition class types must match the function that it decorates. Method={0}. Class={1}. + + + Decomposition class should implement only one decomposable interface. Class={0}. + + + Decomposition class must have a public parameterless constructor. Class={0}. + + + A method tagged [Associative] should take two parameters of type T and return type T. Method={0}. + + + Associative class must be public. Class={0}. + + + Associative class must implement IDecomposable<,,> or IDecomposableRecursive<,,>. Class={0}. + + + Associative class types must match the function that it decorates. Method={0}. Class={1}. + + + Associative class should implement only one IAssociativeRecursive interface. Class={0}. + + + Associative class must have a public parameterless constructor. Class={0}. + + + Internal error: cannot rebuild optimized query expression for node. + + + Internal error: InputArity must equal Children.Length. + + + Internal Error: The Distinct operator can only take at most 2 arguments. + + + If a key-comparer is not provided, TKey must implement IComparable. TKey={0}. + + + If a key-comparer is not provided, TKey must override GetHashCode() and either implement IEquatable or override Equals(). TKey={0}. + + + If a comparer expression is not provided, TElement must implement IEquatable or override both Equals() and GetHashCode(). TElement={0}. + + + A method should not be tagged with both HomomorphicAttribute and LeftHomomorphicAttribute. + + + All inputs to homomorphic apply must have same number of input partitions + + + Unrecognized data source: {0}. + + + The operator '{0}' encountered in expression isn't a valid LINQ to HPC operator. + + + Aggregation operator '{0}' can only work on objects that implement IComparable. + + + The number of separators passed to AssumeRangePartition must equal src.nPartitions - 1. nRangeSeparators={0}, Expected={1}. + + + Multiple query outputs are targeted to the same DSC fileset name {0} + + + The output DSC fileset name {0} is also used as the query source or as one of the referenced data sources. + + + Can't concat two datasets with different compression schemes. + + + Can't create multiple tables with different compression schemes. + + + Internal error: Failed to remove the Merge node. + + + Internal error: cannot attach a pipeline + + + Internal error: Can't add Tee to a node with more than one outputs. + + + Internal error: should not create vertex code for a dummy node. + + + Internal error: should not create vertex code for input. + + + Internal error: should not create vertex code for output. + + + Internal error: should not create vertex code for Concat. + + + Internal error: should not create vertex code for Tee. + + + Internal error: Illegal type of dynamic manager in dynamic node + + + The aggregate operator {0} is not supported. + + + Internal error: Cannot have dynamic manager of type {0}. + + + A LINQ to HPC channel cannot be read more than once. For example, the delegate in Apply() may only enumerate its input once. + + + Internal error: Should never call Reset() + + + Internal error: Can't have more than one output channel. + + + DSC fileset error: {0}. + + + The DSC fileset {0} doesn't exist. + + + DSC fileset already exists: {0}. + + + Internal error: OpenForWrite called when fs was not null. + + + Attempt to read from a stream that was opened for writing. + + + Failed to create DSC fileset: {0}. + + + ReadWrite access is not supported. + + + {0} cannot handle expression of type {1}. + + + Internal error: LINQ to HPC query expression cannot be handled : {0}. + + + The partition list of the table was empty. + + + Failed to access file {0}, which is part of DSC stream {1}. If this file has other replicas, they were all attempted but could not be accessed either. + + + GetURI() is not implemented for this stream type. + + + SetCalcFP() is not implemented for this stream type. + + + GetFP() is not implemented for this stream type. + + + Failed to allocate a new native data block of size {0}. + + + Native channel failed to read from input channel at port {0}. Win32 error = {1}. + + + Native channel failed to write to output channel at port {0}. + + + Failed to get stream properties for file set {0}. + + + RecordType does not match with file set metadata. Use FromDsc<T> with matching T, or an overload that supresses type-check. RecordType={0}, DscStream.recordType={1}. + + + KeySelector function must be non-null. + + + The number of keys must be equal to PartitionCount-1. + + + The job to create this HpcLinqQuery(T) failed with error: {0}. + + + The job to create this HpcLinqQuery(T) was canceled by the user. + + + Failed to get read paths for DSC fileset {0}. + + + Property only available for physical data. + + + Error creating DSC data from local debug mode. See inner exception for details. + + + Argument is not a LineRecord. + + + Queries ending in ToHdfs() cannot be followed by operators other than Submit(). + + \ No newline at end of file diff --git a/LinqToDryad/SimpleRewriter.cs b/LinqToDryad/SimpleRewriter.cs new file mode 100644 index 0000000..0651c8a --- /dev/null +++ b/LinqToDryad/SimpleRewriter.cs @@ -0,0 +1,287 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Text; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.CodeDom; +using System.Xml; +using System.Diagnostics; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + internal class NodeInfoEdge + { + public QueryNodeInfo parent; + public QueryNodeInfo child; + + public NodeInfoEdge(QueryNodeInfo parent, QueryNodeInfo child) + { + this.parent = parent; + this.child = child; + } + + // Replace all occurences of oldEdge in edges by newEdge. + public static bool UpdateEdge(List edges, NodeInfoEdge oldEdge, NodeInfoEdge newEdge) + { + for (int i = 0; i < edges.Count; i++) + { + if (Object.ReferenceEquals(oldEdge, edges[i])) + { + edges[i] = newEdge; + return true; + } + } + return false; + } + + // Insert a node info on this edge. + public void Insert(QueryNodeInfo nextInfo) + { + Debug.Assert(nextInfo.children.Count == 0 && nextInfo.parents.Count == 0); + NodeInfoEdge edge1 = new NodeInfoEdge(this.parent, nextInfo); + NodeInfoEdge edge2 = new NodeInfoEdge(nextInfo, this.child); + UpdateEdge(this.parent.children, this, edge1); + nextInfo.parents.Add(edge1); + UpdateEdge(this.child.parents, this, edge2); + nextInfo.children.Add(edge2); + } + } + + internal class QueryNodeInfo + { + public Expression queryExpression; + public List children; + public List parents; + public bool isQueryOperator; + public DryadQueryNode queryNode; + + public QueryNodeInfo(Expression queryExpression, + bool isQueryOperator, + params QueryNodeInfo[] children) + { + this.queryExpression = queryExpression; + this.isQueryOperator = isQueryOperator; + this.children = new List(children.Length); + foreach (QueryNodeInfo childInfo in children) + { + NodeInfoEdge edge = new NodeInfoEdge(this, childInfo); + this.children.Add(edge); + childInfo.parents.Add(edge); + } + this.parents = new List(); + this.queryNode = null; + } + + public string OperatorName + { + get { + if (!this.isQueryOperator) return null; + return ((MethodCallExpression)this.queryExpression).Method.Name; + } + } + + public Type Type + { + get { return this.queryExpression.Type; } + } + + public bool IsForked + { + get { return this.parents.Count > 1; } + } + + public QueryNodeInfo Clone() + { + return new QueryNodeInfo(this.queryExpression, this.isQueryOperator); + } + + // Delete this NodeInfo. + // Precondition: this.children.Count < 2 + public void Delete() + { + Debug.Assert(this.children.Count < 2); + if (this.children.Count == 0) + { + foreach (NodeInfoEdge edge in this.parents) + { + edge.parent.children.Remove(edge); + } + } + else + { + QueryNodeInfo child = this.children[0].child; + child.parents.Remove(this.children[0]); + foreach (NodeInfoEdge edge in this.parents) + { + NodeInfoEdge newEdge = new NodeInfoEdge(edge.parent, child); + NodeInfoEdge.UpdateEdge(edge.parent.children, edge, newEdge); + child.parents.Add(newEdge); + } + } + + this.parents.Clear(); + this.children.Clear(); + } + + // Return true if this is not in the NodeInfo graph. + public bool IsOrphaned + { + get { return (this.children.Count == 0 && this.parents.Count == 0); } + } + + public void Swap(QueryNodeInfo other) + { + Debug.Assert(this.isQueryOperator && other.isQueryOperator); + Debug.Assert(this.queryNode == null && other.queryNode == null); + + Expression queryExpr = this.queryExpression; + this.queryExpression = other.queryExpression; + other.queryExpression = queryExpr; + } + } + + internal class SimpleRewriter + { + private List m_nodeInfos; + + public SimpleRewriter(List nodeInfos) + { + this.m_nodeInfos = nodeInfos; + } + + public void Rewrite() + { + bool isDone = false; + while (!isDone) + { + isDone = true; + int idx = 0; + while (idx < this.m_nodeInfos.Count) + { + if (this.m_nodeInfos[idx].IsOrphaned) + { + this.m_nodeInfos[idx] = this.m_nodeInfos[this.m_nodeInfos.Count - 1]; + this.m_nodeInfos.RemoveAt(this.m_nodeInfos.Count - 1); + } + else + { + bool changed = this.RewriteOne(idx); + isDone = isDone && !changed; + idx++; + } + } + } + } + + // Return true iff EPG is modified by this method. + public bool RewriteOne(int idx) + { + QueryNodeInfo curNode = this.m_nodeInfos[idx]; + if (curNode.OperatorName == "Where" && !curNode.children[0].child.IsForked) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(((MethodCallExpression)curNode.queryExpression).Arguments[1]); + if (lambda.Type.GetGenericArguments().Length == 2) + { + QueryNodeInfo child = curNode.children[0].child; + string[] names = new string[] { "OrderBy", "Distinct", "RangePartition", "HashPartition" }; + if (names.Contains(child.OperatorName)) + { + curNode.Swap(child); + return true; + } + if (child.OperatorName == "Concat") + { + curNode.Delete(); + for (int i = 0; i < child.children.Count; i++) + { + NodeInfoEdge edge = child.children[i]; + QueryNodeInfo node = curNode.Clone(); + this.m_nodeInfos.Add(node); + edge.Insert(node); + } + return true; + } + } + } + else if ((curNode.OperatorName == "Select" || curNode.OperatorName == "SelectMany") && + !curNode.children[0].child.IsForked) + { + LambdaExpression lambda = HpcLinqExpression.GetLambda(((MethodCallExpression)curNode.queryExpression).Arguments[1]); + if (lambda.Type.GetGenericArguments().Length == 2) + { + QueryNodeInfo child = curNode.children[0].child; + if (child.OperatorName == "Concat") + { + curNode.Delete(); + for (int i = 0; i < child.children.Count; i++) + { + NodeInfoEdge edge = child.children[i]; + QueryNodeInfo node = curNode.Clone(); + this.m_nodeInfos.Add(node); + edge.Insert(node); + } + return true; + } + } + } + else if (curNode.OperatorName == "Take" && !curNode.children[0].child.IsForked) + { + QueryNodeInfo child = curNode.children[0].child; + if (child.OperatorName == "Select") + { + QueryNodeInfo cchild = child.children[0].child; + if (cchild.OperatorName != "GroupBy") + { + curNode.Swap(child); + return true; + } + } + } + else if ((curNode.OperatorName == "Contains" || + curNode.OperatorName == "ContainsAsQuery" || + curNode.OperatorName == "All" || + curNode.OperatorName == "AllAsQuery" || + curNode.OperatorName == "Any" || + curNode.OperatorName == "AnyAsQuery") && + !curNode.children[0].child.IsForked) + { + QueryNodeInfo child = curNode.children[0].child; + string[] names = new string[] { "OrderBy", "Distinct", "RangePartition", "HashPartition" }; + if (names.Contains(child.OperatorName)) + { + child.Delete(); + return true; + } + } + return false; + } + } +} diff --git a/LinqToDryad/TypeSystem.cs b/LinqToDryad/TypeSystem.cs new file mode 100644 index 0000000..35ba7ce --- /dev/null +++ b/LinqToDryad/TypeSystem.cs @@ -0,0 +1,1393 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Reflection.Emit; +using System.Runtime.Serialization; +using System.Runtime.CompilerServices; +using System.Data.SqlTypes; +using System.Diagnostics; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + internal class TypeHelper + { + private static T instanceHandle = (T)System.Runtime.Serialization.FormatterServices.GetUninitializedObject(typeof(T)); + + // Return the instance of type T created as a handle + internal static T InstanceHandle + { + get { return instanceHandle; } + } + + // Create a new fresh instance of type T + internal static T Instance + { + get { return (T)System.Runtime.Serialization.FormatterServices.GetUninitializedObject(typeof(T)); } + } + } + + internal class IdentityFunction + { + internal static LambdaExpression Instance(Type type, string paramName) + { + ParameterExpression param = Expression.Parameter(type, paramName); + Type delegateType = typeof(Func<,>).MakeGenericType(type, type); + return Expression.Lambda(delegateType, param, param); + } + + internal static LambdaExpression Instance(Type type) + { + return IdentityFunction.Instance(type, "x"); + } + + internal static bool IsIdentity(LambdaExpression expr) + { + return (expr.Parameters.Count == 1 && + expr.Parameters[0] == expr.Body); + } + } + + internal static class TypeSystem + { + private static Dictionary s_sizeOfKnownTypes; + private static HashSet s_systemAssemblies; + + static TypeSystem() + { + // + // Add the sizes of built in types into s_sizeOfKnownTypes. Entries in this dictionary will be + // used by TypeSystem.GetSize() to compute the fixed size of a TRecord. + // TypeSystem.GetSize() returns -1 if the type's size isn't fixed due to a variable size field it contains. + // To support that, we add -1 entries in s_sizeOfKnownTypes for string and object. + // + s_sizeOfKnownTypes = new Dictionary(20); + s_sizeOfKnownTypes.Add(typeof(bool), sizeof(bool)); + s_sizeOfKnownTypes.Add(typeof(char), sizeof(char)); + s_sizeOfKnownTypes.Add(typeof(sbyte), sizeof(sbyte)); + s_sizeOfKnownTypes.Add(typeof(byte), sizeof(byte)); + s_sizeOfKnownTypes.Add(typeof(short), sizeof(Int16)); + s_sizeOfKnownTypes.Add(typeof(ushort), sizeof(UInt16)); + s_sizeOfKnownTypes.Add(typeof(int), sizeof(Int32)); + s_sizeOfKnownTypes.Add(typeof(uint), sizeof(UInt32)); + s_sizeOfKnownTypes.Add(typeof(long), sizeof(Int64)); + s_sizeOfKnownTypes.Add(typeof(ulong),sizeof(UInt64)); + s_sizeOfKnownTypes.Add(typeof(float), sizeof(float)); + s_sizeOfKnownTypes.Add(typeof(double), sizeof(double)); + s_sizeOfKnownTypes.Add(typeof(decimal), sizeof(decimal)); + s_sizeOfKnownTypes.Add(typeof(DateTime), sizeof(Int64)); + s_sizeOfKnownTypes.Add(typeof(SqlDateTime), sizeof(Int64)); + s_sizeOfKnownTypes.Add(typeof(string), -1); + s_sizeOfKnownTypes.Add(typeof(object), -1); + + s_systemAssemblies = new HashSet(); + s_systemAssemblies.Add("mscorlib"); + s_systemAssemblies.Add("System"); + s_systemAssemblies.Add("Accessibility"); + s_systemAssemblies.Add("SMDiagnostics"); + } + + internal static IEnumerable GetLoadedNonSystemAssemblyPaths() + { + List names = new List(); + foreach (Assembly asm in TypeSystem.GetAllAssemblies()) + { + if (!TypeSystem.IsSystemAssembly(asm)) + { + names.Add(asm.Location); + } + } + return names.ToArray(); + } + + internal static bool IsSystemAssembly(Assembly asm) + { + string name = asm.GetName().Name; + return (s_systemAssemblies.Contains(name) || + name.StartsWith("Microsoft.", StringComparison.Ordinal) || + name.StartsWith("System.", StringComparison.Ordinal) || + name == "WindowsBase"); + } + + internal static Type FindGenericType(Type definition, Type type) + { + while (type != null && type != typeof(object)) + { + if (type.IsGenericType && type.GetGenericTypeDefinition() == definition) + { + return type; + } + if (definition.IsInterface) + { + foreach (Type itype in type.GetInterfaces()) + { + Type found = FindGenericType(definition, itype); + if (found != null) + { + return found; + } + } + } + type = type.BaseType; + } + return null; + } + + internal static Type GetElementType(Type seqType) + { + Type ienumType = FindGenericType(typeof(IEnumerable<>), seqType); + if (ienumType == null) + { + return seqType; + } + return ienumType.GetGenericArguments()[0]; + } + + internal static bool IsSameOrSubclass(Type type, Type subType) + { + return (type == subType) || subType.IsSubclassOf(type); + } + + // This should be the only method used to find out the real type. + internal static bool IsRealType(Type type) + { + if (type.IsAbstract) + { + if (type.IsGenericType) + { + Type typedef = type.GetGenericTypeDefinition(); + if (typedef == typeof(IGrouping<,>) || + typedef == typeof(IEnumerable<>) || + typedef == typeof(IList<>)) + { + return type.GetGenericArguments().All(x => IsRealType(x)); + } + } + return false; + } + + return (!type.IsGenericType || + type.GetGenericArguments().All(x => IsRealType(x))); + + } + + internal static int GetInMemSize(Type type) + { + if (!type.IsValueType) return 8; + + int size = 0; + bool found = s_sizeOfKnownTypes.TryGetValue(type, out size); + if (found) return size; + + FieldInfo[] fields = type.GetFields(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); + foreach (FieldInfo field in fields) + { + size += GetInMemSize(field.FieldType); + } + return size; + } + + // Returns the size of the type if it has a fixed layout. + // Returns -1 if the type has a variable size (due to fields of string, object or arrays) + // Note that this is the number of bytes used by DryadLINQ serialization + internal static int GetSize(Type type) + { + return GetSize(type, new Dictionary()); + } + + private static int GetSize(Type type, Dictionary typeSizeMap) + { + int size = 0; + bool found = s_sizeOfKnownTypes.TryGetValue(type, out size); + if (found) return size; + found = typeSizeMap.TryGetValue(type, out size); + if (found) return size; + + if (type.IsArray) return -1; + + typeSizeMap[type] = -1; + size = 0; + FieldInfo[] fields = type.GetFields(BindingFlags.Instance|BindingFlags.Public|BindingFlags.NonPublic); + foreach (FieldInfo field in fields) + { + int fsize = GetSize(field.FieldType, typeSizeMap); + if (fsize == -1) return -1; + size += fsize; + } + typeSizeMap[type] = size; + return size; + } + + internal static bool StructContainsNoReference(Type type) + { + if (!type.IsValueType) return false; + + int size = 0; + bool found = s_sizeOfKnownTypes.TryGetValue(type, out size); + if (found) return true; + + FieldInfo[] fields = type.GetFields(BindingFlags.Instance|BindingFlags.Public|BindingFlags.NonPublic); + foreach (FieldInfo finfo in fields) + { + if (!StructContainsNoReference(finfo.FieldType)) return false; + } + return true; + } + + internal static bool ContainsNoLazyValue(Type type) + { + return ContainsNoLazyValue(type, new HashSet()); + } + + internal static bool ContainsNoLazyValue(Type type, HashSet seen) + { + while (type.IsArray) + { + type = type.GetElementType(); + } + if (type.IsPrimitive || + seen.Contains(type) || + type == typeof(System.Reflection.Pointer)) + { + return true; + } + + if (typeof(IEnumerable).IsAssignableFrom(type) && + !typeof(IList).IsAssignableFrom(type)) + { + return false; + } + + seen.Add(type); + FieldInfo[] finfos = TypeSystem.GetAllFields(type); + foreach (FieldInfo finfo in finfos) + { + if (!ContainsNoLazyValue(finfo.FieldType, seen)) + { + return false; + } + } + return true; + } + + // Returns true iff the method is a property. + internal static bool IsProperty(MethodInfo minfo) + { + if (minfo.Name.StartsWith("get_", StringComparison.Ordinal)) + { + Type declType = minfo.DeclaringType; + ParameterInfo[] paramInfos = minfo.GetParameters(); + + PropertyInfo[] pinfos = declType.GetProperties(); + foreach (var pinfo in pinfos) + { + MethodInfo[] accessors = pinfo.GetAccessors(); + foreach (MethodInfo ac in accessors) + { + if (ac.Name == minfo.Name) + { + ParameterInfo[] paramInfos1 = ac.GetParameters(); + if (paramInfos.Length != paramInfos1.Length) break; + for (int i = 0; i < paramInfos.Length; i++) + { + if (paramInfos[i].ParameterType != paramInfos1[i].ParameterType) break; + } + return true; + } + } + } + } + return false; + } + + /// + /// Compare two assemblies for equality. + /// + internal class AssemblyComparer : IEqualityComparer + { + public bool Equals(Assembly x, Assembly y) + { + // some Assembly objects loaded as reflection only may have different pointers + // what matters is that their fully qualified names match + + return ReferenceEquals(x, y) || x.FullName == y.FullName; + } + + public int GetHashCode(Assembly obj) + { + return obj.FullName.GetHashCode(); + } + } + + private static bool IsDynamicAssembly(Assembly asm) + { + if (asm is AssemblyBuilder) return true; + + try + { + if (asm.Location != null) return false; + } + catch + { + // if we get an exception from asm.Location it's a dynamic assembly. + return true; + } + + return false; + } + + private static HashSet s_allReferencedAssemblies = null; + + /// + /// Compute all referenced assemblies (transitive closure) and cache them. + /// + /// List of all referenced assemblies. + internal static HashSet GetAllAssemblies() + { + if (s_allReferencedAssemblies == null) + { + // compute transitive closure + HashSet referencedAssemblies = new HashSet(new AssemblyComparer()); + Queue toscan = new Queue(10); + + Assembly[] assemblies = AppDomain.CurrentDomain.GetAssemblies(); + foreach (Assembly asm in assemblies) + { + if (!referencedAssemblies.Contains(asm) && !IsDynamicAssembly(asm)) + { + toscan.Enqueue(asm); + referencedAssemblies.Add(asm); + } + } + + while (toscan.Count > 0) + { + Assembly asm = toscan.Dequeue(); + AssemblyName[] names = asm.GetReferencedAssemblies(); + foreach (AssemblyName asmName in names) + { + try + { + Assembly refAssembly = Assembly.ReflectionOnlyLoad(asmName.FullName); + if (!referencedAssemblies.Contains(refAssembly)) + { + toscan.Enqueue(refAssembly); + referencedAssemblies.Add(refAssembly); + } + } + catch (Exception) + { + // Console.WriteLine("Warning: Could not load referenced assembly " + asmName); + } + } + } + + // Due to the use of FullName in the ReflectionOnlyLoad call, we may end up with multiple versions of the same assembly in the list + // if the client is running against a newer .NET version than what Microsoft.Research.DryadLinq.DLL is compiled against + // We need to filter the list by selecting the newest version of each assembly + var newestAssemblies = referencedAssemblies.GroupBy(asm => asm.GetName().Name).Select(grp => grp.OrderByDescending(asm => asm.GetName().Version).First()); + s_allReferencedAssemblies = new HashSet(new AssemblyComparer()); + foreach (var asm in newestAssemblies) + { + s_allReferencedAssemblies.Add(asm); + } + } + return s_allReferencedAssemblies; + } + + private static Dictionary s_typeMap = null; + + internal static Dictionary BuildTypeHierarchy() + { + if (s_typeMap != null) return s_typeMap; + + Dictionary typeMap = new Dictionary(); + HashSet assemblies = TypeSystem.GetAllAssemblies(); + foreach (Assembly asm in assemblies) + { + foreach (Type type in asm.GetTypes()) + { + if (!type.IsInterface) + { + Type[] baseTypes = type.GetInterfaces(); + if (type.BaseType != null && type.BaseType != typeof(object)) + { + Type[] newBaseTypes = new Type[baseTypes.Length + 1]; + Array.Copy(baseTypes, newBaseTypes, baseTypes.Length); + newBaseTypes[baseTypes.Length] = type.BaseType; + baseTypes = newBaseTypes; + } + for (int i = 0; i < baseTypes.Length; i++) + { + Type baseType = baseTypes[i]; + if (baseType.IsGenericType) + { + baseType = baseType.GetGenericTypeDefinition(); + baseTypes[i] = baseType; + } + + bool isNew = true; + for (int j = 0; j < i; j++) + { + if (baseTypes[j] == baseType) + { + isNew = false; + break; + } + } + if (isNew) + { + Type[] deriveds = null; + if (typeMap.TryGetValue(baseType, out deriveds)) + { + Type[] newDeriveds = new Type[deriveds.Length + 1]; + Array.Copy(deriveds, newDeriveds, deriveds.Length); + newDeriveds[deriveds.Length] = type; + deriveds = newDeriveds; + } + else + { + deriveds = new Type[1] { type }; + } + typeMap[baseType] = deriveds; + } + } + } + } + } + s_typeMap = typeMap; + return s_typeMap; + } + + internal static bool HasSubtypes(Type type) + { + Type typeDef = type; + if (type.IsGenericType) + { + typeDef = type.GetGenericTypeDefinition(); + } + Dictionary typeMap = TypeSystem.BuildTypeHierarchy(); + return typeMap.ContainsKey(typeDef); + } + + internal static bool IsASubType(Type type) + { + return (type.BaseType != null && + type.BaseType != typeof(object) && + !(typeof(System.ValueType).IsAssignableFrom(type.BaseType)) && + !type.IsArray); + } + + internal static Type GetType(string name) + { + Type type = null; + Assembly callingAssembly = Assembly.GetCallingAssembly(); + if (callingAssembly != null) + { + type = callingAssembly.GetType(name); + if (type != null) return type; + } + Assembly executingAssembly = Assembly.GetExecutingAssembly(); + if (executingAssembly != null) + { + type = executingAssembly.GetType(name); + if (type != null) return type; + } + Assembly entryAssembly = Assembly.GetEntryAssembly(); + if (entryAssembly != null) + { + type = entryAssembly.GetType(name); + if (type != null) return type; + } + + foreach (Assembly asm in GetAllAssemblies()) + { + type = asm.GetType(name); + if (type != null) return type; + } + return type; + } + + // Get the current value of a static field + internal static object GetFieldValue(string name) + { + int idx = name.LastIndexOf('.'); + if (idx <= 0) + { + throw new ArgumentException("Internal: The argument is not a reference to static field"); + } + string tname = name.Substring(0, idx); + string fname = name.Substring(idx+1); + + Type type = TypeSystem.GetType(tname); + if (type == null) + { + throw new ArgumentException("Internal: The argument is not a reference to static field"); + } + FieldInfo finfo = type.GetField(fname); + if (finfo == null || !finfo.IsStatic) + { + throw new ArgumentException("Internal: The argument is not a reference to static field"); + } + return finfo.GetValue(null); + } + + internal static LambdaExpression GetExpression(string name) + { + if (name == null) return null; + object val = TypeSystem.GetFieldValue(name); + if (val == null) + { + throw new ArgumentException("Internal: The argument is not defined"); + } + if (!(val is LambdaExpression)) + { + throw new ArgumentException("Internal: The argument is not a lambda expression"); + } + return (LambdaExpression)val; + } + + internal static MethodInfo FindStaticMethod(Type type, string name, Type[] paramTypes, params Type[] genericTypeArgs) + { + MethodInfo[] methods = type.GetMethods(BindingFlags.Static | BindingFlags.Public); + foreach (MethodInfo minfo in methods) + { + if (minfo.Name == name) + { + MethodInfo matchedInfo = MatchArgs(minfo, paramTypes, genericTypeArgs); + if (matchedInfo != null) + { + return matchedInfo; + } + } + } + return null; + } + + internal static MethodInfo FindStaticMethod(string methodName, Type[] paramTypes) + { + int index = methodName.LastIndexOf('.'); + if (index == -1) return null; + + string className = methodName.Substring(0, index); + methodName = methodName.Substring(index + 1); + Type classType = TypeSystem.GetType(className); + if (classType == null) return null; + return TypeSystem.FindStaticMethod(classType, methodName, paramTypes); + } + + private static MethodInfo MatchArgs(MethodInfo minfo, Type[] paramTypes, Type[] genericTypeArgs) + { + ParameterInfo[] mParams = minfo.GetParameters(); + if (mParams.Length != paramTypes.Length) + { + return null; + } + if (!minfo.IsGenericMethodDefinition && minfo.IsGenericMethod && minfo.ContainsGenericParameters) + { + minfo = minfo.GetGenericMethodDefinition(); + } + if (minfo.IsGenericMethodDefinition) + { + if (genericTypeArgs == null || + genericTypeArgs.Length == 0 || + minfo.GetGenericArguments().Length != genericTypeArgs.Length) + { + return null; + } + minfo = minfo.MakeGenericMethod(genericTypeArgs); + mParams = minfo.GetParameters(); + } + else if (genericTypeArgs != null && genericTypeArgs.Length > 0) + { + return null; + } + for (int i = 0; i < paramTypes.Length; i++) + { + Type parameterType = mParams[i].ParameterType; + if (parameterType == null || + (paramTypes[i].IsByRef && parameterType != paramTypes[i]) || + !parameterType.IsAssignableFrom(paramTypes[i])) + { + return null; + } + } + return minfo; + } + + internal static string TypeName(Type type) + { + return TypeName(type, new Dictionary()); + } + + internal static string TypeName(Type type, Dictionary typeToName) + { + string name = null; + if (typeToName.TryGetValue(type, out name)) + { + return name; + } + + if (type.IsGenericParameter) + { + return type.Name; + } + + if (type.IsArray) + { + Type baseType = type.GetElementType(); + while (baseType.IsArray) + { + baseType = baseType.GetElementType(); + } + string tname = TypeName(baseType, typeToName); + + Type elemType = type; + do + { + string ranks = new string(',', (elemType.GetArrayRank() - 1)); + tname += "[" + ranks + "]"; + elemType = elemType.GetElementType(); + } + while (elemType.IsArray); + return tname; + } + + List nestedTypes = new List(); + nestedTypes.Add(type); + Type declaringType = type; + while (declaringType.IsNested) + { + declaringType = declaringType.DeclaringType; + nestedTypes.Add(declaringType); + } + + StringBuilder processedName = new StringBuilder(); + Type[] typeArgs = type.GetGenericArguments(); + int typeArgIndex = 0; + bool isFirst = true; + for (int i = nestedTypes.Count - 1; i >= 0; i--) + { + Type curType = nestedTypes[i]; + name = (curType.IsNested) ? curType.Name : curType.FullName; + + int lastIndex = name.IndexOf('`'); + if (lastIndex != -1) + { + name = name.Substring(0, lastIndex); + } + if (isFirst) + { + isFirst = false; + } + else + { + processedName.Append('.'); + } + processedName.Append(name); + + if (curType.IsGenericType) + { + int len = curType.GetGenericArguments().Length; + if (typeArgIndex < len) + { + processedName.Append('<'); + processedName.Append(TypeName(typeArgs[typeArgIndex], typeToName)); + typeArgIndex++; + while (typeArgIndex < len) + { + processedName.Append(','); + processedName.Append(TypeName(typeArgs[typeArgIndex], typeToName)); + typeArgIndex++; + } + processedName.Append('>'); + } + } + } + + return processedName.ToString(); + } + + internal static IComparer GetComparer(object comparer) + { + IComparer res = comparer as IComparer; + if (res == null && HasDefaultComparer(typeof(T))) + { + if (typeof(T) == typeof(string)) + { + res = (IComparer)StringComparer.Ordinal; + } + else + { + res = Comparer.Default; + } + } + return res; + } + + internal static bool IsComparer(object comparer, Type type) + { + Type comparerType = comparer.GetType(); + return typeof(IComparer<>).MakeGenericType(type).IsAssignableFrom(comparerType); + } + + internal static IEqualityComparer GetEqualityComparer(object comparer) + { + IEqualityComparer res = comparer as IEqualityComparer; + if (res == null && HasDefaultEqualityComparer(typeof(T))) + { + res = EqualityComparer.Default; + } + return res; + } + + internal static bool IsEqualityComparer(object comparer, Type type) + { + Type comparerType = comparer.GetType(); + return typeof(IEqualityComparer<>).MakeGenericType(type).IsAssignableFrom(comparerType); + } + + internal static bool HasDefaultComparer(Type type) + { + // true if T implements IComparable + if (typeof(IComparable<>).MakeGenericType(type).IsAssignableFrom(type)) + { + return true; + } + // true if T is a Nullable where U implements IComparable + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) + { + Type u = type.GetGenericArguments()[0]; + if (typeof(IComparable<>).MakeGenericType(u).IsAssignableFrom(u)) + { + return true; + } + } + // true if T implements IComparable + return typeof(IComparable).IsAssignableFrom(type); + } + + internal static bool HasDefaultEqualityComparer(Type type) + { + // true if T implements IEquatable + if (typeof(IEquatable<>).MakeGenericType(type).IsAssignableFrom(type)) + { + return HasOverrideHashCode(type); + } + // true if T is a Nullable where U implements IEquatable + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) + { + Type u = type.GetGenericArguments()[0]; + if (typeof(IEquatable<>).MakeGenericType(u).IsAssignableFrom(u)) + { + return HasOverrideHashCode(type); + } + } + // true if Equals is overridden in T + if (type.IsAbstract) + { + Dictionary typeMap = BuildTypeHierarchy(); + Type[] subtypes; + if (typeMap.TryGetValue(type, out subtypes)) + { + foreach (Type subtype in subtypes) + { + if (!HasDefaultEqualityComparer(subtype)) + { + return false; + } + } + } + return true; + } + MethodInfo minfo = type.GetMethod("Equals", new Type[] { typeof(object) }); + if (minfo.DeclaringType != type) + { + return false; + } + minfo = type.GetMethod("GetHashCode"); + if (minfo.DeclaringType != type) + { + return false; + } + if (type.IsGenericType) + { + Type[] typeArgs = type.GetGenericArguments(); + foreach (Type targ in typeArgs) + { + if (!HasDefaultEqualityComparer(targ)) + { + return false; + } + } + } + return true; + } + + private static bool HasOverrideHashCode(Type type) + { + if (type.IsAbstract) + { + Dictionary typeMap = BuildTypeHierarchy(); + Type[] subtypes; + if (typeMap.TryGetValue(type, out subtypes)) + { + foreach (Type subtype in subtypes) + { + if (!HasOverrideHashCode(subtype)) + { + return false; + } + } + } + } + MethodInfo minfo = type.GetMethod("GetHashCode"); + if (minfo.DeclaringType != type) + { + return false; + } + if (type.IsGenericType) + { + Type[] typeArgs = type.GetGenericArguments(); + foreach (Type targ in typeArgs) + { + if (!HasOverrideHashCode(targ)) + { + return false; + } + } + } + return true; + } + + internal static bool IsTypeSerializable(Type type) + { + return (!type.IsPointer && + type != typeof(IntPtr) && + !typeof(System.Delegate).IsAssignableFrom(type)); + } + + internal static bool IsFieldSerialized(FieldInfo finfo) + { + return (finfo.Attributes & FieldAttributes.NotSerialized) != FieldAttributes.NotSerialized; + } + + internal static Type GetNonserializable(object obj) + { + return GetNonserializable(obj, new HashSet(new ReferenceEqualityComparer())); + } + + internal static Type GetNonserializable(object obj, HashSet seen) + { + if (obj == null || seen.Contains(obj)) + { + return null; + } + + Type type = obj.GetType(); + if (type.IsPrimitive) + { + return null; + } + if (!type.IsSerializable) + { + return type; + } + Type[] argTypes = new Type[] { typeof(SerializationInfo), typeof(StreamingContext) }; + if (type.GetMethod("GetObjectData", argTypes) != null) + { + return null; + } + if (type.IsArray) + { + return GetTypeNonserializable(type); + } + seen.Add(obj); + FieldInfo[] fields = GetAllFields(type); + foreach (FieldInfo finfo in fields) + { + if ((finfo.Attributes & FieldAttributes.NotSerialized) != FieldAttributes.NotSerialized) + { + object fval = finfo.GetValue(obj); + if (!(fval is Pointer)) + { + type = GetNonserializable(fval, seen); + if (type != null) + { + return type; + } + } + } + } + return null; + } + + internal static Type GetTypeNonserializable(Type type) + { + return GetTypeNonserializable(type, new HashSet()); + } + + internal static Type GetTypeNonserializable(Type type, HashSet seen) + { + while (type.IsArray) + { + type = type.GetElementType(); + } + if (type.IsPrimitive || seen.Contains(type)) + { + return null; + } + if (!type.IsSerializable) + { + return type; + } + Type[] argTypes = new Type[] { typeof(SerializationInfo), typeof(StreamingContext) }; + if (type.GetMethod("GetObjectData", argTypes) != null) + { + return null; + } + + seen.Add(type); + FieldInfo[] fields = GetAllFields(type); + foreach (FieldInfo finfo in fields) + { + if ((finfo.Attributes & FieldAttributes.NotSerialized) != FieldAttributes.NotSerialized) + { + if (finfo.FieldType != typeof(System.Reflection.Pointer)) + { + type = GetTypeNonserializable(finfo.FieldType, seen); + if (type != null) + { + return type; + } + } + } + } + return null; + } + + internal static bool IsQueryOperatorCall(MethodCallExpression expression) + { + Type declType = expression.Method.DeclaringType; + return (declType == typeof(System.Linq.Enumerable) || + declType == typeof(System.Linq.Queryable) || + declType == typeof(Microsoft.Research.DryadLinq.HpcLinqQueryable)); + } + + internal static FieldInfo[] GetAllFields(Type type) + { + List res = new List(); + while (type != typeof(object) && type != null) + { + FieldInfo[] fields = type.GetFields(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); + foreach (FieldInfo finfo in fields) + { + if (finfo.DeclaringType == type) + { + res.Add(finfo); + } + } + type = type.BaseType; + } + return res.ToArray(); + } + + internal static bool HasFieldOfNonPublicType(Type type) + { + FieldInfo[] fields = GetAllFields(type); + foreach (FieldInfo finfo in fields) + { + if (!finfo.FieldType.IsPublic && !finfo.FieldType.IsNestedPublic) + { + return true; + } + } + return false; + } + + + // Checks whether the given type contains a field that refers to itself. Fields will be traversed recursively, and any field that referes back to + // the original type (either as a direct reference, or as an array reference) will cause the check to return true. Static fields are skipped. + // + // However types of the fields traversed will not be checked for circularity themselves. This is meant to allow the user to handle a circular UDT + // with a custom serializer, and still rely on autoserialization for types that contain it. + // See last example below for this case.. B is circular, and A contains B. If B has a custom serializer, A can still be autoserialized. + // + // But such "contained circularity" (e.g. B or C) doesn't escape the check in case there are no custom serializers involved, + // because codegen has to call IsCircular for each of those contained circular types before generating autoserializaton code for them. + // + // Here are some examples: + // + // class A { A m_f1;} => returns TRUE for class A + // class A { A[] m_f1;} => returns TRUE for class A + // class A { B m_f1;} class B { A m_f2;} => returns TRUE for both class A and class B (both are second level circular) + // class A { B m_f1;} class B { C m_f2;} class C { B m_f3;} => returns FALSE for class A (as neither B nor C contain references to A) + // => but returns TRUE for both class B and class C (both are second level circular) + internal static bool IsCircularType(Type type) + { + return DoCircularTypeCheck(type, type, null); + } + + // Returns true if "typeToCheck" contains a direct reference to "parentType", or an array of it + // The hashset visitedTypes is meant to prevent multiple visits to the same type within this recursion (allocated on demand) + private static bool DoCircularTypeCheck(Type typeToCheck, Type parentType, HashSet visitedTypes) + { + if (visitedTypes == null) + { + visitedTypes = new HashSet(); + } + + // if typeToCheck is in visitedTypes it means either we visited this type and haven't encountered the parent, or a visit is in progress + // In both cases it's OK to return false. If a visit is already in progress and typeToCheck is actually an offender, it will be caught in that callstack level which is performing the visit. + if (visitedTypes.Contains(typeToCheck)) + { + return false; + } + + // Before starting to visit all fields of "typeToCheck", we need to mark it as already visited to guard against infinite recursion. + visitedTypes.Add(typeToCheck); + + foreach (FieldInfo fi in typeToCheck.GetFields(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly)) + { + Type fieldType = fi.FieldType; + + // if this field is a reference to the parentType it means parentType is circular. Return true and unwind everyone. + if (fieldType == parentType || + (fieldType.IsArray && fieldType.GetElementType() == parentType) || + DoCircularTypeCheck(fieldType, parentType, visitedTypes) ) // Otherwise recurse for this field's type, keeping the same parentType because + // we are only interested in whether fieldType is circular wrt parentType (we don't want to check whether fieldType is circular wrt itself) + { + return true; + } + } + + return false; + } + + // Assume: t1.GetGenericTypeDefinition() == t2.GetGenericTypeDefinition() + internal static bool TypeMatch(Type t1, Type t2, Dictionary map) + { + Type[] typeArgs1 = t1.GetGenericArguments(); + Type[] typeArgs2 = t2.GetGenericArguments(); + + if (typeArgs1.Length != typeArgs2.Length) + { + return false; + } + for (int i = 0; i < typeArgs1.Length; i++) + { + if (typeArgs1[i].IsGenericParameter) + { + Type t; + if (map.TryGetValue(typeArgs1[i], out t)) + { + if (t != typeArgs2[i]) return false; + } + else + { + map[typeArgs1[i]] = typeArgs2[i]; + } + } + else if (typeArgs1[i].ContainsGenericParameters) + { + bool matched = TypeMatch(typeArgs1[i], typeArgs2[i], map); + if (!matched) return false; + } + else if (typeArgs1[i] != typeArgs2[i]) + { + return false; + } + } + return true; + } + + // Assume: type.IsGenericType + internal static Type MakeGenericInstance(Type type, Dictionary map) + { + Type[] typeArgs = type.GetGenericArguments(); + Type[] args = new Type[typeArgs.Length]; + for (int i = 0; i < args.Length; i++) + { + if (typeArgs[i].IsGenericParameter) + { + Type t; + if (map.TryGetValue(typeArgs[i], out t)) + { + args[i] = t; + } + else + { + args[i] = typeArgs[i]; + } + } + else if (typeArgs[i].ContainsGenericParameters) + { + args[i] = MakeGenericInstance(typeArgs[i], map); + } + else + { + args[i] = typeArgs[i]; + } + } + + // kluggy, need to rewrite. + try + { + return type.GetGenericTypeDefinition().MakeGenericType(args); + } + catch (ArgumentException) + { + return null; + } + } + + internal static MethodBase GetBaseMethod(Type baseType, MethodBase method) + { + if (method.IsStatic || method.IsConstructor) + { + return (method.DeclaringType == baseType) ? method : null; + } + Type[] argTypes = new Type[method.GetParameters().Length]; + for (int i = 0; i < argTypes.Length; i++) + { + argTypes[i] = method.GetParameters()[i].ParameterType; + } + + MethodInfo[] minfos = baseType.GetMethods(BindingFlags.Public|BindingFlags.NonPublic|BindingFlags.Instance); + foreach (MethodInfo minfo in minfos) + { + ParameterInfo[] mparams = minfo.GetParameters(); + if (minfo.Name == method.Name && mparams.Length == argTypes.Length) + { + bool found = true; + for (int i = 0; i < argTypes.Length; i++) + { + if (argTypes[i].IsGenericParameter) + { + if (!mparams[i].ParameterType.IsGenericParameter) return null; + } + else if (argTypes[i] != mparams[i].ParameterType) + { + found = false; + break; + } + } + if (found) return minfo; + } + } + return null; + } + + internal static List GetAllOverrides(Type baseType, MethodBase method, Dictionary typeMap) + { + List methods = new List(); + GetAllOverrides(baseType, method, typeMap, methods); + return methods; + } + + private static void GetAllOverrides(Type baseType, MethodBase method, Dictionary typeMap, List methods) + { + if (method.IsStatic || method.IsFinal || method.IsConstructor) + { + return; + } + + Type type = baseType; + if (type.IsGenericType) + { + type = type.GetGenericTypeDefinition(); + } + Type[] subtypes = null; + if (typeMap.TryGetValue(type, out subtypes)) + { + Type[] argTypes = new Type[method.GetParameters().Length]; + for (int i = 0; i < argTypes.Length; i++) + { + argTypes[i] = method.GetParameters()[i].ParameterType; + } + Type[] genericArgTypes = method.GetGenericArguments(); + foreach (Type subtype in subtypes) + { + // TBD: We could make subtype more precise by unifying with baseType + MethodInfo[] minfos = subtype.GetMethods(BindingFlags.Public|BindingFlags.NonPublic|BindingFlags.Instance|BindingFlags.DeclaredOnly); + foreach (MethodInfo minfo in minfos) + { + if (minfo.Name == method.Name && minfo.IsVirtual) + { + MethodInfo minfo1 = MatchAndInstantiate(minfo, argTypes, genericArgTypes); + if (minfo1 != null) + { + methods.Add(minfo1); + GetAllOverrides(subtype, minfo1, typeMap, methods); + } + } + } + } + } + } + + private static MethodInfo MatchAndInstantiate(MethodInfo minfo, Type[] argTypes, Type[] genericArgTypes) + { + ParameterInfo[] mParams = minfo.GetParameters(); + if (mParams.Length != argTypes.Length) + { + return null; + } + if (minfo.IsGenericMethod) + { + Type[] genericArgs = minfo.GetGenericArguments(); + if (genericArgs.Length != genericArgTypes.Length) + { + return null; + } + if (!minfo.IsGenericMethodDefinition) + { + bool hasGenericParameter = false; + foreach (Type arg in genericArgs) + { + if (arg.IsGenericParameter) + { + hasGenericParameter = true; + break; + } + } + if (hasGenericParameter) + { + minfo = minfo.GetGenericMethodDefinition(); + } + } + + // kluggy, need to rewrite. + try + { + minfo = minfo.MakeGenericMethod(genericArgTypes); + } + catch (ArgumentException) + { + return null; + } + mParams = minfo.GetParameters(); + } + for (int i = 0; i < argTypes.Length; i++) + { + Type paramType = mParams[i].ParameterType; + if (paramType.IsGenericParameter) + { + if (!argTypes[i].IsGenericParameter) + { + return null; + } + } + else if (paramType != argTypes[i]) + { + return null; + } + } + return minfo; + } + + internal static Type FindConstrainedType(Type type) + { + if (type.IsInterface) return null; + if (type.IsGenericParameter) + { + Type[] constraints = type.GetGenericParameterConstraints(); + foreach (Type ctype in constraints) + { + Type constrainedType = FindConstrainedType(ctype); + if (constrainedType != null) return constrainedType; + } + } + return type; + } + + // A hack to detect anonymous types + internal static bool IsAnonymousType(Type type) + { + return (Attribute.IsDefined(type, typeof(CompilerGeneratedAttribute), false) && + type.IsGenericType && type.Name.Contains("AnonymousType") && + (type.Name.StartsWith("<>", StringComparison.Ordinal) || type.Name.StartsWith("VB$", StringComparison.Ordinal)) && + (type.Attributes & TypeAttributes.NotPublic) == TypeAttributes.NotPublic); + } + + // A hack to detect transparent identifiers + internal static bool IsTransparentIdentifier(string name) + { + return name.StartsWith("<>h__TransparentIdentifier", StringComparison.Ordinal); + } + + // A hack to detect backing field names + internal static bool IsBackingField(string fieldName) + { + return fieldName.StartsWith("<", StringComparison.Ordinal) && fieldName.Contains("BackingField"); + } + + internal static bool ContainsAnonymousType(IEnumerable types) + { + foreach (Type type in types) + { + if (IsAnonymousType(type)) return true; + if (type.IsGenericType && + ContainsAnonymousType(type.GetGenericArguments())) + { + return true; + } + } + return false; + } + + internal static bool IsTypeOrAnyGenericParamsAnonymous(Type type) + { + if (IsAnonymousType(type)) + { + return true; + } + if (type.IsGenericType) + { + foreach (Type typeArg in type.GetGenericArguments()) + { + if (IsTypeOrAnyGenericParamsAnonymous(typeArg)) return true; + } + } + return false; + } + + internal static string FieldName(string fieldName) + { + if (!IsBackingField(fieldName)) + { + return fieldName; + } + int idx = fieldName.IndexOf(">", StringComparison.Ordinal); + if (idx == -1) idx = fieldName.Length; + return fieldName.Substring(1, idx - 1); + } + } +} diff --git a/LinqToDryad/VertexCodeGen.cs b/LinqToDryad/VertexCodeGen.cs new file mode 100644 index 0000000..b03ae16 --- /dev/null +++ b/LinqToDryad/VertexCodeGen.cs @@ -0,0 +1,356 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; +using System.IO; +using System.Reflection; +using System.Linq; +using System.Linq.Expressions; +using System.CodeDom; +using System.Diagnostics; +using System.Xml; +using System.Data.Linq.Mapping; +using System.Data.Linq; +using Microsoft.Research.DryadLinq.Internal; + +namespace Microsoft.Research.DryadLinq +{ + internal class VertexCodeGen + { + internal virtual IEnumerable GetResources() + { + return new string[0]; + } + + internal virtual string AddVertexCode(DryadQueryNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + switch (node.NodeType) + { + case QueryNodeType.InputTable: + { + return this.Visit((DryadInputNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.OutputTable: + { + return this.Visit((DryadOutputNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Aggregate: + { + return this.Visit((DryadAggregateNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Select: + case QueryNodeType.SelectMany: + { + return this.Visit((DryadSelectNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Where: + { + return this.Visit((DryadWhereNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Distinct: + { + return this.Visit((DryadDistinctNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.BasicAggregate: + { + return this.Visit((DryadBasicAggregateNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.GroupBy: + { + return this.Visit((DryadGroupByNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.OrderBy: + { + return this.Visit((DryadOrderByNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Skip: + case QueryNodeType.SkipWhile: + case QueryNodeType.Take: + case QueryNodeType.TakeWhile: + { + return this.Visit((DryadPartitionOpNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Contains: + { + return this.Visit((DryadContainsNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Join: + case QueryNodeType.GroupJoin: + { + return this.Visit((DryadJoinNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Union: + case QueryNodeType.Intersect: + case QueryNodeType.Except: + { + return this.Visit((DryadSetOperationNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Concat: + { + return this.Visit((DryadConcatNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Zip: + { + return this.Visit((DryadZipNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Super: + { + return this.Visit((DryadSuperNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.RangePartition: + { + return this.Visit((DryadRangePartitionNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.HashPartition: + { + return this.Visit((DryadHashPartitionNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Merge: + { + return this.Visit((DryadMergeNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Apply: + { + return this.Visit((DryadApplyNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Fork: + { + return this.Visit((DryadForkNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Tee: + { + return this.Visit((DryadTeeNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Dynamic: + { + return this.Visit((DryadDynamicNode)node, vertexMethod, readerNames, writerNames); + } + case QueryNodeType.Dummy: + { + return this.Visit((DryadDummyNode)node, vertexMethod, readerNames, writerNames); + } + default: + throw new DryadLinqException("Internal error: unhandled node type " + node.NodeType); + } + } + + internal virtual string Visit(DryadInputNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadOutputNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadWhereNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadSelectNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadZipNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadOrderByNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadGroupByNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadPartitionOpNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadJoinNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadDistinctNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadContainsNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadBasicAggregateNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadAggregateNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadConcatNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadSetOperationNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadMergeNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadHashPartitionNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadRangePartitionNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadSuperNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadApplyNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadForkNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadTeeNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadDynamicNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + internal virtual string Visit(DryadDummyNode node, + CodeMemberMethod vertexMethod, + string[] readerNames, + string[] writerNames) + { + return node.AddVertexCode(vertexMethod, readerNames, writerNames); + } + + } +} diff --git a/LinqToDryad/WebHdfsClient.cs b/LinqToDryad/WebHdfsClient.cs new file mode 100644 index 0000000..b711c3d --- /dev/null +++ b/LinqToDryad/WebHdfsClient.cs @@ -0,0 +1,93 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; +using System.Collections.Generic; + +using System.Linq; +using System.Net; +using System.Text.RegularExpressions; + +namespace Microsoft.Research.DryadLinq +{ + internal class WebHdfsClient + { + + internal void GetHdfsFile(string hdfsDir, string fileName) + { + if(!hdfsDir.EndsWith("/")) + { + hdfsDir = hdfsDir + "/"; + } + var hdfsDirUri = new Uri(hdfsDir, UriKind.Absolute); + var hdfsFileUri = new Uri(hdfsDirUri, fileName); + var builder = new UriBuilder(); + builder.Host = hdfsFileUri.DnsSafeHost; + builder.Port = 50070; //hdfsFileUri.Port; // ipc port is 9000, http port is 50070 TODO + builder.Path = "webhdfs/v1/" + hdfsFileUri.AbsolutePath.TrimStart('/'); + builder.Query = "op=OPEN"; + Console.WriteLine(builder.Uri); + var wc = new WebClient(); + wc.DownloadFile(builder.Uri, fileName); + + } + + internal static void GetContentSummary(string path, ref long estSize, ref int parCount) + { + // TODO: Move this to a sensible JSON parser. + var pathUri = new Uri(path, UriKind.Absolute); + var builder = new UriBuilder(); + builder.Host = pathUri.DnsSafeHost; + builder.Port = 50070; // pathUri.Port; // ipc port is 9000, http port is 50070 TODO + builder.Path = "webhdfs/v1/" + pathUri.AbsolutePath.TrimStart('/'); + builder.Query = "op=GETCONTENTSUMMARY"; + bool foundParCount = false; + bool foundEstSize = false; + + var wc = new WebClient(); + var data = wc.DownloadString(builder.Uri); + + var matches = Regex.Matches(data, "\"([^\"]+)\":([^,]+)"); + foreach(Match match in matches) + { + for(int ctr = 1; ctr <= match.Groups.Count - 1; ctr++) + { + if(match.Groups[ctr].Value == "fileCount") + { + parCount = int.Parse(match.Groups[ctr + 1].Value); + foundParCount = true; + ctr++; + } + else if(match.Groups[ctr].Value == "length") + { + estSize = long.Parse(match.Groups[ctr + 1].Value); + foundEstSize = true; + ctr++; + } + + } + } + if(!foundParCount || !foundEstSize) + { + throw new DryadLinqException("Unable to parse WebHdfs reponse."); + } + } + } +} diff --git a/LinqToDryad/YarnJobSubmission.cs b/LinqToDryad/YarnJobSubmission.cs new file mode 100644 index 0000000..55af965 --- /dev/null +++ b/LinqToDryad/YarnJobSubmission.cs @@ -0,0 +1,516 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// +// � Microsoft Corporation. All rights reserved. +// +#if REMOVE_FOR_YARN +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.IO.Compression; +using System.Linq; +using System.Net; +using System.Security.Principal; +using System.Text; +using System.Xml; +using Microsoft.Hpc.Scheduler; +using Microsoft.Hpc.Scheduler.Properties; +using Microsoft.Hpc.Dryad; +using Microsoft.Research.DryadLinq.Internal; +using System.Collections.Specialized; + +namespace Microsoft.Research.DryadLinq +{ + internal class HpcJobSubmission : IHpcLinqJobSubmission + { + private HpcLinqContext m_context; + private DryadJobSubmission m_job; + private JobStatus m_status; + + internal void Initialize() + { + this.m_job.FriendlyName = m_context.Configuration.JobFriendlyName; + + // if the user specified MinNodes and it is less than 2, return an error. Otherwise let job run with job template which + // must specify a value of 2 or higher + if (m_context.Configuration.JobMinNodes.HasValue && m_context.Configuration.JobMinNodes < 2) + { + throw new HpcLinqException(HpcLinqErrorCode.HpcLinqJobMinMustBe2OrMore, + SR.HpcLinqJobMinMustBe2OrMore); + } + + this.m_job.DryadJobMinNodes = m_context.Configuration.JobMinNodes; + this.m_job.DryadJobMaxNodes = m_context.Configuration.JobMaxNodes; + this.m_job.DryadNodeGroup = m_context.Configuration.NodeGroup; + this.m_job.DryadUserName = m_context.Configuration.JobUsername; + this.m_job.DryadPassword = m_context.Configuration.JobPassword; + this.m_job.DryadRuntime = m_context.Configuration.JobRuntimeLimit; + this.m_job.EnableSpeculativeDuplication = m_context.Configuration.EnableSpeculativeDuplication; + this.m_job.RuntimeTraceLevel = (int)m_context.Configuration.RuntimeTraceLevel; + this.m_job.GraphManagerNode = m_context.Configuration.GraphManagerNode; + + System.Collections.Specialized.NameValueCollection collection = new System.Collections.Specialized.NameValueCollection(); + + foreach (var keyValuePair in m_context.Configuration.JobEnvironmentVariables) + { + collection.Add(keyValuePair.Key, keyValuePair.Value); + } + + this.m_job.JobEnvironmentVariables = collection; + } + + internal bool LocalJM + { + get + { + return m_job.Type == DryadJobSubmission.JobType.Local; + } + set + { + if (value == true) + { + m_job.Type = DryadJobSubmission.JobType.Local; + } + else + { + m_job.Type = DryadJobSubmission.JobType.Cluster; + } + } + } + + internal string CommandLine + { + get + { + return m_job.CommandLine; + } + set + { + m_job.CommandLine = value; + } + } + + public string ErrorMsg + { + get + { + return m_job.ErrorMessage; + } + private set + { + m_job.ErrorMessage = value; + } + } + + internal HpcJobSubmission(HpcLinqContext context) + { + this.m_context = context; + this.m_status = JobStatus.NotSubmitted; + + //@@TODO[P0] pass the runtime to the DryadJobSubmission so that it can use the scheduler instance. + //@@TODO: Merge DryadJobSubmission into Ms.Hpc.Linq. Until then make sure Context is not disposed before DryadJobSubmission. + this.m_job = new DryadJobSubmission(m_context.GetIScheduler()); + } + + public void AddJobOption(string fieldName, string fieldVal) + { + if (fieldName == "cmdline") + { + m_job.CommandLine = fieldVal; + } + else + { + throw new HpcLinqException(HpcLinqErrorCode.JobOptionNotImplemented, + String.Format(SR.JobOptionNotImplemented, fieldName, fieldVal)); + } + } + + public void AddLocalFile(string fileName) + { + m_job.AddFileToJob(fileName); + } + + public void AddRemoteFile(string fileName) + { + string msg = String.Format("HpcJobSubmission.AddRemoteFile({0}) not implemented", fileName); + } + + public JobStatus GetStatus() + { + if (this.m_status == JobStatus.Success || + this.m_status == JobStatus.Failure ) + { + return this.m_status; + } + + if (this.m_job == null) + { + return JobStatus.NotSubmitted; + } + + switch (this.m_job.State) + { + case JobState.ExternalValidation: + case JobState.Queued: + case JobState.Submitted: + case JobState.Validating: + { + this.m_status = JobStatus.Waiting; + break; + } + case JobState.Configuring: + case JobState.Running: + case JobState.Canceling: + case JobState.Finishing: + { + this.m_status = JobStatus.Running; + break; + } + case JobState.Failed: + // a job only fails if the job manager fails. + { + ISchedulerCollection tasks = this.m_job.Job.GetTaskList(null, null, false); + if (tasks.Count < 1) + { + this.ErrorMsg = this.m_job.ErrorMessage; + this.m_status = JobStatus.Failure; + } + else + { + ISchedulerTask jm = tasks[0] as ISchedulerTask; + switch (jm.State) + { + case TaskState.Finished: + this.m_status = JobStatus.Success; + break; + default: + this.m_status = JobStatus.Failure; + this.ErrorMsg = "JM error: " + jm.ErrorMessage; + break; + } + } + break; + } + case JobState.Canceled: + { + this.ErrorMsg = this.m_job.ErrorMessage; + this.m_status = JobStatus.Failure; + break; + } + case JobState.Finished: + { + this.m_status = JobStatus.Success; + break; + } + } + + return this.m_status; + } + + public void SubmitJob() + { + // Verify that the head node is set + if (m_context.Configuration.HeadNode == null) + { + throw new HpcLinqException(HpcLinqErrorCode.ClusterNameMustBeSpecified, + SR.ClusterNameMustBeSpecified); + } + + try + { + this.m_job.SubmitJob(); + } + catch (Exception e) + { + throw new HpcLinqException(HpcLinqErrorCode.SubmissionFailure, + String.Format(SR.SubmissionFailure, m_context.Configuration.HeadNode), e); + } + } + + public JobStatus TerminateJob() + { + JobStatus status = GetStatus(); + switch (status) + { + case JobStatus.Failure: + case JobStatus.NotSubmitted: + case JobStatus.Success: + case JobStatus.Cancelled: + // Nothing to do. + return status; + default: + break; + } + + this.m_job.CancelJob(); + return JobStatus.Cancelled; + } + + public int GetJobId() + { + if (m_job == null || m_job.Job == null) + { + throw new InvalidOperationException("(internal) GetDryadJobInfo called when no job is available"); + } + return m_job.Job.Id; + } + } +} +#else +namespace Microsoft.Research.DryadLinq +{ + using System; + using System.Diagnostics; + using System.IO; + using System.Net; + using System.Xml.Linq; + using System.Linq; + using System.Text; + + internal class YarnJobSubmission : IHpcLinqJobSubmission + { + private HpcLinqContext m_context; + private JobStatus m_status; + private string m_applicationId; + private WebClient m_wc; + private string m_cmdLine; + private string m_queryPlan; + private string m_errorMsg; + + + public YarnJobSubmission(HpcLinqContext context) + { + m_context = context; + m_status = JobStatus.NotSubmitted; + m_wc = new WebClient(); + } + + public void AddJobOption(string fieldName, string fieldVal) + { + if(fieldName == "cmdline") + { + m_cmdLine = fieldVal; + var fields = m_cmdLine.Split(); + m_queryPlan = fields[fields.Length - 1].Trim(); + Console.WriteLine("QueryPlan: {0}", m_queryPlan); + } + } + + public void AddLocalFile(string fileName) + { + // do nothing for now + } + + public void AddRemoteFile(string fileName) + { + throw new System.NotImplementedException(); + } + + public string ErrorMsg + { + get + { + if (!String.IsNullOrEmpty(m_errorMsg)) + { + return m_errorMsg; + } + else + { + return "Unknown error running YARN query."; + } + } + } + + public JobStatus GetStatus() + { + if (m_status == JobStatus.Waiting || m_status == JobStatus.Running) + { + m_wc.Headers.Add("Accept", "application/xml"); + var xmlData = m_wc.DownloadString(GetRestServiceUri()); + ProcessXmlData(xmlData); + } + return m_status; + } + + private string BuildExpandedClasspath(HpcLinqConfiguration config) + { + String classPathString = System.Environment.GetEnvironmentVariable("classpath"); + //Console.WriteLine(classPathString); + + var fields = classPathString.Split(';'); + + StringBuilder sb = new StringBuilder(16384); + + var jarFiles = Directory.GetFiles(config.DryadHomeDirectory, "*.jar", System.IO.SearchOption.TopDirectoryOnly); + foreach (String file in jarFiles) + { + //Console.WriteLine("\t{0}", file); + sb.Append(file); + sb.Append(";"); + } + + foreach (String field in fields) + { + //Console.WriteLine(field); + if (!field.EndsWith("*")) + { + sb.Append(field); + sb.Append(";"); + } + else + { + var dirField = field.Substring(0, field.Length - 1); // trim the trailing * + jarFiles = Directory.GetFiles(dirField, "*.jar", System.IO.SearchOption.TopDirectoryOnly); + foreach (String file in jarFiles) + { + //Console.WriteLine(file); + sb.Append(file); + sb.Append(";"); + } + } + } + return sb.ToString(); + } + + public void SubmitJob() + { + // find the xml file, then invoke the java submission process + ProcessStartInfo psi = new ProcessStartInfo(); + + psi.FileName = Path.Combine(m_context.Configuration.YarnHomeDirectory, "bin", "yarn.cmd"); + string jarPath = Path.Combine(m_context.Configuration.DryadHomeDirectory, "DryadYarnBridge.jar"); + psi.Arguments = string.Format(@"jar {0} {1}", jarPath, m_queryPlan); + psi.EnvironmentVariables.Add("JNI_CLASSPATH", BuildExpandedClasspath(m_context.Configuration)); + + if (!psi.EnvironmentVariables.ContainsKey("DRYAD_HOME")) + { + //Console.WriteLine("Adding DRYAD_HOME env variable"); + psi.EnvironmentVariables.Add("DRYAD_HOME", m_context.Configuration.DryadHomeDirectory); + } + + psi.UseShellExecute = false; + psi.RedirectStandardOutput = true; + var process = Process.Start(psi); + + var procOutput = process.StandardOutput; + process.WaitForExit(); + // the java submission process will return the application id as the last line + while (!procOutput.EndOfStream) + { + m_applicationId = procOutput.ReadLine(); + } + Console.WriteLine("Application ID: {0}", m_applicationId); + m_status = JobStatus.Waiting; + } + + public JobStatus TerminateJob() + { + throw new System.NotImplementedException(); + } + + public int GetJobId() + { + int jobId = -1; + if (m_status != JobStatus.NotSubmitted) + { + var appNumberString = m_applicationId.Substring(m_applicationId.LastIndexOf('_') + 1); + return int.Parse(appNumberString); + } + return jobId; + } + + internal void Initialize() + { + // nothing needed for now + } + + internal Uri GetRestServiceUri() + { + UriBuilder builder = new UriBuilder(); + builder.Host = m_context.Configuration.HeadNode; + builder.Port = m_context.Configuration.HdfsNameNodeHttpPort; + builder.Path = "/ws/v1/cluster/apps/" + m_applicationId; + return builder.Uri; + } + + private void ProcessXmlData(string xmlData) + { + // for now, just pull state and finalStatus out of xml response + + //State: The application state according to the ResourceManager - + //valid values are: NEW, SUBMITTED, ACCEPTED, RUNNING, FINISHED, FAILED, KILLED + // + //finalStatus: The final status of the application if finished - + //reported by the application itself - valid values are: UNDEFINED, SUCCEEDED, FAILED, KILLED + + XDocument xdoc = XDocument.Parse(xmlData); + var stateString = xdoc.Descendants("state").Single().Value; + + switch (stateString) + { + case "NEW": + m_status = JobStatus.NotSubmitted; + break; + case "SUBMITTED": + case "ACCEPTED": + m_status = JobStatus.Waiting; + break; + case "RUNNING": + m_status = JobStatus.Running; + break; + case "FINISHED": + var finalStatusString = xdoc.Descendants("finalStatus").Single().Value; + switch (finalStatusString) { + case "UNDEFINED": + m_status = JobStatus.Success; + break; + case "SUCCEEDED": + m_status = JobStatus.Success; + break; + case "FAILED": + m_status = JobStatus.Failure; + break; + case "KILLED": + m_status = JobStatus.Cancelled; + break; + default: + throw new DryadLinqException("Unexpected finalStatus from Resource Manager"); + } + break; + case "FAILED": + m_status = JobStatus.Failure; + if (String.IsNullOrEmpty(m_errorMsg)) + { + m_errorMsg = xdoc.Descendants("diagnostics").Single().Value.Trim(); + } + break; + case "KILLED": + m_status = JobStatus.Cancelled; + break; + default: + throw new DryadLinqException("Unexpected status from Resource Manager"); + } + } + + } +} +#endif diff --git a/LinqToDryad/YarnScheduler.cs b/LinqToDryad/YarnScheduler.cs new file mode 100644 index 0000000..cbadc4e --- /dev/null +++ b/LinqToDryad/YarnScheduler.cs @@ -0,0 +1,48 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Microsoft.Research.DryadLinq +{ + internal class YarnScheduler : IScheduler + { + public string m_headNode; + + public void Connect(string headNode) + { + m_headNode = headNode; + } + + public void Dispose() + { + m_headNode = null; + } + + public IServerVersion GetServerVersion() + { + throw new NotImplementedException(); + } + + } +} diff --git a/LinqToDryad/sr.txt b/LinqToDryad/sr.txt new file mode 100644 index 0000000..1cae8c4 --- /dev/null +++ b/LinqToDryad/sr.txt @@ -0,0 +1,253 @@ +DistinctAttributeComparerNotDefined=DistinctAttribute: Comparer {0} is not defined. +SerializerTypeMustBeNonNull=CustomHpcSerializer attribute for type {0} does not define a SerializerType. +SerializerTypeMustSupportIHpcSerializer=Type {0} referred to by CustomHpcSerializer attribute must implement IHpcSerializer({1}). + +ContextDisposed=The HpcLinqContext has been disposed. +FileSetCouldNotBeOpened=Could not open FileSet and/or read its properties. +FileSetMustBeSealed=FileSet must be sealed before calling FromDsc(). +FileSetMustHaveAtLeastOneFile=FileSet must have at least one file. +CouldNotGetClientVersion=Could not get version information for HpcLinq client component. See inner exception. +CouldNotGetServerVersion=Could not get version information for HpcLinq server component. See inner exception. + +Internal_CannotBeUsedForValueType=Internal error: Can't be used for value type. +TypeDoesNotContainRequestedField=The type {0} does not contain a field named {1} +CannotBeUsedForReferenceType=Internal error: Can't be used for reference type. + +IllFormedUri=The URI is not well formed: {0} + +PrefixAlreadyUsedForOtherProvider=Prefix {0} has already been used for another provider. +UnknownProvier=Provider {0} is unknown. Register it using DataProvider.Register. +IllFormedUriArguments=The arguments in the table URI are not well formed. +CannotCallPartitionInfoOnType=Internal error: Can't call on PartitionInfo of type {0} +TargetMustBeDscUri=Only DSC URIs are supported. + +CannotCreatePartitionNodeRandom=Internal error: Can't create partition node based on random partition. +PartitionKeysNotProvided=The partition keys are not provided. +PartitionKeysAreNotConsistentlyOrdered=The provided range-partition keys are not consistenly ascending or descending. +IsDescendingIsInconsistent=The arguments 'partitionKeys' and 'isDescending' are inconsistent. + +EndOfStreamEncountered=Read failure: End of stream encountered while reading {0}. Data may be corrupt or does not match the file set compression scheme. + +FingerprintDisabled=Fingerprint was disabled. +RecordSizeMax2GB=The maximum record size is 2GB. +SettingPositionNotSupported=Position_Set is not supported. +ReadNotAllowed=Read is not supported. +SeekNotSupported=Seek is not supported. +SetLengthNotSupported=SetLength is not supported. + +ArrayLengthVsCountAndOffset=Length of array {0} must be greater than or equal to {1} ({2} + {3}). + +TypeRequiredToBePublic=Cannot auto-serialize a type that is not public. Type = {0}. +TypeMustHaveDataMembers=Auto-serialized types must have at least one data member. Type={0}. +CustomSerializerMustSupportDefaultCtor=Custom serializer type must have a public default constructor. Type = {0}. +CustomSerializerMustBeClassOrStruct=Custom serializer type must either be a class or a struct that implements IHpcSerializer({1}). Type = {0} +TypeNotSerializable=Cannot auto-serialize a type containing pointers. Type = {0}. +CannotHandleSubtypes=Auto-serialization is not supported for type that has subtypes or derives from a non-primitive type. Type = {0}. +CannotHandleCircularTypes=Auto-serialization is not supported for types containing circular references. Type = {0}. +CannotHandleDerivedtypes=Auto-serialization is not supported for type {0} because it derives from the non-primitive type {1}. Consider using a custom serializer for {0}. Please note this auto-serialization rule may have exceptions for built-in types. Please see product documentation for details. +UDTMustBeConcreteType=Auto-serialization is not supported for type that is an abstract type or has generic arguments which are abstract types. Type = {0}. +UDTHasFieldOfNonPublicType=Auto-serialization is not supported for type with private field. Type = {0}. +CannotHandleObjectFields=Auto-serialization is not supported for types containing fields of System.Object, System.Object[] or other collections of System.Object. Type = {0}. +UDTIsDelegateType=Auto-serialization is not supported for delegate type. Type = {0}. +AddVertexNotHandled=Internal error: AddVertexMethod on {0} not handled. +CannotBeEmpty=Internal error: cannot be empty +MustSpecifyOutputAssemblyFileName=Internal error: Must specify the file name for the output assembly. +FailedToBuild=Failed to build {0}. See the client side log ({1}) for compilation error messages. +AutogeneratedAssemblyMissing=Internal error: The auto-generated LINQ to HPC vertex assembly was missing. + +KeyNotFound=Key not found in collection. +TooManyItems=Too many items in collection. +TooManyElementsBeforeReduction=Internal error: Too many elements before reduction. + +CcpHomeMustBeSpecified=LINQ to HPC requires the CCP_HOME environment variable to be set to the Windows HPC Server product installation folder. +ClusterNameMustBeSpecified=The ClusterName configuration must be set to the name of a Windows HPC Server head node. +ConfigReadonly=The configuration object is read-only. +HpcLinqStringDictionaryReadonly=The collection is read-only. + +TypeDoesNotContainMember=Internal error: {0} doesn't contain field/property {1} +UnrecognizedOperatorName=Unrecognized operator name: {0}. +BugInHandlingAnonymousClass=Internal error processing anonymous type in query. +UnsupportedExpressionsType=Expression of type {0} is not supported. +UnnamedParameterExpression=Internal error: This parameter expression didn't have a name. +UnsupportedExpressionType=Expression of type {0} is not supported for expression-summarization. + +SourceNotOrdered=Source is not ordered. +PartitionCountMustBePositive=The partition count must be greater than 0. +WindowSizeMustyBeGTOne=The window size must be greater than 1. +PartitionTooSmallForSlidingWindow=Each partition needs at least {0} records for this sliding window computation. + +CannotAccesFilePath=Failed to access path {0} +GetFileSizeError=Error returned from GetFileSizeEx: {0}. +ReadFileError=Native channel error while reading from file. Win32 error code = {0}. +UnknownCompressionScheme=Unknown compression scheme. +WriteFileError=Error returned from WriteFile: {0}. + +IndexTooSmall=Index overflowed range of Int32. +MultiQueryableKeyOutOfRange=Key out of range. +IndexOutOfRange=Index out of range. +NotAHpcLinqQuery=sources[{0}] is not a LINQ to HPC query. A LINQ to HPC IQueryable should be created via an HpcLinqContext object and use only LINQ to HPC operators. +AtLeastOneOperatorRequired=A LINQ to HPC query that is submitted should involve at least one operator. +ToDscUsedIncorrectly=Queries ending in ToDsc() cannot be followed by operators other than Submit(). +ToHdfsUsedIncorrectly=Queries ending in ToHdfs() cannot be followed by operators other than Submit(). + +UnsupportedSchedulerType=Scheduler type not supported: {0}. +UnsupportedExecutionKind=Unexpected execution kind. +UnexpectedJobStatus=Unexpected job status: {0}. +JobStatusQueryError=Repeated server error when querying job status. +JobOptionNotImplemented=JobOption({0}, {1}) not implemented. +HpcLinqJobMinMustBe2OrMore=JobMinNodes must be greater than 1. +SubmissionFailure=Error submitting job to {0}. Refer to inner exception for more detail. +DidNotCompleteSuccessfully=The job did not complete successfully. Refer to HPC Cluster Manager and/or HPC APIs for more detail. +Binaries32BitNotSupported=Job cannot be submitted because either the client program or one of its dependencies is targetted for 32 bit execution ({0}). To correct this problem make sure your binaries are compiled as 'x64' or 'AnyCPU'. + +ErrorReadingMetadata=Error reading metadata. +ErrorWritingMetadata=Error writing metadata. + +CannotSerializeHpcLinqQuery=HpcLinqQuery IQueryable objects cannot be added to object store. +CannotSerializeObject=Cannot serialize object store due to non-serializable object. Type = {0}. +GeneralSerializeFailure=Error serializing object store. See inner exception. +FailedToDeserialize=Failed to deserialize object from object store. + +ExpressionMustBeMethodCall=The expression must be a method call: [{0}]. +MustStartFromContext=The query must be created from a HpcLinqContext object and only use LINQ to HPC operators. +UntypedProviderMethodsNotSupported=The non-generic methods CreateQuery() and Execute() are not supported. Use CreateQuery() and Execute() instead. +SequenceEqualNotSupported=SequenceEqual() is not supported. +AlreadySubmitted=This query instance has already been submitted. To submit a query more than once, create a new IQueryable<> instance. +AlreadySubmittedInMaterialize=A query instance has already been submitted. To submit a query more than once, create a new IQueryable<> instance. +SameQuerySubmittedMultipleTimesInMaterialize=A query instance appears more than once in 'sources'. A query instance should be submitted at most once. To submit a query more than once, create a new IQueryable<> instance. + +PositionNotSupported=Position is not supported. +WriteNotSupported=Write is not supported. +WriteByteNotSupported=WriteByte is not supported. + +NegativeLengthInMemcopy=Internal error: Negative length in memcopy. + +SourceOfMergesortMustBeMultiEnumerable=The source for mergesort must be of type IMultiEnumerable. +ThenByNotSupported=ThenBy is not supported. +WrongFlagCombination=Internal error: Wrong combination of flags. +AggregateNoElements=Aggregate: No elements. +FirstNoElementsFirst=First: No elements. +SingleMoreThanOneElement=Single: More than one element. +SingleNoElements=Single: No elements. +LastNoElements=Last: No elements. +MinNoElements=Min: No elements +MaxNoElements=Max: No elements +AverageNoElements=Average: No elements +SourceMustBeDryadVertexReader=Internal error: source must be DryadVertexReader. +RangePartitionKeysMissing=RangePartition: partition keys are missing. +PartitionFuncReturnValueExceedsNumPorts=The return value of partitionFunc exceeded the number of ports. +NumberOfKeysMustEqualNumOutputPorts=Fork: The number of keys must match the number of output ports. +BranchOfForkNotUsed=The branch {0} of Fork is not used. +NullSelector=Internal error: The result and element selectors must be non-null +CannotResetIEnumerator=Internal error: Cannot reset this IEnumerator. +FailureInExcept=Failure during Except. +FailureInIntersect=Failure during Intersect. +FailureInSort=Failure during sorting. +SortedChunkCannotBeEmpty=Internal error: Sorted chunk cannot be empty. +RangePartitionInputOutputMismatch=RangePartition: partition keys and output channels mismatch. There were {0} keys and {1} channels. +FailureInHashGroupBy=Failure in hash based GroupBy. +FailureInSortGroupBy=Failure in sort based GroupBy. +FailureInHashJoin=Failure in hash based Join. +FailureInHashGroupJoin=Failure in hash based GroupJoin. +FailureInDistinct=Failure in Distinct. +FailureInOperator=Failure in {0}. +FailureInOrderedGroupBy=Failure in ordered GroupBy. +FailureInUserApplyFunction=Apply: Failure in user-defined function. +VertexBridgeBadArgs=VertexBridge arguments are malformed. argsString={0} + +UnknownChannelType=Unknown channel kind: {0}. +CannotReadQueryPlan=Cannot read query plan for job: {0}. +UnknownConnectionType=Unknown connection type: {0}. +UnknownChannelType2=Unknown channel of type: {0}. +UnknownMethodInExpression=Unknown method in expression to summarize: {0}. + +InputMustBeHpcLinqSource=The input expression must be a LINQ to HPC source. +OutputTypeCannotBeAnonymous=Output data type cannot be an anonymous type. +InputTypeCannotBeAnonymous=Input data type cannot be an anonymous type. + +DecomposerTypeMustBePublic=Decomposition class must be public. Class={0}. +DecomposerTypeDoesNotImplementInterface=Decomposition class must implement IDecomposable<,,> or IDecomposableRecursive<,,>. Class={0}. +DecomposerTypesDoNotMatch=Decomposition class types must match the function that it decorates. Method={0}. Class={1}. +DecomposerTypeImplementsTooManyInterfaces=Decomposition class should implement only one decomposable interface. Class={0}. +DecomposerTypeDoesNotHavePublicDefaultCtor=Decomposition class must have a public parameterless constructor. Class={0}. + +AssociativeMethodHasWrongForm=A method tagged [Associative] should take two parameters of type T and return type T. Method={0}. + +AssociativeTypeMustBePublic=Associative class must be public. Class={0}. +AssociativeTypeDoesNotImplementInterface=Associative class must implement IDecomposable<,,> or IDecomposableRecursive<,,>. Class={0}. +AssociativeTypesDoNotMatch=Associative class types must match the function that it decorates. Method={0}. Class={1}. +AssociativeTypeImplementsTooManyInterfaces=Associative class should implement only one IAssociativeRecursive interface. Class={0}. +AssociativeTypeDoesNotHavePublicDefaultCtor=Associative class must have a public parameterless constructor. Class={0}. + + +CannotRebuildOptimizedQueryExpression=Internal error: cannot rebuild optimized query expression for node. +InputArityMustEqualChildren=Internal error: InputArity must equal Children.Length. +DistinctOnlyTakesTwoArgs=Internal Error: The Distinct operator can only take at most 2 arguments. + +ComparerMustBeSpecifiedOrKeyTypeMustBeIComparable=If a key-comparer is not provided, TKey must implement IComparable. TKey={0}. +ComparerMustBeSpecifiedOrKeyTypeMustBeIEquatable=If a key-comparer is not provided, TKey must override GetHashCode() and either implement IEquatable or override Equals(). TKey={0}. +ComparerExpressionMustBeSpecifiedOrElementTypeMustBeIEquatable=If a comparer expression is not provided, TElement must implement IEquatable or override both Equals() and GetHashCode(). TElement={0}. + +TooManyHomomorphicAttributes=A method should not be tagged with both HomomorphicAttribute and LeftHomomorphicAttribute. +HomomorphicApplyNeedsSamePartitionCount=All inputs to homomorphic apply must have same number of input partitions +UnrecognizedDataSource=Unrecognized data source: {0}. +OperatorNotSupported=The operator '{0}' encountered in expression isn't a valid LINQ to HPC operator. +AggregationOperatorRequiresIComparable=Aggregation operator '{0}' can only work on objects that implement IComparable. +BadSeparatorCount=The number of separators passed to AssumeRangePartition must equal src.nPartitions - 1. nRangeSeparators={0}, Expected={1}. +MultipleOutputsWithSameDscUri=Multiple query outputs are targeted to the same DSC fileset name {0} +OutputUriAlsoQueryInput=The output DSC fileset name {0} is also used as the query source or as one of the referenced data sources. + + +CannotConcatDatasetsWithDifferentCompression=Can't concat two datasets with different compression schemes. +CannotCreateTablesWithDifferentCompression=Can't create multiple tables with different compression schemes. +FailedToRemoveMergeNode=Internal error: Failed to remove the Merge node. +CannotAttach=Internal error: cannot attach a pipeline +CannotAddTeeToNode=Internal error: Can't add Tee to a node with more than one outputs. +ShouldNotCreateCodeForDummyNode=Internal error: should not create vertex code for a dummy node. +ShouldNotCreateCodeForInput=Internal error: should not create vertex code for input. +ShouldNotCreateCodeForOutput=Internal error: should not create vertex code for output. +ShouldNotCreateCodeForConcat=Internal error: should not create vertex code for Concat. +ShouldNotCreateCodeForTee=Internal error: should not create vertex code for Tee. +IllegalDynamicManagerType=Internal error: Illegal type of dynamic manager in dynamic node +AggregateOperatorNotSupported=The aggregate operator {0} is not supported. +DynamicManagerType=Internal error: Cannot have dynamic manager of type {0}. + +ChannelCannotBeReadMoreThanOnce=A LINQ to HPC channel cannot be read more than once. For example, the delegate in Apply() may only enumerate its input once. +ShouldNototCallReset=Internal error: Should never call Reset() + +CannotHaveMoreThanOneOutput=Internal error: Can't have more than one output channel. + +DSCStreamError=DSC fileset error: {0}. +StreamDoesNotExist=The DSC fileset {0} doesn't exist. +StreamAlreadyExists=DSC fileset already exists: {0}. +OpenForWriteError=Internal error: OpenForWrite called when fs was not null. +AttemptToReadFromAWriteStream=Attempt to read from a stream that was opened for writing. +FailedToCreateStream=Failed to create DSC fileset: {0}. +ReadWriteNotSupported=ReadWrite access is not supported. + +ExpressionTypeNotHandled={0} cannot handle expression of type {1}. +UnhandledQuery=Internal error: LINQ to HPC query expression cannot be handled : {0}. + +MultiBlockEmptyPartitionList=The partition list of the table was empty. +MultiBlockCannotAccesFilePath= Failed to access file {0}, which is part of DSC stream {1}. If this file has other replicas, they were all attempted but could not be accessed either. + + +GetURINotSupported=GetURI() is not implemented for this stream type. +SetCalcFPNotSupported=SetCalcFP() is not implemented for this stream type. +GetFPNotSupported=GetFP() is not implemented for this stream type. +FailedToAllocateNewNativeBuffer=Failed to allocate a new native data block of size {0}. +FailedToReadFromInputChannel=Native channel failed to read from input channel at port {0}. Win32 error = {1}. +FailedToWriteToOutputChannel=Native channel failed to write to output channel at port {0}. + +FailedToGetStreamProps=Failed to get stream properties for file set {0}. +MetadataRecordType=RecordType does not match with file set metadata. Use FromDsc with matching T, or an overload that supresses type-check. RecordType={0}, DscStream.recordType={1}. +NullKeySelector=KeySelector function must be non-null. +NumKeys=The number of keys must be equal to PartitionCount-1. +JobToCreateTableFailed=The job to create this HpcLinqQuery(T) failed with error: {0}. +JobToCreateTableWasCanceled=The job to create this HpcLinqQuery(T) was canceled by the user. +FailedToGetReadPathsForStream=Failed to get read paths for DSC fileset {0}. + +OnlyAvailableForPhysicalData=Property only available for physical data. +CreatingDscDataFromLocalDebugFailed=Error creating DSC data from local debug mode. See inner exception for details. + +CompareArgIncorrect=Argument is not a LineRecord. diff --git a/README.txt b/README.txt new file mode 100644 index 0000000..f254492 --- /dev/null +++ b/README.txt @@ -0,0 +1,47 @@ +Dryad + +This is a research prototype of the Dryad and DryadLINQ data-parallel +processing frameworks running on Hadoop YARN. Dryad utilizes cluster +services provided as part of Hadoop YARN to reliably execute +distributed computations on a cluster of computers. DryadLINQ +leverages Dryad to reliably execute a distributed computation on a +cluster of computers. + +This is a research prototype of Dryad and DryadLINQ running on YARN, +which is still in active development. As a result, you should expect +some fragility. + +Requirements + +A version of YARN built for Windows + The BUILDING.txt file in the Hadoop YARN repository contains + instructions on building YARN for Windows. +Visual Studio 2010 or 2012 +Java Development Kit 1.6 +A Windows YARN cluster composed of x64 machines + +Building Dryad + +1) Clone the Dryad git repository. +2) Ensure that YARN_HOME environment variable is set. +3) Set the DRYAD_HOME environment variable to binary path + (bin\Debug or bin\Release) under the directory Dryad was cloned to. +4) Use Visual Studio to open The Dryad solution file (Dryad.sln) located + in the root of the repository and build the solution. +5) Run Build.bat in the Java directory at the top-level of the repository. + The CLASSPATH will need to be set to the output of the 'yarn classpath' + command. + +Cluster setup +1) Setup your YARN cluster as you normally would. +2) Copy the contents of the DRYAD_HOME directory to the location set by + DRYAD_HOME on each compute node in the cluster. + +Notes + +The YARN interfaces used are current as of commit dfb83b8 in trunk. + +If you are running debug builds of the Dryad, also copy the files msvcp100d.dll +and msvcr100d.dll to the DRYAD_HOME directory on each compute node. + +The HDFS implementation in Dryad currently only supports text files. \ No newline at end of file diff --git a/linqtodryadjm_managed_yarn/DryadLinqApplication.cs b/linqtodryadjm_managed_yarn/DryadLinqApplication.cs new file mode 100644 index 0000000..0218b5b --- /dev/null +++ b/linqtodryadjm_managed_yarn/DryadLinqApplication.cs @@ -0,0 +1,410 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; +using System.Collections.Generic; +using System.IO; +using System.Reflection; +using System.Data.Linq; +using System.Data.Linq.Mapping; +using System.Xml; +using System.Xml.Serialization; +using Microsoft.Research.Dryad; + +namespace linqtodryadjm_managed +{ + public class DryadLINQApp + { + public class OptionDescription + { + public enum OptionCategory + { + OC_General = 0, + OC_Inputs, + OC_Outputs, + OC_VertexBehavior, + OC_LAST + }; + public OptionCategory m_optionCategory; + + public int m_optionIndex; + public string m_shortName; + public string m_longName; + public string m_arguments; + public string m_descriptionText; + + public OptionDescription() + { + } + + public OptionDescription(OptionCategory optionCatagory, int index, string shortName, string longName, string args, string desc) + { + m_optionCategory = optionCatagory; + m_optionIndex = index; + m_shortName = shortName; + m_longName = longName; + m_arguments = args; + m_descriptionText = desc; + } + }; + + enum DryadLINQAppOptions + { + BDAO_EmbeddedResource, + BDAO_ReferenceResource, + BDJAO_AMD64, + BDJAO_I386, + BDJAO_Retail, + BDJAO_Debug, + DNAO_MaxAggregateInputs, + DNAO_MaxAggregateFilterInputs, + DNAO_AggregateThreshold, + DNAO_NoClusterAffinity, + DNAO_LAST + }; + + static OptionDescription []s_dryadLinqOptionArray = new OptionDescription[] + { + new OptionDescription + ( + OptionDescription.OptionCategory.OC_VertexBehavior, + (int)DryadLINQAppOptions.BDJAO_AMD64, + "amd64", + "useamd64binary", + "", + "Dummy argument for legacy reasons" + ), + new OptionDescription + ( + OptionDescription.OptionCategory.OC_VertexBehavior, + (int)DryadLINQAppOptions.BDJAO_I386, + "i386", + "usei386binary", + "", + "Dummy argument for legacy reasons" + ), + new OptionDescription + ( + OptionDescription.OptionCategory.OC_VertexBehavior, + (int)DryadLINQAppOptions.BDJAO_Retail, + "retail", + "useretailbinary", + "", + "Dummy argument for legacy reasons" + ), + new OptionDescription + ( + OptionDescription.OptionCategory.OC_VertexBehavior, + (int)DryadLINQAppOptions.BDJAO_Debug, + "debug", + "usedebugbinary", + "", + "Dummy argument for legacy reasons" + ), + + new OptionDescription + ( + OptionDescription.OptionCategory.OC_General, + (int)DryadLINQAppOptions.DNAO_MaxAggregateInputs, + "mai", + "maxaggregateinputs", + "", + "Only allow aggregate vertices to use up to maxInputs inputs." + ), + new OptionDescription + ( + OptionDescription.OptionCategory.OC_General, + (int)DryadLINQAppOptions.DNAO_MaxAggregateFilterInputs, + "mafi", + "maxaggregatefilterinputs", + "", + "Only allow aggregate filter vertices to use up to maxInputs inputs." + ), + new OptionDescription + ( + OptionDescription.OptionCategory.OC_General, + (int)DryadLINQAppOptions.DNAO_AggregateThreshold, + "at", + "aggregatethreshold", + "", + "Only allow aggregate and aggregate filter vertices to use inputs up to a total of size dataSize. dataSize can be specified as a number of bytes or with the (case-insensitive) suffix KB, MB, GB, TB, PB, i.e. 12KB. If a suffix is present then fractions are allowed, e.g. 20.5MB." + ), + new OptionDescription + ( + OptionDescription.OptionCategory.OC_General, + (int)DryadLINQAppOptions.DNAO_NoClusterAffinity, + "nca", + "noclusteraffinity", + "", + "By default, LINQToHPC does not process DSC filesets from a different cluster. Specifying this flag overrides that" + ) + }; + + SortedDictionary m_optionMap; + private DrGraph m_graph; + private bool m_clusterAffinity; + private int m_maxAggregateInputs; + private int m_maxAggregateFilterInputs; + private UInt64 m_aggregateThreshold; + private UInt64 m_startTime; + private FileStream m_identityMapFile; + + public DryadLINQApp(DrGraph graph) + { + m_graph = graph; + m_clusterAffinity = true; + m_maxAggregateInputs = 150; + m_maxAggregateFilterInputs = 32; + m_aggregateThreshold = 1024*1024*1024; // 1GB + m_startTime = graph.GetXCompute().GetCurrentTimeStamp(); + m_identityMapFile = null; + m_optionMap = new SortedDictionary(); + AddOptionsToMap(s_dryadLinqOptionArray); + } + + public void AddOptionsToMap(OptionDescription[] optionArray) + { + foreach (OptionDescription option in optionArray) + { + m_optionMap.Add(option.m_shortName, option); + m_optionMap.Add(option.m_longName, option); + } + } + + public void PrintOptionUsage(OptionDescription option) + { + Console.WriteLine(" {-{0}|-{1}} {2}", option.m_shortName, option.m_longName, option.m_arguments); + Console.WriteLine(" {0}", option.m_descriptionText); + } + + static string[] s_optionCategoryName = new string[] + { + "General options", + "Options to specify job inputs", + "Options to specify job outputs", + "Options to control vertex behavior" + }; + + public void PrintCategoryUsage(OptionDescription.OptionCategory category) + { + Console.WriteLine(s_optionCategoryName[(int)category]); + foreach (KeyValuePair kvp in m_optionMap) + { + OptionDescription option = kvp.Value; + if (option.m_optionCategory == category && kvp.Key == option.m_shortName) + { + PrintOptionUsage(option); + } + } + } + + public void PrintUsage(string exeName) + { + string leafName = Path.GetFileName(exeName); + Console.WriteLine("usage: {0} [--debugbreak] [--popup] \n fall into the following categories:", leafName); + for (OptionDescription.OptionCategory i=0; i= args.Length) + { + DryadLogger.LogCritical(0, null, "The argument for option '{0}' was missing.\n", args[index]); + retVal = false; + } + else + { + int maxInputs; + if (!Int32.TryParse(args[index+1], out maxInputs)) + { + DryadLogger.LogCritical(0, null, "The argument '{0}' for option '{1}' could not be parsed as an integer.\n", args[index + 1], args[index]); + retVal = false; + } + else + { + m_maxAggregateInputs = maxInputs; + index++; + } + } + break; + + case (int)DryadLINQAppOptions.DNAO_MaxAggregateFilterInputs: + if ((index + 1) >= args.Length) + { + DryadLogger.LogCritical(0, null, "The argument for option '{0}' was missing.\n", args[index]); + retVal = false; + } + else + { + int maxInputs; + if (!Int32.TryParse(args[index+1], out maxInputs)) + { + DryadLogger.LogCritical(0, null, "The argument '{0}' for option '{1}' could not be parsed as an integer.\n", args[index + 1], args[index]); + retVal = false; + } + else + { + m_maxAggregateFilterInputs = maxInputs; + index++; + } + } + break; + + + case (int)DryadLINQAppOptions.DNAO_AggregateThreshold: + if ((index + 1) >= args.Length) + { + DryadLogger.LogCritical(0, null, "The argument for option '{0}' was missing.\n", args[index]); + retVal = false; + } + else + { + UInt64 threshold; + if (!UInt64.TryParse(args[index+1], out threshold)) + { + DryadLogger.LogCritical(0, null, "The argument '{0}' for option '{1}' could not be parsed as a UIN64.\n", args[index + 1], args[index]); + retVal = false; + } + else + { + m_aggregateThreshold = threshold; + index++; + } + } + break; + + case (int)DryadLINQAppOptions.DNAO_NoClusterAffinity: + m_clusterAffinity = false; + break; + + default: + DryadLogger.LogCritical(0, null, "Unknown command-line option {0}\n", optionIndex); + retVal = false; + break; + } + + } + + if (!retVal) + { + PrintUsage(args[0]); + } + return retVal; + + } + + internal bool ExtractReferenceResourceNameAndUri(string resourceSpec, ref string resourceUri, ref string resourceName) + { + // reference resourceSpec is of type name@uri + string name = resourceSpec.Substring(0, resourceSpec.IndexOf('@')); + string uri = resourceSpec.Substring(resourceSpec.IndexOf('@')); + + // resourceSpecCopy now is name\0uri with resourceName pointing to name and resourceUri pointing to uri + resourceUri = uri; + resourceName = name; + + return true; + } + + public DrGraph GetGraph() + { + return m_graph; + } + + public DrUniverse GetUniverse() + { + return m_graph.GetXCompute().GetUniverse(); + } + + public void SetClusterAffinity(bool flag) + { + m_clusterAffinity = flag; + } + + public bool GetClusterAffinity() + { + return m_clusterAffinity; + } + + public int GetMaxAggregateInputs() + { + return m_maxAggregateInputs; + } + + public UInt64 GetAggregateThreshold() + { + return m_aggregateThreshold; + } + + public int GetMaxAggregateFilterInputs() + { + return m_maxAggregateFilterInputs; + } + + public UInt64 GetStartTime() + { + return m_startTime; + } + + public void SetXmlFileName(string xml) + { + if (xml.Length >= 3) + { + string dup = xml.Substring(0, xml.Length-3); + dup += "map"; + m_identityMapFile = File.Open(dup, FileMode.Create); + } + } + + public FileStream GetIdentityMapFile() + { + return m_identityMapFile; + } + + } + +} // namespace DryadLINQ + \ No newline at end of file diff --git a/linqtodryadjm_managed_yarn/GraphBuilder.cs b/linqtodryadjm_managed_yarn/GraphBuilder.cs new file mode 100644 index 0000000..4223b96 --- /dev/null +++ b/linqtodryadjm_managed_yarn/GraphBuilder.cs @@ -0,0 +1,748 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; +using System.Collections.Generic; +using System.IO; +using System.Reflection; +using System.Data.Linq; +using System.Data.Linq.Mapping; +using System.Xml; +using System.Xml.Serialization; +using Microsoft.Research.Dryad; + +namespace linqtodryadjm_managed +{ + public class GraphBuilder + { + enum ChannelConnectorType + { + DCT_File, + DCT_Output, + DCT_Pipe, + DCT_Fifo, + DCT_FifoNonBlocking, + DCT_Tombstone + }; + + public class DrVertexSet : List + { + }; + + public class GraphStageInfo + { + public DrVertexSet members; + public Vertex vertex; + public DrStageManager stageManager; + + public GraphStageInfo() + { + stageManager = null; + } + + public GraphStageInfo(DrVertexSet m, Vertex v, DrStageManager mgr) + { + members = m; + vertex = v; + stageManager = mgr; + } + } + + // GraphStageMap is used to store the vertices as they are created, grouped by query plan stage. + // Once all the stages are added, the map is iterated over to connect the appropriate vertices + // and produce the full graph. + + internal bool SUCCEEDED(int err) + { + if (err == 0) + { + return true; + } + return false; + } + + private DrInputStreamManager CreateInputNode(DryadLINQApp app, VertexInfo info, string inputName) + { + DrInputStreamManager s; + int err = 0; + + DryadLogger.LogMethodEntry(inputName); + + if (info.ioType == VertexInfo.IOType.PARTITIONEDFILE ) + { + DrPartitionInputStream input = new DrPartitionInputStream(); + + err = input.Open(app.GetUniverse(), info.sources[0]); + if (!SUCCEEDED(err)) + { + string msg = String.Format("Could not read DSC input file {0}", info.sources[0]); + throw new LinqToDryadException(msg, err); + } + + DrManagerBase inputStage = new DrManagerBase(app.GetGraph(), inputName); + DrInputStreamManager inputManager = new DrInputStreamManager(input, inputStage); + + s = inputManager; + } + //else if ( info.ioType == VertexInfo.IOType.STREAM ) + //{ + // DrDscInputStream input = new DrDscInputStream(); + + // DryadLogger.LogInformation("Create input node", "Opening DSC input fileset"); + + // err = input.Open(app.GetUniverse(), info.sources[0]); + // if (!SUCCEEDED(err)) + // { + // string msg = String.Format("Could not read DSC input fileset {0}", info.sources[0]); + // throw new LinqToDryadException(msg, err); + // } + + // DryadLogger.LogInformation("Create input node", "Opened DSC input fileset"); + + // DrManagerBase inputStage = new DrManagerBase(app.GetGraph(), inputName); + // DrInputStreamManager inputManager = new DrInputStreamManager(input, inputStage); + + // s = inputManager; + //} + else if (info.ioType == VertexInfo.IOType.HDFS_STREAM) + { + DrHdfsInputStream input = new DrHdfsInputStream(); + + DryadLogger.LogInformation("Create input node", "Opening HDFS input fileset"); + + err = input.Open(app.GetUniverse(), info.sources[0]); + if (!SUCCEEDED(err)) + { + string msg = String.Format("Could not read HDFS input fileset {0}", info.sources[0]); + throw new LinqToDryadException(msg, err); + } + + DryadLogger.LogInformation("Create input node", "Opened HDFS input fileset"); + + DrManagerBase inputStage = new DrManagerBase(app.GetGraph(), inputName); + DrInputStreamManager inputManager = new DrInputStreamManager(input, inputStage); + + s = inputManager; + } + else + { + string msg = String.Format("Unknown input type {0}", info.ioType); + throw new LinqToDryadException(msg); + } + + DryadLogger.LogMethodExit(); + return s; + } + + private DrOutputStreamManager CreateOutputNode(DryadLINQApp app, VertexInfo info, string outputName) + { + DrOutputStreamManager s; + if ( info.ioType == VertexInfo.IOType.PARTITIONEDFILE ) + { + DrPartitionOutputStream output = new DrPartitionOutputStream(); + int err = output.Open(info.sources[0], info.partitionUncPath); + if (!SUCCEEDED(err)) + { + string msg = String.Format("Could not open DSC output fileset {0}", info.sources[0]); + throw new LinqToDryadException(msg, err); + } + + DrManagerBase outputStage = new DrManagerBase(app.GetGraph(), outputName); + DrOutputStreamManager outputManager = new DrOutputStreamManager(output, outputStage); + app.GetGraph().AddPartitionGenerator(outputManager); + + s = outputManager; + } + //else if ( info.ioType == VertexInfo.IOType.STREAM ) + //{ + // DrDscOutputStream output = new DrDscOutputStream(info.compressionScheme, info.isTemporary); + // int err = 0; + // if (info.recordType == "") + // { + // err = output.Open(info.sources[0], info.partitionUncPath); + // } + // else + // { + // err = output.OpenWithRecordType(info.sources[0], info.partitionUncPath, info.recordType); + // } + + // if (!SUCCEEDED(err)) + // { + // string msg = String.Format("Could not open DSC output fileset {0}", info.sources[0]); + // throw new LinqToDryadException(msg, err); + // } + + // DrManagerBase outputStage = new DrManagerBase(app.GetGraph(), outputName); + // DrOutputStreamManager outputManager = new DrOutputStreamManager(output, outputStage); + // app.GetGraph().AddPartitionGenerator(outputManager); + + // s = outputManager; + //} + else if (info.ioType == VertexInfo.IOType.HDFS_STREAM) + { + DrHdfsOutputStream output = new DrHdfsOutputStream(); + int err = output.Open(info.sources[0]); + + if (!SUCCEEDED(err)) + { + string msg = String.Format("Could not open HDFS output fileset {0}", info.sources[0]); + throw new LinqToDryadException(msg, err); + } + + DrManagerBase outputStage = new DrManagerBase(app.GetGraph(), outputName); + DrOutputStreamManager outputManager = new DrOutputStreamManager(output, outputStage); + app.GetGraph().AddPartitionGenerator(outputManager); + + s = outputManager; + } + else + { + string msg = String.Format("Unknown output type {0}", info.ioType); + throw new LinqToDryadException(msg); + } + + return s; + } + + private DrVertexSet CreateVertexSet(DrGraph graph, DrVertex prototype, int copies) + { + DrVertexSet result = new DrVertexSet(); + + for (int i = 0; i < copies; i++) + { + DrVertex v = prototype.MakeCopy(i); + result.Add(v); + } + + return result; + } + + private DrVertexSet CreateVertexSet(DrGraph graph, DrInputStreamManager inputStream) + { + DrVertexSet result = new DrVertexSet(); + + List vertices = inputStream.GetVertices(); + for (int i = 0; i < vertices.Count; i++) + { + DrVertex v = vertices[i]; + result.Add(v); + } + + return result; + } + + private DrVertexSet CreateVertexSet(DrGraph graph, DrOutputStreamManager outputStream, int partitions) + { + DrVertexSet result = new DrVertexSet(); + + outputStream.SetNumberOfPartitions(partitions); + + List vertices = outputStream.GetVertices(); + for (int i = 0; i < vertices.Count; i++) + { + DrVertex v = vertices[i]; + result.Add(v); + } + + return result; + } + + private void CreateVertexSet(Vertex v, DryadLINQApp app, Query query, Dictionary graphStageMap) + { + SortedDictionary queryPlan = query.queryPlan; + DrVertexSet nodes = null; + DrStageManager newManager = null; + + app.GetGraph().GetParameters(); + + DrGraphParameters parameters = app.GetGraph().GetParameters(); + string stdVertexName = "MW"; + + if ( v.type == Vertex.Type.INPUTTABLE) + { + DrInputStreamManager input = CreateInputNode(app, v.info, v.name); + newManager = input.GetStageManager(); + nodes = CreateVertexSet(app.GetGraph(), input); + } + else if ( v.type == Vertex.Type.OUTPUTTABLE ) + { + DrOutputStreamManager output = CreateOutputNode(app, v.info, v.name); + newManager = output.GetStageManager(); + nodes = CreateVertexSet(app.GetGraph(), output, v.partitions); + } + else if ( v.type == Vertex.Type.CONCAT ) + { + newManager = new DrManagerBase(app.GetGraph(), v.name); + + // the set of nodes in a concat is just the set of nodes in the predecessor stages concatenated + nodes = new DrVertexSet(); + foreach (Predecessor p in v.info.predecessors) + { + GraphStageInfo value = null; + if (graphStageMap.TryGetValue(p.uniqueId, out value)) + { + nodes.InsertRange(nodes.Count, value.members); + } + else + { + throw new LinqToDryadException(String.Format("Concat: Failed to find predecessor {0} in graph stage map", p.uniqueId)); + } + } + } + else + { + newManager = new DrManagerBase(app.GetGraph(), v.name); + + DrVertex vertex; + if (v.type == Vertex.Type.TEE) + { + DrTeeVertex teeVertex = new DrTeeVertex(newManager); + vertex = teeVertex; + } + else + { + DrActiveVertex activeVertex = new DrActiveVertex(newManager, parameters.m_defaultProcessTemplate, parameters.m_defaultVertexTemplate); + + activeVertex.AddArgument(stdVertexName); + activeVertex.AddArgument(v.info.assemblyName); + activeVertex.AddArgument(v.info.className); + activeVertex.AddArgument(v.info.methodName); + + vertex = activeVertex; + } + + nodes = CreateVertexSet(app.GetGraph(), vertex, v.partitions); + + if (v.machines != null && v.machines.Length != 0 && v.type != Vertex.Type.TEE) + { + for (int i=0; i graphStageMap) + { + Vertex destVertex = destInfo.vertex; + int destId = destVertex.uniqueId; + + DrVertexSet destNodes = graphStageMap[ destId ].members; + foreach (Predecessor p in destVertex.info.predecessors) + { + GraphStageInfo info = graphStageMap[ p.uniqueId ]; + Vertex sourceVertex = info.vertex; + + if (destVertex.type != Vertex.Type.CONCAT) + { + ChannelConnectorType channelType = ChannelConnectorType.DCT_File; + if ( p.channelType == Predecessor.ChannelType.DISKFILE ) + { + if (destVertex.type == Vertex.Type.OUTPUTTABLE) + { + channelType = ChannelConnectorType.DCT_Output; + } + else + { + channelType = ChannelConnectorType.DCT_File; + } + } + else if ( p.channelType == Predecessor.ChannelType.MEMORYFIFO ) + { + channelType = ChannelConnectorType.DCT_Fifo; + } + else if ( p.channelType == Predecessor.ChannelType.TCPPIPE ) + { + channelType = ChannelConnectorType.DCT_Pipe; + } + else + { + string msg = String.Format("Unknown channel type {0}", p.channelType); + throw new LinqToDryadException(msg); + } + + DrVertexSet sourceNodes = info.members; + switch (p.connectionOperator) + { + case Predecessor.ConnectionOperator.CROSSPRODUCT: + ConnectCrossProduct(sourceNodes, destNodes, channelType); + break; + case Predecessor.ConnectionOperator.POINTWISE: + ConnectPointwise(sourceNodes, destNodes, channelType); + break; + default: + break; + } + + } + } + } + + public void BuildGraphFromQuery(DryadLINQApp app, Query query) + { + // set configurable properties + int highThreshold = app.GetMaxAggregateInputs(); + int lowThreshold = 16; + UInt64 highDataThreshold = (UInt64)app.GetAggregateThreshold(); + UInt64 lowDataThreshold = (3*highDataThreshold)/4; + UInt64 maxSingleDataThreshold = highDataThreshold/2; + int aggFilterThreshold = app.GetMaxAggregateFilterInputs(); + + // use a graph stage map to store the vertices as they are created, grouped by stage. + Dictionary graphStageMap = new Dictionary(); + + DryadLogger.LogInformation("Build Graph From Query", "Building graph"); + + // + // Create a set of vertices for each vertex (stage) in the query plan + // + DryadLogger.LogInformation("Build Graph From Query", "Adding vertices"); + foreach (KeyValuePair kvp in query.queryPlan) + { + Vertex v = kvp.Value; + GraphStageInfo value = null; + + if (!graphStageMap.TryGetValue(v.uniqueId, out value)) + { + DryadLogger.LogInformation("Build Graph From Query", "Adding vertices for stage {0}", v.name); + CreateVertexSet(v, app, query, graphStageMap); + } + } + + // + // Add dynamic stage managers + // + DryadLogger.LogInformation("Build Graph From Query", "Adding stage managers"); + foreach (KeyValuePair kvp in graphStageMap) + { + Vertex v = kvp.Value.vertex; + + // + //There are no dynamic managers + // + if (v.dynamicManager == null) + { + continue; + } + + DrStageManager newManager = kvp.Value.stageManager; // newManager + + DrGraphParameters parameters = app.GetGraph().GetParameters(); + + string stdVertexName = "MW"; + string cpyVertexName = "CP"; + + if (v.type != Vertex.Type.INPUTTABLE && v.type != Vertex.Type.CONCAT) + { + if (v.dynamicManager.type == DynamicManager.Type.SPLITTER) + { + if (v.info.predecessors.Length == 1) + { + DrPipelineSplitManager splitter = new DrPipelineSplitManager(); + newManager.AddDynamicConnectionManager(graphStageMap[v.info.predecessors[0].uniqueId].stageManager, splitter); + } + else + { + DrSemiPipelineSplitManager splitter = new DrSemiPipelineSplitManager(); + newManager.AddDynamicConnectionManager(graphStageMap[v.info.predecessors[0].uniqueId].stageManager, splitter); + } + } + else if (v.dynamicManager.type == DynamicManager.Type.PARTIALAGGR) + { + DrDynamicAggregateManager dynamicMerge = new DrDynamicAggregateManager(); + + dynamicMerge.SetGroupingSettings(0, 0); + dynamicMerge.SetMachineGroupingSettings(2, aggFilterThreshold); + dynamicMerge.SetDataGroupingSettings(lowDataThreshold, highDataThreshold, maxSingleDataThreshold); + dynamicMerge.SetSplitAfterGrouping(true); + + foreach (Predecessor p in v.info.predecessors) + { + newManager.AddDynamicConnectionManager(graphStageMap[p.uniqueId].stageManager, dynamicMerge); + } + } + else if (v.dynamicManager.type == DynamicManager.Type.FULLAGGR || + v.dynamicManager.type == DynamicManager.Type.HASHDISTRIBUTOR) + { + int idx = 0; + int sz = v.dynamicManager.assemblyNames == null ? 0 : v.dynamicManager.assemblyNames.Length; + DrDynamicAggregateManager dynamicMerge = new DrDynamicAggregateManager(); + + if (v.dynamicManager.type == DynamicManager.Type.FULLAGGR || sz > 1) + { + dynamicMerge = new DrDynamicAggregateManager(); + + string name = v.dynamicManager.methodNames[idx]; + DrManagerBase newStage = new DrManagerBase(app.GetGraph(), name); + + DrActiveVertex mergeVertex = new DrActiveVertex(newStage, parameters.m_defaultProcessTemplate, parameters.m_defaultVertexTemplate); + mergeVertex.AddArgument(stdVertexName); + + mergeVertex.AddArgument(v.dynamicManager.assemblyNames[idx]); + mergeVertex.AddArgument(v.dynamicManager.classNames[idx]); + mergeVertex.AddArgument(v.dynamicManager.methodNames[idx]); + + idx++; + dynamicMerge.SetInternalVertex(mergeVertex); + + dynamicMerge.SetGroupingSettings(0, 0); + dynamicMerge.SetPodGroupingSettings(lowThreshold, highThreshold); + dynamicMerge.SetDataGroupingSettings(lowDataThreshold, + highDataThreshold, + maxSingleDataThreshold); + dynamicMerge.SetMaxAggregationLevel(v.dynamicManager.aggregationLevels); + } + + if (v.dynamicManager.type == DynamicManager.Type.FULLAGGR) + { + newManager.AddDynamicConnectionManager(graphStageMap[v.info.predecessors[0].uniqueId].stageManager, dynamicMerge); + } + else + { + string name = v.dynamicManager.methodNames[idx]; + DrManagerBase newStage = new DrManagerBase(app.GetGraph(), name); + + DrActiveVertex distributeVertex = new DrActiveVertex(newStage, parameters.m_defaultProcessTemplate, parameters.m_defaultVertexTemplate); + distributeVertex.AddArgument(stdVertexName); + + distributeVertex.AddArgument(v.dynamicManager.assemblyNames[idx]); + distributeVertex.AddArgument(v.dynamicManager.classNames[idx]); + distributeVertex.AddArgument(v.dynamicManager.methodNames[idx]); + + idx++; + + DrDynamicDistributionManager dynamicHashDistribute = + new DrDynamicDistributionManager(distributeVertex, dynamicMerge); + dynamicHashDistribute.SetDataPerVertex(highDataThreshold*2); // 2GB + + newManager.AddDynamicConnectionManager(graphStageMap[v.info.predecessors[0].uniqueId].stageManager, dynamicHashDistribute); + } + } + else if (v.dynamicManager.type == DynamicManager.Type.RANGEDISTRIBUTOR) + { + DrStageManager splitManager = graphStageMap[v.dynamicManager.splitVertexId].stageManager; + DrDynamicRangeDistributionManager drdm = new DrDynamicRangeDistributionManager(splitManager, v.dynamicManager.sampleRate); + drdm.SetDataPerVertex(highDataThreshold*2); // 2GB + newManager.AddDynamicConnectionManager(graphStageMap[v.info.predecessors[0].uniqueId].stageManager,drdm); + } + else if (v.dynamicManager.type == DynamicManager.Type.BROADCAST) + { + // the copy vertex + int bcastNumber = 0; + string nameString = String.Format("CP__{0}", bcastNumber++); + DrManagerBase newStage = new DrManagerBase(app.GetGraph(), nameString); + + DrActiveVertex copyVertex = + new DrActiveVertex(newStage, + parameters.m_defaultProcessTemplate, + parameters.m_defaultVertexTemplate); + copyVertex.AddArgument(cpyVertexName); + + DrDynamicBroadcastManager bcast = new DrDynamicBroadcastManager(copyVertex); + newManager.AddDynamicConnectionManager(graphStageMap[v.info.predecessors[0].uniqueId].stageManager, bcast); + } + else if (v.dynamicManager.type != DynamicManager.Type.NONE) + { + DryadLogger.LogWarning("Build Graph From Query", "Dynamic manager type {0} not supported yet", v.dynamicManager.type); + } + } + } + + + // + // Add all the edges + // + DryadLogger.LogInformation("Build Graph From Query", "Adding edges"); + foreach (KeyValuePair kvp in graphStageMap) + { + AddEdges(kvp.Value, graphStageMap); + } + + // + // Register the actual created vertices with the graph + // + MaterializeToManagers(graphStageMap); + } + + public void PrintGraph(DrGraph graph) + { + using (StreamWriter sw = new StreamWriter(new FileStream("toplogy.txt", FileMode.Create))) + { + List stages = graph.GetStages(); + + foreach (DrStageManager s in stages) + { + List vertices = s.GetVertexVector(); + foreach (DrVertex v in vertices) + { + sw.Write("{0} <= ", v.GetId()); + + DrEdgeHolder eh = v.GetInputs(); + int n = eh.GetNumberOfEdges(); + + if (n == 0) + { + sw.Write("DSC"); + } + else + { + for (int i = 0; i < n; i++) + { + DrVertex vin = eh.GetEdge(i).m_remoteVertex; + sw.Write("{0} ", vin.GetId()); + } + } + sw.WriteLine(); + } + } + } + } + + public void MaterializeToManagers(Dictionary graphStageMap) + { + foreach (KeyValuePair kvp in graphStageMap) + { + // Skip CONCAT - it is just a collection of inputs from previous stages + // and those vertices have already been registered + if (kvp.Value.vertex.type != Vertex.Type.CONCAT) + { + foreach (DrVertex v in kvp.Value.members) + { + v.GetStageManager().RegisterVertex(v); + } + } + } + } + } + +} // namespace DryadLINQ diff --git a/linqtodryadjm_managed_yarn/LinqToDryadException.cs b/linqtodryadjm_managed_yarn/LinqToDryadException.cs new file mode 100644 index 0000000..41c5089 --- /dev/null +++ b/linqtodryadjm_managed_yarn/LinqToDryadException.cs @@ -0,0 +1,41 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; + +namespace linqtodryadjm_managed +{ + internal class LinqToDryadException : Exception + { + public static readonly uint E_FAIL = 0x80004005; + + public LinqToDryadException(string message) + : base(message) + { + HResult = unchecked((int)E_FAIL); + } + + public LinqToDryadException(string message, int hresult) + : base(message) + { + HResult = hresult; + } + } +} diff --git a/linqtodryadjm_managed_yarn/LinqToDryadJM.cs b/linqtodryadjm_managed_yarn/LinqToDryadJM.cs new file mode 100644 index 0000000..3371365 --- /dev/null +++ b/linqtodryadjm_managed_yarn/LinqToDryadJM.cs @@ -0,0 +1,297 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; +using System.Collections.Generic; +using System.IO; +using System.Reflection; +using System.Data.Linq; +using System.Data.Linq.Mapping; +using System.Diagnostics; +using System.Threading; +using System.Xml; +using System.Xml.Serialization; +using Microsoft.Research.Dryad; + +namespace linqtodryadjm_managed +{ + internal class DebugHelper + { + private static bool brokeInDebugger = false; + private static bool loggingInitialized = false; + private static object syncRoot = new object(); + + public static void WaitForDebugger() + { + if (!brokeInDebugger) + { + Console.Out.WriteLine("Waiting for debugger..."); + while (!Debugger.IsAttached) + { + Thread.Sleep(1000); + } + Debugger.Break(); + brokeInDebugger = true; + } + } + + public static void DebugBreakOnFileExisting(string breakFileName) + { + if (File.Exists(breakFileName)) + { + WaitForDebugger(); + } + } + + public static void SetLogType() + { + DrLogging.SetLoggingLevel((DrLogType)DryadLogger.TraceLevel); + } + + public static void InitializeLogging(DateTime startTime) + { + if (!loggingInitialized) + { + lock (syncRoot) + { + if (!loggingInitialized) + { + // Initialize text-based tracing + string traceFile = Path.Combine(Directory.GetCurrentDirectory(), "GraphManagerTrace.txt"); + DryadLogger.Start(traceFile); + + // Initialize Graph Manager's internal logging + DrLogging.Initialize(); + DebugHelper.SetLogType(); + + // Report start time to Artemis - must come after + // DrLogging is initialized so stdout is redirected + DrArtemisLegacyReporter.ReportStart((ulong) startTime.Ticks); + + loggingInitialized = true; + } + } + } + } + + public static void StopLogging(int retCode) + { + if (loggingInitialized) + { + lock (syncRoot) + { + if (loggingInitialized) + { + // Report stop time to Artemis + DrArtemisLegacyReporter.ReportStop(unchecked((uint)retCode)); + + // Shutdown Graph Manager's internal logging + DrLogging.ShutDown(unchecked((uint)retCode)); + + // Shutdown text-based tracing + DryadLogger.Stop(); + + loggingInitialized = false; + + } + } + } + } + } + + public class LinqToDryadJM + { + internal void FinalizeExecution(Query query, DrGraph graph) + { + SortedDictionary queryPlan = query.queryPlan; + foreach (KeyValuePair kvp in query.queryPlan) + { + /* used to do CSStream expiration time stuff here */ + } + } + + internal bool ConsumeSingleArgument(string arg, ref string[] args) + { + List temp = new List(); + bool found = false; + + for (int index=0; index internalArgs = new List(); + + // + // add the XmlExecHost args to the internal app arguments + // + foreach (string xmlExecHostArg in query.xmlExecHostArgs) + + { + if (xmlExecHostArg == "--break") + { + DebugHelper.WaitForDebugger(); + } + else + { + internalArgs.Add(xmlExecHostArg); + } + } + + // + // combine internal arguments with any additional arguments received on the command line + // don't include argv[0] and argv[1] (program name and query XML file name) + // + + int internalArgc = (int)internalArgs.Count; + int externalArgc = args.Length - 2; // don't include argv[0] and argv[1] + int combinedArgc = internalArgc + externalArgc; + string[] combinedArgv = new string[combinedArgc]; + + string msg = ""; + // internal arguments first + for (int i=0; i tempArgs = new List(); + + // Record start time so we can report it once logging has been initialized + DateTime startTime = DateTime.Now.ToLocalTime(); + + // Set unhandled exception handler to catch anything thrown from + // Microsoft.Hpc.Query.GraphManager.dll + AppDomain currentDomain = AppDomain.CurrentDomain; + currentDomain.UnhandledException += new UnhandledExceptionEventHandler(ExceptionHandler); + + + // + // Add executable name to beginning of args + // to make graph manager libraries happy + // + tempArgs.Add("YarnQueryGraphManager"); + foreach (string arg in args) + { + if (String.Compare(arg, "--break", StringComparison.OrdinalIgnoreCase) == 0) + { + waitForDebugger = true; + } + else + { + tempArgs.Add(arg); + } + } + args = tempArgs.ToArray(); + + if (waitForDebugger) + { + DebugHelper.WaitForDebugger(); + } + /* Yarn removes the need for this + + // + // Create job directory and copy resources + // + string resources = Environment.GetEnvironmentVariable("XC_RESOURCEFILES"); + if (ExecutionHelper.InitializeForJobExecution(resources) == false) + { + return 1; + } + + // + // Set current directory to working directory + // + Directory.SetCurrentDirectory(ProcessPathHelper.JobPath); + */ + // + // Configure tracing + // + DebugHelper.InitializeLogging(startTime); + + // Ensure that there is a jvm.dll on the path. + string pathString = Environment.GetEnvironmentVariable("PATH"); + var pathDirs = pathString.Split(';'); + bool found = false; + foreach (var dir in pathDirs) + { + string targetFile = Path.Combine(dir, "jvm.dll"); + if (File.Exists(targetFile)) + { + found = true; + break; + } + } + if (!found) + { + string javaHome = Environment.GetEnvironmentVariable("JAVA_HOME"); + if (String.IsNullOrEmpty(javaHome)) + { + throw new ApplicationException("DryadLINQ requires the JAVA_HOME environment variable to be set or the jvm.dll to be on the path."); + } + var jvmPath = ";" + Path.Combine(javaHome, "jre", "bin", "server"); + pathString += jvmPath; + Environment.SetEnvironmentVariable("PATH", pathString); + } + + + // + // Run the Graph Manager + // + int retCode = 0; + try + { + LinqToDryadJM jm = new LinqToDryadJM(); + retCode = jm.Run(args); + } + catch (Exception e) + { + retCode = System.Runtime.InteropServices.Marshal.GetHRForException(e); + if (retCode == 0) + { + DryadLogger.LogCritical(0, e); + retCode = unchecked((int)LinqToDryadException.E_FAIL); + } + else + { + DryadLogger.LogCritical(retCode, e); + } + System.Threading.Thread.Sleep(10 * 60 * 1000); + } + + DebugHelper.StopLogging(retCode); + + // NOTE: We don't want to log critical errors twice, so here we're assuming that + // if the GM exited "gracefully" and returned an error code instead of throwing + // an exception, that it has already logged the error. + if (retCode != 0) + { + // If the Graph Manager already started executing, we need to exit the process. + // Exiting the thread (returning from Main) will not necessarily cause the GM's + // worker threads to exit + + // TODO: Consider deleting temp output stream from DSC + // requires that we have access to the URI at this point, though + Environment.Exit(retCode); + } + + // + // Cleanup all vertex tasks in case any became + // unreachable. + // + + return retCode; + } + } +} diff --git a/linqtodryadjm_managed_yarn/Properties/AssemblyInfo.cs b/linqtodryadjm_managed_yarn/Properties/AssemblyInfo.cs new file mode 100644 index 0000000..eb8466d --- /dev/null +++ b/linqtodryadjm_managed_yarn/Properties/AssemblyInfo.cs @@ -0,0 +1,54 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("linqtodryadjm_managed")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyProduct("linqtodryadjm_managed")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("55c7f4b4-02ab-4309-b276-3c9ab5a927e0")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("1.0.*")] +[assembly: AssemblyVersion("1.0.0.0")] +[assembly: AssemblyFileVersion("1.0.0.0")] diff --git a/linqtodryadjm_managed_yarn/Query.cs b/linqtodryadjm_managed_yarn/Query.cs new file mode 100644 index 0000000..9271a8a --- /dev/null +++ b/linqtodryadjm_managed_yarn/Query.cs @@ -0,0 +1,162 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; +using System.Collections.Generic; +using System.IO; +using System.Reflection; +using System.Data.Linq; +using System.Data.Linq.Mapping; +using System.Xml; +using System.Xml.Serialization; + +namespace linqtodryadjm_managed +{ + public class DynamicManager + { + public enum Type + { + NONE = 0, + SPLITTER, + PARTIALAGGR, + FULLAGGR, + HASHDISTRIBUTOR, + RANGEDISTRIBUTOR, + BROADCAST, + }; + public Type type; + + public string[] assemblyNames; // dll file name of vertex entry code + public string[] classNames; // class name of vertex entry code + public string[] methodNames; // method name of vertex entry code + + public double sampleRate; // For range distributor only + public int splitVertexId; // For range distributor only + public int aggregationLevels; // For aggregators + }; + + + public class Predecessor + { + public enum ConnectionOperator + { + POINTWISE = 0, + CROSSPRODUCT + }; + + public enum ChannelType + { + DISKFILE = 0, + TCPPIPE, + MEMORYFIFO + }; + + public enum AffinityConstraint + { + UseDefault = 0, + HardConstraint, + OptimizationConstraint, + Preference, + DontCare + }; + + public int uniqueId; + public ConnectionOperator connectionOperator; + public ChannelType channelType; + public AffinityConstraint constraint; + + }; + + + public class VertexInfo + { + public enum IOType + { + FILELIST = 0, // always just one file + FILEDIRECTORY, // only for input + FILEWILDCARD, // only for input + STREAM, + HDFS_STREAM, + PARTITIONEDFILE, + FILEPREFIX // only for output + }; + + public IOType ioType; + public Predecessor[] predecessors = new Predecessor[0]; + + // for tables-type vertices only + public string[] sources; // fully qualified URI of output + + // for partitioned output table type vertices only + public string partitionUncPath; // fully-qualified network URI path + + // True iff the output is a temp dataset + public bool isTemporary; + + // for general vertices + public string assemblyName; // dll file name of vertex entry code + public string className; // class name of vertex entry code + public string methodName; // method name of vertex entry code + + // for OUTPUTTABLE storage set, this is the record type + public string recordType; + // for OUTPUTTABLE storage set, this is the compresssion mode + public int compressionScheme; + + }; + + + public class Vertex + { + public enum Type + { + UNKNOWN = -1, + UNUSED = 0, + INPUTTABLE, + OUTPUTTABLE, + WHERE, + JOIN, + FORK, + TEE, + CONCAT, + SUPER + }; + public Type type; + + public int uniqueId; // id number (and position in plan - starts from zero) + public string name; // pretty-printing name + public int partitions; // partitions + public VertexInfo info; // (vertex-type specific)additional info about this vertex + public DynamicManager dynamicManager; + public string[] machines; + }; + + public class Query + { + public string compilerversion = ""; // version of DryadLinq compiler + public string clusterName = ""; // name of dryad cluster to run on + public string[] xmlExecHostArgs = new string[0]; // app defined command-line args for XmlExecHost + public string visualization = ""; // control of visualization + public int intermediateDataCompression = 0; //YARN // compression scheme for intermediate data + public SortedDictionary queryPlan = new SortedDictionary(); // DAG of numbered vertices + public bool enableSpeculativeDuplication = true; + }; + +} // namespace DryadLINQ diff --git a/linqtodryadjm_managed_yarn/QueryParser.cs b/linqtodryadjm_managed_yarn/QueryParser.cs new file mode 100644 index 0000000..6a59e9c --- /dev/null +++ b/linqtodryadjm_managed_yarn/QueryParser.cs @@ -0,0 +1,428 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +using System; +using System.Collections.Generic; +using System.IO; +using System.Reflection; +using System.Data.Linq; +using System.Data.Linq.Mapping; +using System.Xml; +using System.Xml.Serialization; +using Microsoft.Research.Dryad; + +namespace linqtodryadjm_managed +{ + public class QueryPlanParser + { + public VertexInfo.IOType GetIoType(string type) + { + if (type == "File") return VertexInfo.IOType.FILELIST; + if (type == "PartitionedFile") return VertexInfo.IOType.PARTITIONEDFILE; + if (type == "FileDirectory") return VertexInfo.IOType.FILEDIRECTORY; + if (type == "FileWildcard") return VertexInfo.IOType.FILEWILDCARD; + if (type == "TidyFS") return VertexInfo.IOType.STREAM; + if (type == "Dsc") return VertexInfo.IOType.STREAM; + if (type == "Hdfs") return VertexInfo.IOType.HDFS_STREAM; + if (type == "FilePrefix") return VertexInfo.IOType.FILEPREFIX; + throw new LinqToDryadException(String.Format("Unknown IoType: {0}", type)); + } + + public Vertex.Type GetVertexType(string type) + { + if (type == "InputTable") return Vertex.Type.INPUTTABLE; + if (type == "OutputTable") return Vertex.Type.OUTPUTTABLE; + if (type == "Where") return Vertex.Type.WHERE; + if (type == "Join") return Vertex.Type.JOIN; + if (type == "Fork") return Vertex.Type.FORK; + if (type == "Tee") return Vertex.Type.TEE; + if (type == "Concat") return Vertex.Type.CONCAT; + if (type == "Super") return Vertex.Type.SUPER; + if (type == "Apply") return Vertex.Type.SUPER; + return Vertex.Type.UNKNOWN; + } + + public Predecessor.ChannelType GetChannelType(string type) + { + if (type == "DiskFile") return Predecessor.ChannelType.DISKFILE; + if (type == "TCPPipe") return Predecessor.ChannelType.TCPPIPE; + if (type == "MemoryFIFO") return Predecessor.ChannelType.MEMORYFIFO; + throw new LinqToDryadException(String.Format("Unknown ChannelType: {0}", type)); + } + + public Predecessor.ConnectionOperator GetConnectionOperator(string type) + { + if (type == "Pointwise") return Predecessor.ConnectionOperator.POINTWISE; + if (type == "CrossProduct") return Predecessor.ConnectionOperator.CROSSPRODUCT; + throw new LinqToDryadException(String.Format("Unknown ConnectionOperator: {0}", type)); + } + + public Predecessor.AffinityConstraint GetAffinityConstraint(string type) + { + if (type == "UseDefault") return Predecessor.AffinityConstraint.UseDefault; + if (type == "HardConstraint") return Predecessor.AffinityConstraint.HardConstraint; + if (type == "OptimizationConstraint") return Predecessor.AffinityConstraint.OptimizationConstraint; + if (type == "Preference") return Predecessor.AffinityConstraint.Preference; + if (type == "DontCare") return Predecessor.AffinityConstraint.DontCare; + throw new LinqToDryadException(String.Format("Unknown AffinityConstraint: {0}", type)); + } + + public DynamicManager.Type GetDynamicManagerType(string type) + { + if (type == "None") return DynamicManager.Type.NONE; + if (type == "Splitter") return DynamicManager.Type.SPLITTER; + if (type == "PartialAggregator") return DynamicManager.Type.PARTIALAGGR; + if (type == "FullAggregator") return DynamicManager.Type.FULLAGGR; + if (type == "HashDistributor") return DynamicManager.Type.HASHDISTRIBUTOR; + if (type == "RangeDistributor") return DynamicManager.Type.RANGEDISTRIBUTOR; + if (type == "Broadcast") return DynamicManager.Type.BROADCAST; + throw new LinqToDryadException(String.Format("Unknown DynamicManager: {0}", type)); + } + + public static bool SplitEntryIntoAssemblyClassMethod(string entry, out string _assembly, out string _class, out string _method) + { + _assembly = ""; + _class = ""; + _method = ""; + + int indexBang = entry.IndexOf("!"); + int indexPeriod = entry.LastIndexOf("."); + + if (indexBang == -1 || indexPeriod == -1 || indexPeriod <= indexBang) + { + return false; + } + + _assembly = entry.Substring(0, indexBang); + _class = entry.Substring(indexBang + 1, indexPeriod - indexBang - 1); + _method = entry.Substring(indexPeriod + 1); + return true; + + } + + private void ParseQueryXmlLinqToDryad(XmlDocument queryPlanDoc, Query query) + { + XmlElement root = queryPlanDoc.DocumentElement; + + // + // Query globals + // + query.queryPlan = new SortedDictionary(); + query.compilerversion = root.SelectSingleNode("DryadLinqVersion").InnerText; + query.clusterName = root.SelectSingleNode("ClusterName").InnerText; + query.visualization = root.SelectSingleNode("Visualization").InnerText; + + // Compression scheme for intermediate data + XmlNode compressionNode = root.SelectSingleNode("IntermediateDataCompression"); + if (compressionNode != null) + { + query.intermediateDataCompression = Convert.ToInt32(compressionNode.InnerText); + } + + // + // XmlExecHost arguments + // + XmlNodeList nodes = root.SelectSingleNode("XmlExecHostArgs").ChildNodes; + query.xmlExecHostArgs = new string[nodes.Count]; + for (int index=0; index + + diff --git a/linqtodryadjm_managed_yarn/linqtodryadjm_managed.csproj b/linqtodryadjm_managed_yarn/linqtodryadjm_managed.csproj new file mode 100644 index 0000000..1f28ad0 --- /dev/null +++ b/linqtodryadjm_managed_yarn/linqtodryadjm_managed.csproj @@ -0,0 +1,127 @@ + + + + Debug + AnyCPU + 9.0.30729 + 2.0 + {1311809B-306E-44A4-9D69-8A7BD15123C5} + Exe + Properties + linqtodryadjm_managed + linqtodryadjm_managed + v4.0 + 512 + + + 3.5 + + false + + publish\ + true + Disk + false + Foreground + 7 + Days + false + false + true + 0 + 1.0.0.%2a + false + true + + + true + full + false + ..\bin\Debug\ + DEBUG;TRACE + prompt + 4 + AllRules.ruleset + x64 + + + pdbonly + true + bin\Release\ + TRACE + prompt + 4 + AllRules.ruleset + x64 + + + + + 3.5 + + + 3.5 + + + 3.5 + + + 3.5 + + + + + + + Constants.cs + + + DryadTracing.cs + + + QueryUtility.cs + + + + + + + + + + + + + False + .NET Framework 3.5 SP1 Client Profile + false + + + False + .NET Framework 3.5 SP1 + true + + + False + Windows Installer 3.1 + true + + + + + + + + {8E30F4A4-603B-4799-A473-6EF5388661BA} + GraphManager + + + + + \ No newline at end of file diff --git a/xcompute_managed/Dispatcher.cs b/xcompute_managed/Dispatcher.cs new file mode 100644 index 0000000..2674a5d --- /dev/null +++ b/xcompute_managed/Dispatcher.cs @@ -0,0 +1,772 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.Globalization; + using System.Collections.Generic; + using System.Collections.Specialized; + using System.Diagnostics; + using System.Net.Security; + using System.ServiceModel; + using System.ServiceModel.Channels; + using System.Text; + using System.Xml; + using System.Threading; + using Microsoft.Research.Dryad; + + internal delegate void DispatcherFaultedEventHandler(object sender, EventArgs e); + + internal enum SchedulingResult + { + Success = 0, + Pending = 1, + Failure = 2, + CommunicationError = 3 + }; + + internal sealed class Dispatcher : IDisposable + { + public event DispatcherFaultedEventHandler FaultedEvent; + + public static readonly int InvalidProcessId = -1; + private static readonly int MaxRetries = 3; + private static readonly int RetryDelayInMilliseconds = 60 * 1000; // Retry on a faulted dispatcher every 1 minute + + private NetTcpBinding m_backendBinding = null; + + private VertexServiceClient m_client = null; + + private ISchedulerHelper m_schedulerHelper; + + private int m_connectionAttempts = 0; + private ScheduleProcessRequest m_currentProcess = null; + private string m_currentReplyUri = null; + private AsyncCallback m_currentAsyncCallback = null; + private bool m_disposed = false; + private string m_endpointAddress = String.Empty; + private bool m_faulted = false; + private bool m_idle = true; + private string m_nodeName = String.Empty; + private object m_syncRoot = new object(); + private int m_taskId = 0; + private int m_schedulingAttempts = 0; + private bool m_taskFailed = false; + private Timer m_retryTimer = null; + + #region Constructors + + /// + /// Constructor used by the vertex host + /// + /// + /// + public Dispatcher(string name, string endpointAddress) + { + m_nodeName = name; + m_endpointAddress = endpointAddress; + SafeOpenConnection(); + } + + /// + /// Constructor used by the Graph Manager + /// + /// + /// + public Dispatcher(ISchedulerHelper schedulerHelper, VertexComputeNode computeNode) + { + m_schedulerHelper = schedulerHelper; + m_taskId = computeNode.instanceId; + m_nodeName = computeNode.ComputeNode; + m_backendBinding = m_schedulerHelper.GetVertexServiceBinding(); + m_endpointAddress = m_schedulerHelper.GetVertexServiceBaseAddress(m_nodeName, m_taskId) + Constants.vertexServiceName; + SafeOpenConnection(); + } + + #endregion + + #region Public Methods + + public void SetRetryTimer(TimerCallback cb) + { + lock (SyncRoot) + { + // Guard against SetRetryTimer and Dispose getting called at + // the same time + if (!m_disposed) + { + if (m_retryTimer != null) + { + m_retryTimer.Dispose(); + m_retryTimer = null; + } + m_retryTimer = new Timer(cb, this, RetryDelayInMilliseconds, Timeout.Infinite); + } + } + } + + public void CancelScheduleProcess(int processId) + { + bool faultDispatcher = true; + + for (int numRetries = 0; numRetries < MaxRetries; numRetries++) + { + try + { + if (!Faulted) + { + this.m_client.CancelScheduleProcess(processId); + } + return; + } + // CancelScheduleProcess is one-way + catch (TimeoutException te) + { + DryadLogger.LogWarning("Cancel Process", "Timeout communicating with vertex service on node {0}: {1}", this.m_nodeName, te.ToString()); + if (!SafeOpenConnection()) + { + faultDispatcher = true; + break; + } + } + catch (CommunicationException ce) + { + DryadLogger.LogWarning("Cancel Process", "Error communicating with vertex service on node {0}: {1}", this.m_nodeName, ce.ToString()); + if (!SafeOpenConnection()) + { + faultDispatcher = true; + break; + } + } + catch (Exception e) + { + DryadLogger.LogError(0, e, "Error calling CancelScheduleProcess for node {0}, process {1}", m_nodeName, processId); + faultDispatcher = false; + break; + } + } + + if (faultDispatcher) + { + RaiseFaultedEvent(); + } + } + + public VertexStatus CheckStatus() + { + for (int index = 0; index < MaxRetries; index++) + { + try + { + if (!Faulted) + { + return this.m_client.CheckStatus(); + } + break; + } + catch (Exception e) + { + DryadLogger.LogError(0, e, "node '{0}'", m_nodeName); + if (!SafeOpenConnection()) + { + break; + } + } + } + + RaiseFaultedEvent(); + + VertexStatus s = new VertexStatus(); + s.serviceIsAlive = false; + return s; + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + public void Initialize(StringDictionary vertexEndpointAddresses) + { + bool faultDispatcher = true; + + for (int numRetries = 0; numRetries < MaxRetries; numRetries++) + { + try + { + if (!Faulted) + { + this.m_client.Initialize(vertexEndpointAddresses); + } + return; + } + // Initialize is one-way + catch (TimeoutException te) + { + DryadLogger.LogWarning("Initialize", "Timeout communicating with vertex service on node {0}: {1}", this.m_nodeName, te.ToString()); + if (!SafeOpenConnection()) + { + faultDispatcher = true; + break; + } + } + catch (CommunicationException ce) + { + DryadLogger.LogWarning("Initialize", "Error communicating with vertex service on node {0}: {1}", this.m_nodeName, ce.ToString()); + if (!SafeOpenConnection()) + { + faultDispatcher = true; + break; + } + } + catch (Exception e) + { + DryadLogger.LogError(0, e, "Error calling Initialize for node {0}", m_nodeName); + faultDispatcher = false; + break; + } + } + + if (faultDispatcher) + { + RaiseFaultedEvent(); + } + + } + + private void RaiseFaultedEvent() + { + RaiseFaultedEvent(false); + } + + public void RaiseFaultedEvent(bool taskFailed) + { + bool raiseEvent = false; + + // For SP3, we need to crash if this happens in the vertex host + if (String.Compare(Process.GetCurrentProcess().ProcessName, "HpcQueryVertexHost", StringComparison.OrdinalIgnoreCase) == 0) + { + DryadLogger.LogCritical(0, null, "Vertex Host lost communication with Vertex Service while updating vertex status: Exiting vertex. Graph Manager will rerun a failed vertex up to six times."); + Environment.Exit(unchecked((int)Constants.DrError_VertexHostLostCommunication)); + } + + lock (SyncRoot) + { + // We always want to raise the faulted event if the + // task failed, so that the dispatcher is disposed. + + // If the task did not fail, we want to ensure that + // the event is only raised once for a given fault. + raiseEvent = taskFailed || (!Faulted); + + + // We never want to reset m_taskFailed once it's been set + // to true, because the task isn't coming back. + m_taskFailed = m_taskFailed || taskFailed; + + m_faulted = true; + } + + if (raiseEvent) + { + DryadLogger.LogError(0, null, "Dispatcher for task {0} has faulted on node {1}, current process: {2}", m_taskId, m_nodeName, CurrentProcess == InvalidProcessId ? "" : CurrentProcess.ToString()); + + // Notice that this will keep any locks that are currently held, so refrain from calling this while enumerating the dispatchers + FaultedEvent(this, null); + } + } + + public void Release() + { + if (!Idle) + { + lock (SyncRoot) + { + if (!Idle) + { + this.m_idle = true; + // Reset the number of scheduling attempts, since they are per-process + m_schedulingAttempts = 0; + } + } + } + } + + /// + /// Notify vertex service that the Graph Manager is done + /// with vertex process processId + /// + /// Process Id of the process to release + public void ReleaseProcess(int processId) + { + bool faultDispatcher = true; + + for (int numRetries = 0; numRetries < MaxRetries; numRetries++) + { + try + { + if (CurrentProcess == processId) + { + m_currentProcess = null; + } + + if (!Faulted) + { + this.m_client.ReleaseProcess(processId); + } + return; + } + // ReleaseProcess is one-way + catch (TimeoutException te) + { + DryadLogger.LogWarning("Release Process", "Timeout communicating with vertex service on node {0}: {1}", this.m_nodeName, te.ToString()); + if (!SafeOpenConnection()) + { + faultDispatcher = true; + break; + } + } + catch (CommunicationException ce) + { + DryadLogger.LogWarning("Release Process", "Error communicating with vertex service on node {0}: {1}", this.m_nodeName, ce.ToString()); + if (!SafeOpenConnection()) + { + faultDispatcher = true; + break; + } + } + catch (Exception e) + { + DryadLogger.LogError(0, e, "Error calling ReleaseProcess for node {0}", m_nodeName); + faultDispatcher = false; + break; + } + } + + if (faultDispatcher) + { + RaiseFaultedEvent(); + } + + } + + public bool Reserve() + { + bool acquired = false; + if (!Faulted && Idle) + { + lock (SyncRoot) + { + if (!Faulted && Idle) + { + m_idle = false; + acquired = true; + } + } + } + return acquired; + } + + public bool ScheduleProcess(string replyUri, ScheduleProcessRequest req, AsyncCallback cb) + { + bool faultDispatcher = true; + + for (int numRetries = 0; numRetries < MaxRetries; numRetries++) + { + try + { + // TODO: Why are we taking the lock in this particular case again? + lock (SyncRoot) + { + if (!Faulted && m_schedulingAttempts < MaxRetries) + { + m_schedulingAttempts++; + + // Set the current process so that if the dispatcher faults we know + // which process to kill + m_currentProcess = req; + m_currentReplyUri = replyUri; + m_currentAsyncCallback = cb; + + this.m_client.BeginScheduleProcess(replyUri, req.Id, req.CommandLine, req.Environment, cb, (object)this); + return true; + } + } + return false; + } + catch (FaultException vse) + { + DryadLogger.LogWarning("Schedule Process", "Error scheduling process {0} on node {1}: {2}", req.Id, this.m_nodeName, vse.Reason); + faultDispatcher = false; + break; + } + catch (TimeoutException te) + { + DryadLogger.LogWarning("Schedule Process", "Timeout communicating with vertex service scheduling process {0} on node {1}: {2}", req.Id, this.m_nodeName, te.ToString()); + if (!SafeOpenConnection()) + { + faultDispatcher = true; + break; + } + } + catch (CommunicationException ce) + { + DryadLogger.LogWarning("Schedule Process", "Error communicating with vertex service scheduling process {0} on node {1}: {2}", req.Id, this.m_nodeName, ce.ToString()); + if (!SafeOpenConnection()) + { + faultDispatcher = true; + break; + } + } + catch (Exception e) + { + DryadLogger.LogError(0, e, "Error calling ScheduleProcess for process {0} on node {1}", req.Id, m_nodeName); + faultDispatcher = false; + break; + } + } + + if (faultDispatcher) + { + RaiseFaultedEvent(); + } + return false; + + } + + public SchedulingResult EndScheduleProcess(IAsyncResult asyncResult) + { + // We don't want to retry the async end operation - if it fails retry + // the whole scheduling operation + + try + { + if (!Faulted) + { + if (this.m_client.EndScheduleProcess(asyncResult)) + { + return SchedulingResult.Success; + } + else + { + return SchedulingResult.Failure; + } + + } + else + { + return SchedulingResult.Failure; + } + } + catch (FaultException vse) + { + DryadLogger.LogWarning("Schedule Process", "Error completing schedule process {0} on node {1}: {2}", this.m_currentProcess.Id, this.m_nodeName, vse.Reason); + return SchedulingResult.Failure; + } + catch (TimeoutException te) + { + DryadLogger.LogWarning("Schedule Process", "Timeout communicating with vertex service for process {0} on node {1}: {2}", this.m_currentProcess.Id, this.m_nodeName, te.ToString()); + } + catch (CommunicationException ce) + { + DryadLogger.LogWarning("Schedule Process", "Error communicating with vertex service for process {0} on node {1}: {2}", this.m_currentProcess.Id, this.m_nodeName, ce.ToString()); + } + catch (Exception e) + { + DryadLogger.LogError(0, e, "Error calling EndScheduleProcess for process {0} on node {0}", this.m_currentProcess.Id, m_nodeName); + return SchedulingResult.Failure; + } + + // If we make it here, then we need to retry the scheduling operation + if (SafeOpenConnection()) + { + // ScheduleProcess manages the retry count and returns false if it is exceeded + DryadLogger.LogDebug("Schedule Process", "Communication error: retrying process {0} on node {1}", this.m_currentProcess.Id, this.m_nodeName); + if (ScheduleProcess(m_currentReplyUri, m_currentProcess, m_currentAsyncCallback)) + { + return SchedulingResult.Pending; + } + } + + // SafeOpenConnection failed or retry count exceeded - fault the dispatcher. + DryadLogger.LogWarning("Schedule Process", "Connection failed to node {0}", this.m_nodeName); + return SchedulingResult.CommunicationError; + } + + public bool SetGetProps(string replyUri, int processId, ProcessPropertyInfo[] infos, string blockOnLabel, ulong blockOnVersion, long maxBlockTime, string getPropLabel, bool ProcessStatistics) + { + bool faultDispatcher = true; + + for (int numRetries = 0; numRetries < MaxRetries; numRetries++) + { + try + { + if (!Faulted) + { + return this.m_client.SetGetProps(replyUri, processId, infos, blockOnLabel, blockOnVersion, maxBlockTime, getPropLabel, ProcessStatistics); + } + return false; + } + catch (FaultException) + { + DryadLogger.LogWarning("Set Get Process Properties", "Attempt to get or set properties for unknown process {0} on node {1}", processId, this.m_nodeName); + faultDispatcher = false; + break; + } + catch (FaultException vse) + { + DryadLogger.LogWarning("Set Get Process Properties", "Error setting or getting properties for process {0} on node {1}: {2}", processId, this.m_nodeName, vse.Reason); + faultDispatcher = false; + break; + } + catch (TimeoutException te) + { + DryadLogger.LogWarning("Set Get Process Properties", "Timeout communicating with vertex service for process {0} on node {1}: {2}", processId, this.m_nodeName, te.ToString()); + if (!SafeOpenConnection()) + { + faultDispatcher = true; + break; + } + } + catch (CommunicationException ce) + { + DryadLogger.LogWarning("Set Get Process Properties", "Error communicating with vertex service for process {0} on node {1}: {2}", processId, this.m_nodeName, ce.ToString()); + if (!SafeOpenConnection()) + { + faultDispatcher = true; + break; + } + } + catch (Exception e) + { + DryadLogger.LogError(0, e, "Error calling SetGetProps for process {0} on node {1}", processId, m_nodeName); + faultDispatcher = false; + break; + } + } + + if (faultDispatcher) + { + RaiseFaultedEvent(); + } + return false; + + } + + /// + /// Call Shutdown method on the vertex service and close the communication channel. + /// After this method is called, the Dispatcher is unusable. + /// + /// uint code - reserved for future use + public void Shutdown(uint code) + { + for (int index = 0; index < MaxRetries; index++) + { + try + { + if (!Faulted) + { + this.m_client.Shutdown(code); + } + return; + } + catch (FaultException vse) + { + DryadLogger.LogWarning("Shutdown", "Error shutting down vertex service on node {0}: {1}", this.m_nodeName, vse.Reason); + break; + } + catch (TimeoutException te) + { + DryadLogger.LogWarning("Shutdown", "Timeout communicating with vertex service on node {0}: {1}", this.m_nodeName, te.ToString()); + if (!SafeOpenConnection()) + { + break; + } + } + catch (CommunicationException ce) + { + DryadLogger.LogWarning("Shutdown", "Error communicating with vertex service on node {0}: {1}", this.m_nodeName, ce.ToString()); + if (!SafeOpenConnection()) + { + DryadLogger.LogWarning("Shutdown", "Failed to reopen connection to node {0}", this.m_nodeName); + break; + } + } + catch (Exception e) + { + DryadLogger.LogWarning("Shutdown", "Exception shutting down vertex service on node {0}: {1}", this.m_nodeName, e.ToString()); + if (!SafeOpenConnection()) + { + break; + } + } + } + + // Not faulting the dispatcher here, even though the WCF connection could not be closed cleanly + // Shutdown is only called when the graphmanger is closing, so there is no need to fault the dispatchers + // Also avoids problems around faulting dispatchers while enumerating them + } + + #endregion + + #region Private Methods + + private void Dispose(bool disposing) + { + bool closeConnection = false; + + if (!this.m_disposed) + { + lock (SyncRoot) + { + if (!this.m_disposed) + { + if (disposing) + { + closeConnection = true; + } + + if (m_retryTimer != null) + { + m_retryTimer.Dispose(); + m_retryTimer = null; + } + m_disposed = true; + } + } + } + + if (closeConnection) + { + SafeCloseConnection(); + } + } + + private void SafeCloseConnection() + { + VertexServiceClient client = null; + lock (SyncRoot) + { + if (m_client == null) + { + return; + } + + client = m_client; + m_client = null; + } + + try + { + client.Close(); + } + catch + { + try + { + client.Abort(); + } + catch + { + } + } + } + + private bool SafeOpenConnection() + { + SafeCloseConnection(); + + lock (SyncRoot) + { + if (!Faulted) + { + m_client = new VertexServiceClient(m_backendBinding, new EndpointAddress(m_endpointAddress)); + m_connectionAttempts++; + return true; + } + return false; + } + } + + #endregion + + #region Properties + + public int ConnectionAttempts + { + get { return this.m_connectionAttempts; } + } + + public int CurrentProcess + { + get + { + if (m_currentProcess != null) + { + return m_currentProcess.Id; + } + else + { + return InvalidProcessId; + } + } + } + + public bool Faulted + { + get { return this.m_faulted; } + set { this.m_faulted = value; } + } + + public bool SchedulerTaskFailed + { + get { return this.m_taskFailed; } + set { this.m_taskFailed = value; } + } + + public bool Idle + { + get { return this.m_idle; } + } + + public string NodeName + { + get { return this.m_nodeName; } + } + + public object SyncRoot + { + get { return m_syncRoot; } + } + + public int TaskId + { + get + { + return this.m_taskId; + } + } + + #endregion + + } +} diff --git a/xcompute_managed/DispatcherPool.cs b/xcompute_managed/DispatcherPool.cs new file mode 100644 index 0000000..0fc36b3 --- /dev/null +++ b/xcompute_managed/DispatcherPool.cs @@ -0,0 +1,150 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +namespace Microsoft.Research.Dryad +{ + using System.Collections.Generic; + using System.Linq; + using System; + + internal class DispatcherPool + { + private List m_dispatcherTable = new List(); + + private object m_syncRoot = new object(); + + public DispatcherPool() + { + } + + public bool Add(Dispatcher d) + { + lock (SyncRoot) + { + // We never want to have two dispatchers for the same node + // since we don't support oversubscription for vertex nodes + Dispatcher dummy = null; + if (!GetByNodeName(d.NodeName, out dummy)) + { + m_dispatcherTable.Add(d); + return true; + } + else + { + return false; + } + } + } + + public void Clear() + { + lock (SyncRoot) + { + foreach (Dispatcher d in m_dispatcherTable) + { + d.Dispose(); + } + m_dispatcherTable.Clear(); + } + } + + public bool GetByNodeName(string node, out Dispatcher d) + { + lock (SyncRoot) + { + d = m_dispatcherTable.Find(x => String.Compare(x.NodeName, node, StringComparison.OrdinalIgnoreCase) == 0); + } + return (d != null); + } + + public bool GetByTaskId(int taskId, out Dispatcher d) + { + lock (SyncRoot) + { + d = m_dispatcherTable.Find(x => x.TaskId == taskId); + } + return (d != null); + } + + public bool Remove(Dispatcher d) + { + lock (SyncRoot) + { + return m_dispatcherTable.Remove(d); + } + } + + public bool TryReserveDispatcher(string node, out Dispatcher dispatcher) + { + lock (SyncRoot) + { + Dispatcher d = null; + if (GetByNodeName(node, out d)) + { + if (d.Reserve()) + { + dispatcher = d; + return true; + } + } + } + dispatcher = null; + return false; + } + + public IEnumerator GetEnumerator() + { + return m_dispatcherTable.GetEnumerator(); + } + + public int Count + { + get + { + lock (SyncRoot) + { + return m_dispatcherTable.Count(x => !x.Faulted); + } + } + } + + public List Nodes + { + get + { + List nodes = new List(); + lock (SyncRoot) + { + foreach (Dispatcher d in this) + { + nodes.Add(d.NodeName); + } + } + return nodes; + } + } + + public object SyncRoot + { + get { return this.m_syncRoot; } + } + + } +} diff --git a/xcompute_managed/JobStatus.cs b/xcompute_managed/JobStatus.cs new file mode 100644 index 0000000..4b10e95 --- /dev/null +++ b/xcompute_managed/JobStatus.cs @@ -0,0 +1,128 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.Threading; + using Microsoft.Research.Dryad; + + public class JobStatus + { + private int m_progressStepsCompleted = 0; + private int m_totalProgressSteps = 0; + private ISchedulerHelper m_schedulerHelper = null; + + public JobStatus(ISchedulerHelper helper) + { + m_schedulerHelper = helper; + } + + public void IncrementProgress(string message) + { + Interlocked.Increment(ref m_progressStepsCompleted); + ShowProgress(message, false); + } + + public void IncrementTotalSteps(bool update) + { + Interlocked.Increment(ref m_totalProgressSteps); + if (update) + { + ShowProgress(null, false); + } + } + + public void DecrementTotalSteps(bool update) + { + Interlocked.Decrement(ref m_totalProgressSteps); + if (update) + { + ShowProgress(null, false); + } + } + + + public void ResetProgress(int totalSteps, bool update) + { + Interlocked.Exchange(ref m_totalProgressSteps, totalSteps); + if (update) + { + ShowProgress(null, false); + } + } + + public void CompleteProgress(string message) + { + ShowProgress(message, true); + } + + private void ShowProgress(string message, bool finished) + { + Int32 nPercent = 0; + // Progress is incremented as active vertices complete, when they're all done + // the GM still has to seal the output stream, which may take a nontrivial amount + // of time, so scale to 99% until the final progress update. + double scalingFactor = finished ? 100.0 : 99.0; + + try + { + nPercent = Convert.ToInt32(Convert.ToDouble(m_progressStepsCompleted) / Convert.ToDouble(m_totalProgressSteps) * scalingFactor); + DryadLogger.LogDebug("Set Job Progress", "{0} percent complete", nPercent); + } + catch (OverflowException e) + { + DryadLogger.LogWarning("Set Job Progress", "OverflowException calculating percent complete: {0}", e.ToString()); + nPercent = 100; + } + + if (nPercent > 100) + { + DryadLogger.LogWarning("Set Job Progress", "Percent complete greater than 100: {0} / {1} steps reported complete", m_progressStepsCompleted, m_totalProgressSteps); + nPercent = 100; + } + + try + { + if (message == null) + { + message = String.Empty; + } + else if (message.Length > 80) + { + // Job progress messages have max length of 80 + message = message.Substring(0, 80); + } + m_schedulerHelper.SetJobProgress(nPercent, message); + } + catch (Exception e) + { + DryadLogger.LogWarning("Set Job Progress", "Failed to set job progress: {0}", e.ToString()); + } + } + + public void SetProgress(int completedSteps, string message) + { + Interlocked.Exchange(ref m_progressStepsCompleted, completedSteps); + ShowProgress(message, false); + } + + } +} diff --git a/xcompute_managed/Microsoft.Research.Dryad.ClusterAdapter.csproj b/xcompute_managed/Microsoft.Research.Dryad.ClusterAdapter.csproj new file mode 100644 index 0000000..91c000e --- /dev/null +++ b/xcompute_managed/Microsoft.Research.Dryad.ClusterAdapter.csproj @@ -0,0 +1,169 @@ + + + + Debug + AnyCPU + 9.0.30729 + 2.0 + {F4B04940-67CF-4796-B6D3-3CFD38FB988A} + Library + Properties + Microsoft.Research.Dryad.ClusterAdapter + Microsoft.Research.Dryad.ClusterAdapter + v4.0 + 512 + + + + + 3.5 + + publish\ + true + Disk + false + Foreground + 7 + Days + false + false + true + 0 + 1.0.0.%2a + false + false + true + + + + true + full + false + bin\Debug\ + DEBUG;TRACE + prompt + 4 + true + AllRules.ruleset + x64 + + + pdbonly + true + bin\Release\ + TRACE + prompt + 4 + AllRules.ruleset + x64 + + + + + + AzureUtils.cs + + + Constants.cs + + + DiscLocalMonitor.cs + + + DryadTracing.cs + + + DryadVertexServiceAuthorizationManager.cs + + + IDryadVertexCallback.cs + + + IDryadVertexService.cs + + + NativeMethods.cs + + + NetShareWrapper.cs + + + ProcessPathHelper.cs + + + ProcessState.cs + + + QueryUtility.cs + + + RetryFramework.cs + + + SchedulerHelper.cs + + + SoftAffinity.cs + + + + + + + + + + + + + + + + + + 3.5 + + + 3.5 + + + 3.0 + + + 3.0 + + + + 3.5 + + + + + {09FB27C7-D1A5-4A59-B010-67D5886DD9A2} + DryadYarnBridge + + + + + False + .NET Framework 3.5 SP1 Client Profile + false + + + False + .NET Framework 3.5 SP1 + true + + + False + Windows Installer 3.1 + true + + + \ No newline at end of file diff --git a/xcompute_managed/ProcessTable.cs b/xcompute_managed/ProcessTable.cs new file mode 100644 index 0000000..bf995a5 --- /dev/null +++ b/xcompute_managed/ProcessTable.cs @@ -0,0 +1,109 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.Threading; + using System.Collections.Generic; + + // Abstraction to facilitate moving process table to, e.g., SQL in the future + internal class ProcessTable + { + private Dictionary processTable; + private object tableLock = new object(); + + public ProcessTable() + { + processTable = new Dictionary(); + } + + public void Add(int id, XComputeProcess proc) + { + lock (tableLock) + { + this.processTable.Add(id, proc); + } + } + + public bool ContainsKey(int id) + { + lock (tableLock) + { + return this.processTable.ContainsKey(id); + } + } + + public Dictionary.Enumerator GetEnumerator() + { + return this.processTable.GetEnumerator(); + } + + public void Remove(int id) + { + lock (tableLock) + { + if (this.processTable.ContainsKey(id)) + { + this.processTable[id].Dispose(); + this.processTable.Remove(id); + } + } + } + + public bool TryGetValue(int id, out XComputeProcess proc) + { + return this.processTable.TryGetValue(id, out proc); + } + + public XComputeProcess this[int id] + { + get + { + lock (tableLock) + { + if (this.processTable.ContainsKey(id)) + { + return this.processTable[id]; + } + else + { + throw new ArgumentException(String.Format("Process ID {0} not found in process table", id)); + } + } + } + + set + { + lock (tableLock) + { + this.processTable[id] = value; + } + } + } + + public object SyncRoot + { + get { return this.tableLock; } + } + + + } +} diff --git a/xcompute_managed/RequestPool.cs b/xcompute_managed/RequestPool.cs new file mode 100644 index 0000000..73c4064 --- /dev/null +++ b/xcompute_managed/RequestPool.cs @@ -0,0 +1,97 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +namespace Microsoft.Research.Dryad +{ + using System.Collections.Generic; + + internal class RequestPool + { + private List m_processRequestPool = new List(); + private object m_syncRoot = new object(); + + public RequestPool() + { + } + + public void Add(ScheduleProcessRequest req) + { + lock (SyncRoot) + { + m_processRequestPool.Add(req); + } + } + + public bool Cancel(int processId) + { + lock (SyncRoot) + { + foreach (ScheduleProcessRequest r in m_processRequestPool) + { + if (r.Id == processId) + { + return Remove(r); + } + } + } + return false; + } + + public void Clear() + { + lock (SyncRoot) + { + m_processRequestPool.Clear(); + } + } + + public bool Remove(ScheduleProcessRequest req) + { + lock (SyncRoot) + { + return m_processRequestPool.Remove(req); + } + } + + public IEnumerator GetEnumerator() + { + return m_processRequestPool.GetEnumerator(); + } + + public object SyncRoot + { + get + { + return this.m_syncRoot; + } + } + + public int Count + { + get + { + lock (SyncRoot) + { + return m_processRequestPool.Count; + } + } + } + } +} diff --git a/xcompute_managed/ScheduleProcessRequest.cs b/xcompute_managed/ScheduleProcessRequest.cs new file mode 100644 index 0000000..2d4809e --- /dev/null +++ b/xcompute_managed/ScheduleProcessRequest.cs @@ -0,0 +1,110 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.Collections.Generic; + using System.Collections.Specialized; + + internal class ScheduleProcessRequest + { + private int processId; + private string commandLine; + private string hardAffinity; + private List softAffinityList; + private StringDictionary environment; + + // We want the affinity list sorted in descending order by weight, + // so use a custom IComparer which just reverses the arguments + internal class reverseComparer : IComparer + { + int IComparer.Compare(SoftAffinity lhs, SoftAffinity rhs) + { + return Comparer.Default.Compare(rhs, lhs); + } + } + + public ScheduleProcessRequest(int id, string cl, List soft, string hard, StringDictionary env) + { + processId = id; + commandLine = cl; + softAffinityList = soft; + // Sort the affinity list in descending order by weight so that we try + // the highest weight affinity first, etc + softAffinityList.Sort(new reverseComparer()); + environment = env; + hardAffinity = hard; + } + + public int Id + { + get { return processId; } + } + + public string CommandLine + { + get { return commandLine; } + } + + public StringDictionary Environment + { + get { return environment; } + } + + public bool MustRunOnNode(string node) + { + return ((hardAffinity != null) && (String.Compare(hardAffinity, node, StringComparison.OrdinalIgnoreCase) == 0)); + } + + public bool CanRunOnNode(string node) + { + return (hardAffinity == null || (String.Compare(hardAffinity, node, StringComparison.OrdinalIgnoreCase) == 0)); + } + + public ulong GetAffinityWeightForNode(string node) + { + foreach (SoftAffinity a in softAffinityList) + { + if (String.Compare(a.Node, node, StringComparison.OrdinalIgnoreCase) == 0) + { + return a.Weight; + } + } + + return 0; + } + + public string HardAffinity + { + get { return this.hardAffinity; } + } + + public SoftAffinity AffinityAt(int i) + { + return softAffinityList[i]; + } + + public int AffinityCount + { + get { return softAffinityList.Count; } + } + } +} diff --git a/xcompute_managed/VertexCallbackService.cs b/xcompute_managed/VertexCallbackService.cs new file mode 100644 index 0000000..35d5c17 --- /dev/null +++ b/xcompute_managed/VertexCallbackService.cs @@ -0,0 +1,80 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + + +namespace Microsoft.Research.Dryad +{ + using System; + using System.ServiceModel; + using System.Collections.Generic; + using Microsoft.Research.Dryad; + + [ServiceBehavior(ConcurrencyMode = ConcurrencyMode.Multiple, InstanceContextMode = InstanceContextMode.Single, IncludeExceptionDetailInFaults = true)] + internal class VertexCallbackService : IDryadVertexCallback + { + private VertexScheduler vertexScheduler; + + public VertexCallbackService(VertexScheduler scheduler) + { + this.vertexScheduler = scheduler; + } + + + #region IVertexCallbackService + + public void FireStateChange(int processId, ProcessState newState) + { + try + { + vertexScheduler.ProcessChangeState(processId, newState); + } + catch (Exception e) + { + DryadLogger.LogError(0, e, "Failed to change state to {0} for process {1}", newState.ToString(), processId); + } + } + + public void ProcessExited(int processId, int exitCode) + { + try + { + vertexScheduler.ProcessExit(processId, exitCode); + } + catch (Exception e) + { + DryadLogger.LogError(0, e, "Failed to execute process exit for process {0}", processId); + } + } + + public void SetGetPropsComplete(int processId, ProcessInfo info, string[] propertyLabels, ulong[] propertyVersions) + { + try + { + vertexScheduler.SetGetPropsComplete(processId, info, propertyLabels, propertyVersions); + } + catch (Exception e) + { + DryadLogger.LogError(0, e, "Failed to complete set / get properties for process {0}", processId); + } + } + + #endregion + } +} diff --git a/xcompute_managed/VertexCallbackServiceHost.cs b/xcompute_managed/VertexCallbackServiceHost.cs new file mode 100644 index 0000000..0f343ec --- /dev/null +++ b/xcompute_managed/VertexCallbackServiceHost.cs @@ -0,0 +1,135 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.ServiceModel; + using System.ServiceModel.Description; + using System.Net.Security; + + class VertexCallbackServiceHost + { + private VertexCallbackService callbackService; + + private ServiceHost selfHost; + + public VertexCallbackServiceHost(VertexScheduler vs) + { + callbackService = new VertexCallbackService(vs); + } + + public bool Start(string listenUri, ISchedulerHelper schedulerHelper) + { + DryadLogger.LogMethodEntry(listenUri); + Uri baseAddress = new Uri(listenUri); + + try + { + NetTcpBinding binding = schedulerHelper.GetVertexServiceBinding(); + + selfHost = null; + + // Retry opening the service port if address is already in use + int maxRetryCount = 20; // Results in retrying for ~1 min + for (int retryCount = 0; retryCount < maxRetryCount; retryCount++) + { + try + { + //Step 1 of the hosting procedure: Create ServiceHost + selfHost = new ServiceHost(callbackService, baseAddress); + + //Step 2 of the hosting procedure: Add service endpoints. + ServiceEndpoint vertexEndpoint = selfHost.AddServiceEndpoint(typeof(IDryadVertexCallback), binding, Constants.vertexCallbackServiceName); + ServiceThrottlingBehavior stb = new ServiceThrottlingBehavior(); + stb.MaxConcurrentCalls = Constants.MaxConnections; + stb.MaxConcurrentSessions = Constants.MaxConnections; + selfHost.Description.Behaviors.Add(stb); + + //Step 3 of hosting procedure : Add a security manager + selfHost.Authorization.ServiceAuthorizationManager = new DryadVertexServiceAuthorizationManager(); + + // Step 4 of the hosting procedure: Start the service. + selfHost.Open(); + break; + } + + catch (AddressAlreadyInUseException) + { + if (selfHost != null) + { + selfHost.Abort(); + selfHost = null; + } + + // If this is the last try, dont sleep. Just rethrow exception to exit. + if (retryCount < maxRetryCount - 1) + { + DryadLogger.LogInformation("Start Vertex Callback Service", "Address already in use. Retrying..."); + System.Threading.Thread.Sleep(3000); + } + else + { + throw; + } + } + } + + DryadLogger.LogInformation("Start Vertex Callback Service", "Service Host started successfully"); + return true; + } + catch (CommunicationException ce) + { + DryadLogger.LogCritical(0, ce, "Failed to start vertex callback service"); + try + { + if (selfHost != null) + { + selfHost.Abort(); + } + } + catch + { + } + return false; + } + } + + public void Stop() + { + if (selfHost != null) + { + try + { + selfHost.Close(TimeSpan.FromMilliseconds(100)); + } + catch (Exception) + { + try + { + selfHost.Abort(); + } + + catch { } + } + } + } + } +} diff --git a/xcompute_managed/VertexScheduler.cs b/xcompute_managed/VertexScheduler.cs new file mode 100644 index 0000000..8bfb17c --- /dev/null +++ b/xcompute_managed/VertexScheduler.cs @@ -0,0 +1,1251 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.Collections.Generic; + using System.Collections.Specialized; + using System.Globalization; + using System.Threading; + using System.Diagnostics; + + using Microsoft.Research.Dryad; + + public class VertexScheduler + { + private ProcessTable processTable = null; + private DispatcherPool dispatcherPool = new DispatcherPool(); + private DispatcherPool badDispatcherPool = new DispatcherPool(); + private RequestPool requestPool = new RequestPool(); + private ISchedulerHelper schedulerHelper = SchedulerHelperFactory.GetInstance(); + private VertexCallbackServiceHost callbackServiceHost; + private JobStatus jobStatus = null; + private string baseUri; + private string replyUri; + private int JobId = 0; + private const int currentProcess = 1; + private int processId = 0; + private object dispatcherChangeLock = new object(); + + #region Public Members + + public void CancelScheduleProcess(int processId) + { + DryadLogger.LogMethodEntry(processId); + + XComputeProcess proc = null; + + if (processTable.TryGetValue(processId, out proc) == false) + { + // We don't know about this process + DryadLogger.LogWarning("Cancel process", "Attempt to cancel unknown process, id {0}", processId); + return; + } + + // Try to remove it from request pool (unassigned) + if (requestPool.Cancel(processId)) + { + DryadLogger.LogInformation("Cancel process", "Process request removed from request pool for process id {0}", processId); + return; + } + + // Handle already assigned processes + proc.Cancel(); + } + + public void CloseVertexProcess(int processId) + { + XComputeProcess proc = null; + + if (processId == currentProcess) + { + // We don't maintain an entry in the process table for the current process + return; + } + + if (processTable.TryGetValue(processId, out proc)) + { + // else if it's already assigned, release it at the node + lock (proc.SyncRoot) + { + if (proc.Dispatcher != null) + { + if (proc.CurrentState != ProcessState.Completed) + { + // This can happen when the GM cancels a process and closes the handle right afterward. + // We may not have received the state change from the cancellation yet. + // Note that the handle was closed by the GM, but do nothing else to avoid leaking a Dispatcher. + // ProcessExit will use this to know whether it also needs to close the handle. + DryadLogger.LogDebug("Close vertex process", "Closing handle for process id {0} in state {1} - delaying close until process exit", processId, proc.CurrentState); + proc.HandleClosed = true; + } + else + { + try + { + proc.Dispatcher.ReleaseProcess(processId); + } + finally + { + // Graph Manager is done with the process at this point is called so remove it from the table + processTable.Remove(processId); + } + } + } + else + { + DryadLogger.LogInformation("Close vertex process", "Dispatcher is null for process id {0} - it was either unscheduled or the dispatcher faulted", processId); + } + } + + } + else + { + DryadLogger.LogError(0, null, "Unknown process id {0}", processId); + } + + } + + public void CreateVertexProcess(int processId) + { + XComputeProcess proc = new XComputeProcess(processId); + this.processTable.Add(processId, proc); + proc.ChangeState(ProcessState.Unscheduled); + } + + public string CurrentProcessLocalPath + { + get + { + return ProcessPathHelper.ProcessPath(this.processId); + } + + } + + public string CurrentProcessRemotePath + { + get + { + return GetProcessPath(this.processId, null); + } + } + + public string[] EnumerateProcessNodes() + { + return dispatcherPool.Nodes.ToArray(); + } + + public string GetAssignedNode(int processId) + { + // TODO: Need to fix for local executor if it's supported again + if (processId == currentProcess) + { + return AzureUtils.CurrentHostName; + } + else if (this.processTable.ContainsKey(processId)) + { + return this.processTable[processId].AssignedNode; + } + else + { + return null; + } + } + + public uint GetExitCode(int processId) + { + return this.processTable[processId].ExitCode; + } + + public string GetProcessPath(int processId, string relativePath) + { + // TODO: Need to fix for local executor if it's supported again + string node = GetAssignedNode(processId); + if (String.IsNullOrEmpty(node)) + { + return null; + } + else + { + string path = String.Format(@"\\{0}\{1}\{2}\{3}\{4}", node, Constants.DscTempShare, Environment.UserName, this.JobId, processId); + + if (relativePath != null && relativePath.Length > 0) + { + path += @"\" + relativePath; + } + return path; + } + } + + public ProcessState GetProcessState(int processId) + { + if (processId == currentProcess) + { + return ProcessState.Running; + } + else if (this.processTable.ContainsKey(processId)) + { + return this.processTable[processId].CurrentState; + } + else + { + return ProcessState.Completed; + } + } + + public bool IsGraphManager + { + get + { + return (processId == 1); + } + } + + public bool IsVertex + { + get + { + return (processId > 1); + } + } + + public bool IsVertexRerun + { + get + { + return (processId == 0); + } + } + + public JobStatus JobStatus + { + get { return this.jobStatus; } + } + + public void NotifyStateChange(int processId, long timeoutInterval, ProcessState targetState, StateChangeEventHandler handler) + { + this.processTable[processId].AddStateChangeListener(targetState, timeoutInterval, handler); + } + + public void ProcessChangeState(int processId, ProcessState newState) + { + XComputeProcess proc = null; + if (this.processTable.TryGetValue(processId, out proc)) + { + DryadLogger.LogDebug("Process Change State", "Process {0} changed to state {1}", processId, newState); + if (newState == ProcessState.Running) + { + // Need to ensure that the process transitions to AssignedToNode before + // transitioning to Running, or the GM gets mildly confused + ThreadPool.QueueUserWorkItem(new WaitCallback(proc.TransitionToRunning)); + } + else + { + proc.ChangeState(newState); + } + } + } + + public void ProcessExit(int processId, int exitCode) + { + ProcessExit(processId, exitCode, false); + } + + public bool ProcessCancelled(int processId) + { + if (processTable.ContainsKey(processId)) + { + return processTable[processId].Cancelled; + } + return false; + } + + public bool ScheduleProcess(int processId, string commandLine, List softAffinities, string hardAffinity, StringDictionary environment) + { + bool retVal = false; + + processTable[processId].SetIdAndVersion(commandLine); + DryadLogger.LogInformation("Schedule process", "Internal ID {0} corresponds to vertex {1}.{2}", processId, processTable[processId].GraphManagerId, processTable[processId].GraphManagerVersion); + DryadLogger.LogInformation("Schedule process", "Internal ID {0} has a command line of {1}", processId, + commandLine); + + if (environment == null) + { + environment = new StringDictionary(); + } + environment[Constants.jobManager] = AzureUtils.CurrentHostName; + environment["CCP_DRYADPROCID"] = processId.ToString(CultureInfo.InvariantCulture); + + ScheduleProcessRequest req = new ScheduleProcessRequest(processId, commandLine, softAffinities, hardAffinity, environment); + Dispatcher dispatcher = null; + + // Take the request pool lock in case a ProcessExit comes in after we've looked for a node + // but before the request has been added to the request pool. + lock (requestPool.SyncRoot) + { + if (!FindNodeForRequest(req, out dispatcher)) + { + if (dispatcherPool.Count > 0) + { + DryadLogger.LogDebug("Schedule Process", "No nodes available, adding process {0} to request pool", processId); + requestPool.Add(req); + return true; + } + else + { + DryadLogger.LogCritical(0, null, "No available dispatchers"); + return false; + } + } + } + + // Found a Dispatcher, schedule the request outside of the lock + retVal = ScheduleProcess(req, dispatcher); + if (!retVal) + { + processTable[processId].ChangeState(ProcessState.SchedulingFailed); + dispatcher.Release(); + } + + return retVal; + } + + public bool SetGetProps(int processId, ProcessPropertyInfo[] infos, string blockOnLabel, ulong blockOnVersion, long maxBlockTime, string getPropLabel, bool ProcessStatistics, GetSetPropertyEventHandler handler) + { + if (this.processTable.ContainsKey(processId)) + { + if (infos != null && infos.Length > 0) + { + // Only add for the first property info since we only want to fire completion once per request + this.processTable[processId].AddPropertyListener(infos[0].propertyLabel, infos[0].propertyVersion, handler); + } + else if (getPropLabel != null && getPropLabel.Length > 0) + { + this.processTable[processId].AddPropertyListener(getPropLabel, 0, handler); + } + else + { + DryadLogger.LogError(0, null, "infos and getPropLabel both empty"); + return false; + } + + lock (this.processTable[processId].SyncRoot) + { + if (this.processTable[processId].Dispatcher != null) + { + if (this.processTable[processId].Dispatcher.SetGetProps(replyUri, processId, infos, blockOnLabel, blockOnVersion, maxBlockTime, getPropLabel, ProcessStatistics)) + { + return true; + } + } + } + + // Keep returning error to GM and let its fault-tolerance kick in + if (dispatcherPool.Count == 0) + { + DryadLogger.LogCritical(0, null, "All dispatchers are faulted."); + } + return false; + } + else + { + DryadLogger.LogError(0, null, "process id {0} not found in process table", processId); + return false; + } + } + + public void SetGetPropsComplete(int processId, ProcessInfo info, string[] propertyLabels, ulong[] propertyVersions) + { + XComputeProcess proc = null; + if (processTable.TryGetValue(processId, out proc)) + { + proc.SetGetPropsComplete(info, propertyLabels, propertyVersions); + } + else + { + DryadLogger.LogError(0, null, "process id {0} not found in process table", processId); + } + } + + /// + /// When called from the GM, shuts down all the vertex services and closes the communication channels. + /// When called from the vertex host, closes the communication channel to the local vertex service. + /// + /// Code to pass to the vertex services. Currently unused. + public void Shutdown(uint ShutdownCode) + { + DryadLogger.LogMethodEntry(ShutdownCode); + + // If this is the GM, invoke Shutdown asynchronously to improve job shutdown time + if (processId == 1) + { + // We no longer need to listen for task state changes + schedulerHelper.StopTaskMonitorThread(); + + lock (dispatcherPool.SyncRoot) + { + foreach (Dispatcher disp in dispatcherPool) + { + DryadLogger.LogDebug("Shutdown", "Calling Shutdown on dispatcher for node {0}", disp.NodeName); + Stopwatch sw = new Stopwatch(); + sw.Start(); + try + { + disp.Shutdown(0); + sw.Stop(); + } + catch (Exception e) + { + sw.Stop(); + DryadLogger.LogError(0, e, "Exception calling Shutdown on dispatcher for node {0}", disp.NodeName); + } + + DryadLogger.LogDebug("Shutdown", "Dispatcher.Shutdown took {0} ms", sw.ElapsedMilliseconds); + } + } + } + + // Dispose the SchedulerHelper instance to clean up resources + schedulerHelper.Dispose(); + schedulerHelper = null; + + // Clean out the dispatcher pool (this also disposes all dispatchers) + dispatcherPool.Clear(); + + // Stop the callback service + callbackServiceHost.Stop(); + + DryadLogger.LogMethodExit(); + } + + public bool WaitForStateChange(int processId, long timeoutInterval, ProcessState targetState) + { + DryadLogger.LogDebug("Wait for state change", "Process id: {0}, targetState: {1}", processId, targetState); + if (this.processTable.ContainsKey(processId)) + { + using (ManualResetEvent waitEvent = new ManualResetEvent(false)) + { + this.processTable[processId].AddStateChangeWaiter(targetState, waitEvent); + return waitEvent.WaitOne(TimeSpan.FromMilliseconds(timeoutInterval / 10), false); + } + } + else + { + DryadLogger.LogError(0, null, "process id {0} not found in process table", processId); + return false; + } + + } + + #endregion + + #region Private Members + + private VertexScheduler(ProcessTable table) + { + this.processTable = table; + this.jobStatus = new JobStatus(schedulerHelper); + + // These environment variables will not be set when the vertex rerun command is executed + // Set them to 0 so we can use them later to detect that we're rerunning a vertex outside of an HPC job + if (!Int32.TryParse(Environment.GetEnvironmentVariable("CCP_JOBID"), out JobId)) + { + JobId = 0; + } + if (!Int32.TryParse(Environment.GetEnvironmentVariable("CCP_DRYADPROCID"), out processId)) + { + processId = 0; + } + this.baseUri = String.Format(Constants.vertexCallbackAddrFormat, AzureUtils.CurrentHostName, processId); + this.replyUri = this.baseUri + Constants.vertexCallbackServiceName; + this.callbackServiceHost = new VertexCallbackServiceHost(this); + } + + /// + /// Create a new dispatcher and add to the good dispatcher pool. + /// + /// HPC Task Id + /// Name of node this dispatcher is for + /// State of task when dispatcher is created (always Running now) + /// Dispatcher that was added, or null if a dispatcher already exists in the good pool for specified node + private Dispatcher AddDispatcher(int taskid, string node, VertexTaskState state) + { + VertexComputeNode cn = new VertexComputeNode(); + cn.instanceId = taskid; + cn.ComputeNode = node; + cn.State = state; + Dispatcher d = new Dispatcher(schedulerHelper, cn); + d.FaultedEvent += new DispatcherFaultedEventHandler(OnDispatcherFaulted); + + if (!dispatcherPool.Add(d)) + { + // There's already a dispatcher for this node + d.Dispose(); + d = null; + } + return d; + } + + private bool FindRequestForNode(string node, out ScheduleProcessRequest req) + { + req = null; + ulong maxAffinity = 0; + bool result = false; + Dispatcher dispatcher = null; + Stopwatch swTotal = new Stopwatch(); + Stopwatch swSearch = new Stopwatch(); + Stopwatch swBlock = new Stopwatch(); + int requestCount = 0; + + swTotal.Start(); + if (dispatcherPool.TryReserveDispatcher(node, out dispatcher)) + { + swBlock.Start(); + lock (requestPool.SyncRoot) + { + swBlock.Stop(); + swSearch.Start(); + requestCount = requestPool.Count; + if (requestCount != 0) + { + foreach (ScheduleProcessRequest r in requestPool) + { + // Skip any lingering processes which have been cancelled. + if (processTable.ContainsKey(r.Id) && processTable[r.Id].Cancelled) + { + continue; + } + + if (r.MustRunOnNode(node)) + { + req = r; + DryadLogger.LogDebug("Find Request for Node", "process {0} has hard affinity constraint for node {1}", req.Id, node); + break; + } + else if (r.CanRunOnNode(node)) + { + ulong thisAffinity = r.GetAffinityWeightForNode(node); + if (thisAffinity == 0 && req == null) + { + req = r; + DryadLogger.LogDebug("Find Request for Node", "Process {0} has 0 affinity constraint for node {1} but no other process has been selected yet", r.Id, node); + } + else if (thisAffinity > maxAffinity) + { + maxAffinity = thisAffinity; + req = r; + DryadLogger.LogDebug("Find Request for Node", "Process {0} with affinity constraint {1} for node {2} larger than previous max", r.Id, thisAffinity, node); + } + } + } + } + swSearch.Stop(); + + + if (req != null) + { + requestPool.Remove(req); + DryadLogger.LogDebug("Find Request for Node", "Found request {0} for node {1}", req.Id, node); + result = true; + } + else + { + DryadLogger.LogDebug("Find Request for Node", "Did not find any requests for node {0}", node); + dispatcher.Release(); + result = false; + } + } + } + swTotal.Stop(); + + DryadLogger.LogInformation("Find Request for Node", "Searching {0} requests. Block {1} ms. Inner search {2} ms. Total elapsed time {3} ms.", + requestCount, swBlock.ElapsedMilliseconds, swSearch.ElapsedMilliseconds, swTotal.ElapsedMilliseconds); + return result; + } + + private bool FindNodeForRequest(ScheduleProcessRequest req, out Dispatcher dispatcher) + { + dispatcher = null; + if (req.HardAffinity != null) + { + if (dispatcherPool.TryReserveDispatcher(req.HardAffinity, out dispatcher)) + { + return true; + } + else + { + return false; + } + } + else + { + // First try soft affinity in decreasing order (assumes Soft Affinity list in req is sorted descending by weight) + + // Keep a map of the nodes we've already tried, because Dryad adds each affinity twice + // once for the node and once for the "pod" + Dictionary attemptedNodes = new Dictionary(); + int count = 0; + + for (int i = 0; i < req.AffinityCount; i++) + { + if (attemptedNodes.ContainsKey(req.AffinityAt(i).Node.ToUpper())) + { + continue; + } + attemptedNodes.Add(req.AffinityAt(i).Node.ToUpper(), true); + count++; + + if (dispatcherPool.TryReserveDispatcher(req.AffinityAt(i).Node, out dispatcher)) + { + DryadLogger.LogDebug("Find Node For Request", "process {0} satisfied affinity constraint: node {1}, weight {2}", req.Id, req.AffinityAt(i).Node, req.AffinityAt(i).Weight); + return true; + } + + DryadLogger.LogDebug("Find Node For Request", "process {0} did not satisfy affinity constraint: node {1}, weight {2}", req.Id, req.AffinityAt(i).Node, req.AffinityAt(i).Weight); + } + + // If we get this far and AffinityCount > 0, then we failed to satisfy the affinity constraints + // log a message so we can more easily detect this situation + if (count > 0) + { + DryadLogger.LogInformation("Find Node For Request", "process {0} failed to satisfy any of {1} affinity constraints", req.Id, count); + } + + // Finally try any available node + lock (dispatcherPool.SyncRoot) + { + foreach (Dispatcher d in dispatcherPool) + { + if (req.CanRunOnNode(d.NodeName)) + { + if (d.Reserve()) + { + dispatcher = d; + return true; + } + } + } + } + } + + return false; + } + + private void Initialize() + { + if (IsVertexRerun) + { + // Vertex rerun command is being executed, don't create any dispatchers + return; + } + else if (IsGraphManager) + { + VertexChangeEventHandler evtHandler = new VertexChangeEventHandler(OnVertexChanged); + schedulerHelper.OnVertexChange += evtHandler; + schedulerHelper.StartTaskMonitorThread(); + + if (!schedulerHelper.WaitForTasksReady()) + { + // The graph manager will abort because we will not return any vertex nodes + DryadLogger.LogCritical(0, null, "Unable to begin job: too many vertex tasks failed"); + schedulerHelper.OnVertexChange -= evtHandler; + dispatcherPool.Clear(); + return; + } + + // TODO: we need to be able to turn this off + //ThreadPool.QueueUserWorkItem(new WaitCallback(VertexMonitorThreadFunc)); + } + // IsVertex + else + { + // On vertex nodes, create a dispatcher for the local vertex service and add an entry to the process table for the local process + Dispatcher d = AddDispatcher(Int32.Parse(Environment.GetEnvironmentVariable(Constants.taskIdEnvVar)), "localhost", VertexTaskState.Running); + XComputeProcess proc = new XComputeProcess(processId); + lock (proc.SyncRoot) + { + proc.Dispatcher = d; + } + + this.processTable.Add(processId, proc); + } + callbackServiceHost.Start(this.baseUri, this.schedulerHelper); + } + + /// + /// When a dispatcher faults due to a communication error (as opposed to a task failure) + /// it is moved to the bad dispatcher pool and a timer is set to retry the dispatcher + /// after a predetermined interval. This method is called when that timer fires. + /// + /// The dispatcher to be retried + private void RetryFaultedDispatcher(object state) + { + DryadLogger.LogMethodEntry(); + + Dispatcher newDispatcher = null; + Dispatcher d = state as Dispatcher; + if (d != null) + { + DryadLogger.LogDebug("Retry faulted dispatcher", "Creating new dispatcher for node {0}", d.NodeName); + + lock (dispatcherChangeLock) + { + // Add a new dispatcher for this node + newDispatcher = AddDispatcher(d.TaskId, d.NodeName, VertexTaskState.Running); + + // Get rid of the old dispatcher + badDispatcherPool.Remove(d); + d.Dispose(); + } + + if (newDispatcher != null) + { + // Look for a request to run on this node + ThreadPool.QueueUserWorkItem(new WaitCallback(this.FindRequestForNodeThreadFunc), newDispatcher); + } + } + else + { + DryadLogger.LogWarning("Retry faulted dispatcher", "state parameter not a valid dispatcher"); + } + DryadLogger.LogMethodExit(); + } + + /// + /// This is the event handler for the Dispatcher.FaultedEvent event. + /// The FaultedEvent event is raised when a task transitions out of a running + /// state, or when there is a communication error wich does not succeed after N retries. + /// + /// The dispatcher raising the faulted event + /// Not used + private void OnDispatcherFaulted(object sender, EventArgs e) + { + Dispatcher d = sender as Dispatcher; + if (d != null) + { + DryadLogger.LogWarning("Dispatcher Faulted", "Dispatcher for node '{0}' faulted due to {1}", d.NodeName, d.SchedulerTaskFailed ? "failed YARN Container" : "communication error"); + + lock (dispatcherChangeLock) + { + // Remove from dispatcher pool + dispatcherPool.Remove(d); + + if (d.SchedulerTaskFailed) + { + // If we're faulting because the scheduler task transitioned to + // a non-running state, then we want to completely remove the dispatcher + badDispatcherPool.Remove(d); + d.Dispose(); + } + else + { + // If we're faulting because of a communication error, then we want to + // add to bad dispatcher pool so that we'll retry it again + badDispatcherPool.Add(d); + + // Set up a timer to move this dispatcher out of the bad pool in the future + d.SetRetryTimer(new TimerCallback(this.RetryFaultedDispatcher)); + } + } + + + if (d.CurrentProcess != Dispatcher.InvalidProcessId) + { + ProcessExit(d.CurrentProcess, unchecked((int)Constants.DrError_ProcessingInterrupted), true); + } + + } + } + + private void CheckForOutOfDispatchers() + { + if (badDispatcherPool.Count == 0 && dispatcherPool.Count == 0) + { + DryadLogger.LogError(0, null, "All vertex tasks have failed"); + lock (requestPool.SyncRoot) + { + foreach (ScheduleProcessRequest r in requestPool) + { + XComputeProcess proc; + if (processTable.TryGetValue(r.Id, out proc)) + { + DryadLogger.LogInformation("No Valid Dispatchers", "Transitioning process {0} to state {1} because all vertex tasks failed", r.Id, ProcessState.SchedulingFailed.ToString()); + proc.ChangeState(ProcessState.SchedulingFailed); + } + else + { + DryadLogger.LogCritical(0, null, "Failed to find process {0} in process table, exiting application.", r.Id); + throw new ApplicationException(String.Format("All vertex tasks failed and unable to cancel pending request id {0}", r.Id)); + } + } + + requestPool.Clear(); + } + } + } + + /// + /// This event handler is called from ISchedulerHelper task monitoring thread in response + /// to an HPC Task state change. + /// + /// Not used + /// Information about the task state transition + private void OnVertexChanged(object sender, VertexChangeEventArgs e) + { + Dispatcher oldDispatcher = null; + Dispatcher newDispatcher = null; + bool addNewDispatcher = false; + bool faultOldDispatcher = false; + + lock (dispatcherChangeLock) + { + bool dispatcherFound = dispatcherPool.GetByTaskId(e.Id, out oldDispatcher); + if (!dispatcherFound) + { + // Check to see if this dispatcher was already faulted due to a communication error + dispatcherFound = badDispatcherPool.GetByTaskId(e.Id, out oldDispatcher); + } + + + // Task state change + if (e.OldState != e.NewState) + { + // Transitioning to, e.g., queued + if (e.NewState < VertexTaskState.Running) + { + DryadLogger.LogInformation("Vertex Task State Change", "Task {0} transitioned to waiting", e.Id); + + // If there is a dispatcher for the task, then the task has previously been running. + // Now it's not, so we need to fault the dispatcher. + if (dispatcherFound) + { + DryadLogger.LogWarning("Vertex Task State Change", "Previously running task {0} transitioned to waiting", e.Id); + faultOldDispatcher = true; + } + } + // Transition to running + else if (e.NewState == VertexTaskState.Running) + { + if (!dispatcherFound) + { + // No dispatcher for task, add a new one + DryadLogger.LogInformation("Vertex Task State Change", "Task {0} transitioned to running", e.Id); + addNewDispatcher = true; + } + else if (String.Compare(e.OldNode, e.NewNode, StringComparison.OrdinalIgnoreCase) != 0) + { + // Dispatcher found, but task is now on a new node + // 1. Make sure old dispatcher is faulted. + // 2. Add a new one for the new node + DryadLogger.LogInformation("Vertex Task State Change", "Running task {0} assigned to new node", e.Id); + + faultOldDispatcher = true; + addNewDispatcher = true; + } + else + { + // Dispatcher found, task is on same node + DryadLogger.LogWarning("Vertex Task State Change", "Change notification for running task {0}, but state and node are unchanged in notification", e.Id); + } + } + // Job is exiting, nothing to do + else if (e.NewState == VertexTaskState.Finished) + { + DryadLogger.LogDebug("Vertex Task State Change", "Task {0} transitioned to finished", e.Id); + } + // Failed or Cancelled + else + { + DryadLogger.LogWarning("Vertex Task State Change", "Task {0} transitioned to failed or cancelled", e.Id); + + // Fault dispatcher if it isn't already + if (dispatcherFound) + { + faultOldDispatcher = true; + } + } + } + // Node change + else if (String.Compare(e.OldNode, e.NewNode, StringComparison.OrdinalIgnoreCase) != 0) + { + if (e.NewState == VertexTaskState.Running) + { + DryadLogger.LogDebug("Vertex Task State Change", "Task {0} moved from node {1} to node {2}", e.Id, e.OldNode, e.NewNode); + if (dispatcherFound) + { + faultOldDispatcher = true; + addNewDispatcher = true; + } + } + } + // Running -> Queued -> Running, e.g. + else if (e.OldRequeueCount < e.NewRequeueCount) + { + DryadLogger.LogDebug("Vertex Task State Change", "Task {0} node {1} state {2} unchanged from previous state: likely missed a state change notification.", + e.Id, e.NewNode, e.NewState.ToString()); + + // Was task running previously? If so, fault the old dispatcher. + if (dispatcherFound) + { + faultOldDispatcher = true; + } + + // Is task running now? If so, create a new dispatcher to re-establish connection. + if (e.NewState == VertexTaskState.Running) + { + addNewDispatcher = true; + } + } + } + + if (faultOldDispatcher) + { + oldDispatcher.RaiseFaultedEvent(true); + } + + if (addNewDispatcher) + { + newDispatcher = AddDispatcher(e.Id, e.NewNode, e.NewState); + if (newDispatcher != null) + { + // Look for new request for node + ThreadPool.QueueUserWorkItem(new WaitCallback(FindRequestForNodeThreadFunc), newDispatcher); + } + else + { + DryadLogger.LogError(0, null, "Failed to add new dispatcher for node {0}", e.NewNode); + } + } + + if (faultOldDispatcher) + { + // Check to see if we have any dispatchers left. If not, we need to fail + // everything in the request pool. + CheckForOutOfDispatchers(); + } + } + + private void ProcessExit(int processId, int exitCode, bool dispatcherFaulted) + { + DryadLogger.LogMethodEntry(processId, exitCode, dispatcherFaulted); + try + { + XComputeProcess proc = null; + if (processTable.TryGetValue(processId, out proc)) + { + DryadLogger.LogInformation("Process Exit", "found process {0} for vertex {1}.{2}", processId, proc.GraphManagerId, proc.GraphManagerVersion); + + // Update process + if (proc.CurrentState < ProcessState.AssignedToNode && dispatcherFaulted) + { + // If we haven't yet reached AssignedToNode and the dispatcher faulted, then scheduling failed + DryadLogger.LogInformation("Process Exit", "Process {0} was in state {1}", processId, proc.CurrentState.ToString()); + + proc.ChangeState(ProcessState.SchedulingFailed); + } + else if (proc.CurrentState <= ProcessState.Running) + { + // If we're at AssignedToNode or Running, then the process either did really complete + // or the Vertex Service failed to start it - so this is not a scheduling error and the + // exit code has meaning. + DryadLogger.LogInformation("Process Exit", "Process {0} was in state {1}", processId, proc.CurrentState.ToString()); + proc.ExitCode = (uint)exitCode; + proc.ChangeState(ProcessState.Completed); + } + else + { + // we've already reached this state previously, and this call should be idempotent + DryadLogger.LogInformation("Process Exit", "Process {0} was already in state {1}", processId, proc.CurrentState.ToString()); + DryadLogger.LogMethodExit(); + return; + } + + if (proc.HandleClosed) + { + // This happens if a close handle comes from the GM + // before we've received notification that the process exited. + // For example, when the GM does: + // - Cancel + // - CloseHandle + // in rapid succession. + DryadLogger.LogDebug("Process Exit", "Delayed close handle for process {0}", processId); + CloseVertexProcess(processId); + } + + lock (proc.SyncRoot) + { + if (dispatcherFaulted) + { + DryadLogger.LogWarning("Process Exit", "Process exiting due to faulted dispatcher"); + proc.Dispatcher = null; + } + else if (proc.Dispatcher != null) + { + // Release dispatcher + DryadLogger.LogInformation("Process Exit", "Releasing dispatcher"); + proc.Dispatcher.Release(); + + // Look for new request for node + ThreadPool.QueueUserWorkItem(new WaitCallback(this.FindRequestForNodeThreadFunc), proc.Dispatcher); + } + } + } + else + { + DryadLogger.LogError(0, null, "Unknown process id {0}", processId); + } + } + catch (Exception e) + { + DryadLogger.LogError(0, e, "Failed to transition vertex process {0} to exited gracefully", processId); + } + DryadLogger.LogMethodExit(); + } + + /// + /// This is the callback method for the async ScheduleProcess operation. + /// + /// AsyncState member is the Dispatcher that initiated the operation + private void ScheduleProcessCallback(IAsyncResult asyncResult) + { + try + { + Dispatcher d = asyncResult.AsyncState as Dispatcher; + if (d != null) + { + int currentProcessId = d.CurrentProcess; + SchedulingResult schedulingResult = d.EndScheduleProcess(asyncResult); + if (schedulingResult == SchedulingResult.Failure) + { + // This indicates there was a fatal error (Exception or FaultException) + + // Change process state to scheduling failed + DryadLogger.LogWarning("Schedule Process", "Async operation did not complete successfully for process {0} on node {1}", currentProcessId, d.NodeName); + if (currentProcessId != Dispatcher.InvalidProcessId) + { + // Since we will still be in the Unscheduled state, the return code will be ignored by + // ProcessExit, but we'll pass a nonzero exit code just to be sure we don't + // confuse the GM in case of a race condition. + ProcessExit(currentProcessId, unchecked((int)Constants.DrError_ProcessingInterrupted)); + } + d.Release(); + } + else if (schedulingResult == SchedulingResult.CommunicationError) + { + // This indicates that there was an error communicating with the node. + + // We need to fault the dispatcher so that subsequent attemps don't try to use it again. + // Faulting the dispatcher will take care of exiting the current process, so no need to + // do it here. + + DryadLogger.LogWarning("Schedule Process", "Async operation failed due to communication error for process {0} on node {1}", currentProcessId, d.NodeName); + d.RaiseFaultedEvent(false); + } + else if (schedulingResult == SchedulingResult.Pending) + { + // Nothing to do for this case -it indicates there was a problem and we're retrying + } + else if (schedulingResult == SchedulingResult.Success) + { + // Process has been scheduled, transition to AssignedToNode state + if (currentProcessId != Dispatcher.InvalidProcessId) + { + DryadLogger.LogInformation("Schedule Process", "Process {0} successfully scheduled on node {1}", currentProcessId, d.NodeName); + processTable[currentProcessId].ChangeState(ProcessState.AssignedToNode); + } + } + } + else + { + DryadLogger.LogWarning("Schedule Process", "Dispatcher not passed correctly to callback"); + } + } + catch (Exception e) + { + DryadLogger.LogWarning("Schedule Process", "Schedule process callback threw exception: {0}", e.ToString()); + } + } + + private bool ScheduleProcess(ScheduleProcessRequest request, Dispatcher dispatcher) + { + lock (processTable.SyncRoot) + { + lock (this.processTable[request.Id].SyncRoot) + { + processTable[request.Id].Dispatcher = dispatcher; + } + } + + if (dispatcher.ScheduleProcess(replyUri, request, new AsyncCallback(this.ScheduleProcessCallback))) + { + DryadLogger.LogInformation("Schedule Process", "Began asynchronous scheduling of process {0} on node '{1}': '{2}'", request.Id, dispatcher.NodeName, request.CommandLine); + return true; + } + else + { + DryadLogger.LogWarning("Schedule Process", "Failed to begin asynchronous scheduling of process {0} on node '{1}'", request.Id, dispatcher.NodeName); + return false; + } + } + + private void FindRequestForNodeThreadFunc(Object state) + { + Dispatcher d = state as Dispatcher; + + ScheduleProcessRequest req = null; + // FindRequestForNode takes a lock on the request pool + try + { + if (d != null) + { + if (FindRequestForNode(d.NodeName, out req)) + { + if (!ScheduleProcess(req, d)) + { + DryadLogger.LogWarning("Schedule Request on Node", "Failed to schedule process {0} on node {1}", req.Id, d.NodeName); + processTable[req.Id].ChangeState(ProcessState.SchedulingFailed); + d.Release(); + } + } + } + } + catch (NullReferenceException) + { + if (d == null) + { + // Dispatcher has been faulted and set to null. Ignore. + } + else + { + throw; + } + } + } + +#if false + // This thread is not currently used + private void VertexMonitorThreadFunc(Object state) + { + do + { + using (System.IO.StreamWriter sw = new System.IO.StreamWriter("vertex_health.txt", true)) + { + sw.AutoFlush = true; + + List dlist = new List(); + lock (dispatcherPool.SyncRoot) + { + foreach (Dispatcher d in dispatcherPool) + { + dlist.Add(d); + } + } + + Process proc = Process.GetCurrentProcess(); + + sw.WriteLine(""); + sw.WriteLine(" {0}", DateTime.Now); + sw.WriteLine(" {0}", proc.MainModule); + sw.WriteLine(" {0}", proc.StartTime); + sw.WriteLine(" {0}", proc.VirtualMemorySize64); + sw.WriteLine(" {0}", proc.WorkingSet64); + + foreach (Dispatcher d in dlist) + { + sw.WriteLine(" "); + sw.WriteLine(" {0}", d.NodeName); + sw.WriteLine(" {0}", d.Idle); + sw.WriteLine(" {0}", d.Faulted); + sw.WriteLine(" {0}", d.ConnectionAttempts); + + if (!d.Faulted) + { + VertexStatus status = d.CheckStatus(); + sw.WriteLine(" {0}", status.serviceIsAlive); + if (status.serviceIsAlive) + { + sw.WriteLine(" {0}", status.runningProcessCount); + sw.WriteLine(" {0}", status.freePhysicalMemory); + sw.WriteLine(" {0}", status.freeVirtualMemory); + foreach (KeyValuePair kvp in status.freeDiskSpaces) + { + //sw.WriteLine(" Disk: {0}, Free space = {1}", kvp.Key, kvp.Value); + } + + foreach (VertexProcessInfo vpi in status.vps) + { + sw.WriteLine(" "); + sw.WriteLine(" {0}", vpi.commandLine); + sw.WriteLine(" {0}", vpi.DryadId); + sw.WriteLine(" {0}", vpi.State); + sw.WriteLine(" "); + } + } + } + sw.WriteLine(" "); + + } + sw.WriteLine(""); + + } + + // Let the GM fault tolerance handle this + if (dispatcherPool.Count == 0) + { + DryadLogger.LogCritical(0, null, "No reachable dispatchers"); + } + + Thread.Sleep(1000 * 60); + + } while (true); + } +#endif + + #endregion + + + #region Factory Methods + + private static VertexScheduler vertexScheduler = null; + private static Object factoryLock = new Object(); + + public static VertexScheduler GetInstance() + { + if (vertexScheduler == null) + { + lock (factoryLock) + { + if (vertexScheduler == null) + { + ProcessTable processTable = new ProcessTable(); + + vertexScheduler = new VertexScheduler(processTable); + + vertexScheduler.Initialize(); + } + } + } + return vertexScheduler; + } + + #endregion + } +} diff --git a/xcompute_managed/VertexServiceClient.cs b/xcompute_managed/VertexServiceClient.cs new file mode 100644 index 0000000..cdf3ed8 --- /dev/null +++ b/xcompute_managed/VertexServiceClient.cs @@ -0,0 +1,201 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +namespace Microsoft.Research.Dryad +{ + using System; + using System.ServiceModel; + using System.Runtime.Serialization; + using System.Collections.Specialized; + + public partial interface IDryadVertexService + { + // + // Shutdown + // + [OperationContract(AsyncPattern = true, IsOneWay = true, Action = "http://hpc.microsoft.com/dryadvertex/shutdown")] + IAsyncResult BeginShutdown(uint ShutdownCode, AsyncCallback callback, object state); + + void EndShutdown(IAsyncResult result); + + // + // ReleaseProcess + // + [OperationContract(AsyncPattern = true, IsOneWay = true, Action = "http://hpc.microsoft.com/dryadvertex/releaseprocess")] + IAsyncResult BeginReleaseProcess(int processId, AsyncCallback callback, object state); + + void EndReleaseProcess(IAsyncResult result); + + // + // ScheduleProcess + // + [OperationContract(AsyncPattern = true, Action = "http://hpc.microsoft.com/dryadvertex/scheduleprocess")] + IAsyncResult BeginScheduleProcess(string replyUri, int processId, string commandLine, StringDictionary environment, AsyncCallback callback, object state); + + bool EndScheduleProcess(IAsyncResult result); + + // + // CancelScheduleProcess + // + [OperationContract(AsyncPattern = true, IsOneWay = true, Action = "http://hpc.microsoft.com/dryadvertex/cancelscheduleprocess")] + IAsyncResult BeginCancelScheduleProcess(int processId, AsyncCallback callback, object state); + + void EndCancelScheduleProcess(IAsyncResult result); + + // + // SetGetProps + // + [OperationContract(AsyncPattern = true, Action = "http://hpc.microsoft.com/dryadvertex/setgetprops")] + IAsyncResult BeginSetGetProps(string replyUri, int processId, ProcessPropertyInfo[] infos, string blockOnLabel, ulong blockOnVersion, long maxBlockTime, string getPropLabel, bool ProcessStatistics, AsyncCallback callback, object state); + + bool EndSetGetProps(IAsyncResult result); + + // + // CheckStatus + // + [OperationContract(AsyncPattern = true, Action = "http://hpc.microsoft.com/dryadvertex/checkstatus")] + IAsyncResult BeginCheckStatus(AsyncCallback callback, object state); + + VertexStatus EndCheckStatus(IAsyncResult result); + + // + // Initialize + // + [OperationContract(AsyncPattern = true, IsOneWay = true, Action = "http://hpc.microsoft.com/dryadvertex/initialize")] + IAsyncResult BeginInitialize(StringDictionary vertexEndpointAddresses, AsyncCallback callback, object state); + + void EndInitialize(IAsyncResult result); + + } + + public class VertexServiceClient : ClientBase, IDryadVertexService + { + + public VertexServiceClient(System.ServiceModel.Channels.Binding binding, System.ServiceModel.EndpointAddress remoteAddress) : + base(binding, remoteAddress) + { + } + + public void CancelScheduleProcess(int processId) + { + base.Channel.CancelScheduleProcess(processId); + } + + public IAsyncResult BeginCancelScheduleProcess(int processId, AsyncCallback callback, object state) + { + return base.Channel.BeginCancelScheduleProcess(processId, callback, state); + } + + public void EndCancelScheduleProcess(IAsyncResult result) + { + base.Channel.EndCancelScheduleProcess(result); + } + + public VertexStatus CheckStatus() + { + return base.Channel.CheckStatus(); + } + + public IAsyncResult BeginCheckStatus(AsyncCallback callback, object state) + { + return base.Channel.BeginCheckStatus(callback, state); + } + + public VertexStatus EndCheckStatus(IAsyncResult result) + { + return base.Channel.EndCheckStatus(result); + } + + public void Initialize(StringDictionary vertexEndpointAddresses) + { + base.Channel.Initialize(vertexEndpointAddresses); + } + + public IAsyncResult BeginInitialize(StringDictionary vertexEndpointAddresses, AsyncCallback callback, object state) + { + return base.Channel.BeginInitialize(vertexEndpointAddresses, callback, state); + } + + public void EndInitialize(IAsyncResult result) + { + base.Channel.EndInitialize(result); + } + + public void ReleaseProcess(int processId) + { + base.Channel.ReleaseProcess(processId); + } + + public IAsyncResult BeginReleaseProcess(int processId, AsyncCallback callback, object state) + { + return base.Channel.BeginReleaseProcess(processId, callback, state); + } + + public void EndReleaseProcess(IAsyncResult result) + { + base.Channel.EndReleaseProcess(result); + } + + public bool ScheduleProcess(string replyUri, int processId, string commandLine, System.Collections.Specialized.StringDictionary environment) + { + return base.Channel.ScheduleProcess(replyUri, processId, commandLine, environment); + } + + public IAsyncResult BeginScheduleProcess(string replyUri, int processId, string commandLine, System.Collections.Specialized.StringDictionary environment, AsyncCallback callback, object state) + { + return base.Channel.BeginScheduleProcess(replyUri, processId, commandLine, environment, callback, state); + } + + public bool EndScheduleProcess(IAsyncResult result) + { + return base.Channel.EndScheduleProcess(result); + } + + public bool SetGetProps(string replyUri, int processId, ProcessPropertyInfo[] infos, string blockOnLabel, ulong blockOnVersion, long maxBlockTime, string getPropLabel, bool ProcessStatistics) + { + return base.Channel.SetGetProps(replyUri, processId, infos, blockOnLabel, blockOnVersion, maxBlockTime, getPropLabel, ProcessStatistics); + } + + public IAsyncResult BeginSetGetProps(string replyUri, int processId, ProcessPropertyInfo[] infos, string blockOnLabel, ulong blockOnVersion, long maxBlockTime, string getPropLabel, bool ProcessStatistics, AsyncCallback callback, object state) + { + return base.Channel.BeginSetGetProps(replyUri, processId, infos, blockOnLabel, blockOnVersion, maxBlockTime, getPropLabel, ProcessStatistics, callback, state); + } + + public bool EndSetGetProps(IAsyncResult result) + { + return base.Channel.EndSetGetProps(result); + } + + public void Shutdown(uint ShutdownCode) + { + base.Channel.Shutdown(ShutdownCode); + } + + public IAsyncResult BeginShutdown(uint ShutdownCode, AsyncCallback callback, object state) + { + return base.Channel.BeginShutdown(ShutdownCode, callback, state); + } + + public void EndShutdown(IAsyncResult result) + { + base.Channel.EndShutdown(result); + } + } +} diff --git a/xcompute_managed/XComputeProcess.cs b/xcompute_managed/XComputeProcess.cs new file mode 100644 index 0000000..da2bd74 --- /dev/null +++ b/xcompute_managed/XComputeProcess.cs @@ -0,0 +1,554 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + + +namespace Microsoft.Research.Dryad +{ + using System; + using System.Collections.Generic; + using System.Threading; + using Microsoft.Research.Dryad; + + public delegate void StateChangeEventHandler(object sender, XComputeProcessStateChangeEventArgs e); + + public delegate void GetSetPropertyEventHandler(object sender, XComputeProcessGetSetPropertyEventArgs e); + + internal class XComputeProcess : IDisposable + { + private bool m_disposed = false; + private int m_id; + private int m_graphManagerId = -1; + private int m_graphManagerVersion = -1; + private ProcessState m_currentState; + private uint m_exitCode; + private bool m_cancelled = false; + private bool m_handleClosed = false; + private Dispatcher m_dispatcher = null; + private string m_assignedNode = String.Empty; + private Dictionary m_stateChangeListeners; + private Dictionary> m_stateChangeWaiters; + private Dictionary m_stateChangeTimers; + private ManualResetEvent m_assignedToNodeEvent = new ManualResetEvent(false); + + private Dictionary> m_propertyListeners; + + private object m_syncRoot = new object(); + + private static readonly char[] cmdLineSeparator = new char[] { ' ' }; + + public XComputeProcess(int m_id) + { + this.m_id = m_id; + this.m_currentState = ProcessState.Uninitialized; + this.m_exitCode = 0; + m_stateChangeListeners = new Dictionary(); + m_stateChangeWaiters = new Dictionary>(); + m_stateChangeTimers = new Dictionary(); + + m_propertyListeners = new Dictionary>(); + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!m_disposed) + { + if (disposing) + { + DryadLogger.LogInformation("Dispose Process", "Releasing resources for process id {0}", this.m_id); + + this.m_assignedToNodeEvent.Close(); + + foreach (KeyValuePair> kvp in m_stateChangeWaiters) + { + foreach (ManualResetEvent e in kvp.Value) + { + try + { + e.Close(); + } + catch (Exception ex) + { + DryadLogger.LogError(0, ex); + } + } + } + } + m_disposed = true; + } + } + + public void Timeout(object state) + { + StateChangeEventHandler handler = state as StateChangeEventHandler; + handler(this, new XComputeProcessStateChangeEventArgs(m_id, m_currentState, true)); + lock (SyncRoot) + { + m_stateChangeTimers.Remove(handler); + } + } + + public void AddPropertyListener(string label, ulong version, GetSetPropertyEventHandler handler) + { + lock (SyncRoot) + { + // Is there an entry for this label? + if (m_propertyListeners.ContainsKey(label) == false) + { + m_propertyListeners.Add(label, new Dictionary()); + } + + // Is there an entry for this version + if (m_propertyListeners[label].ContainsKey(version)) + { + m_propertyListeners[label][version] += handler; + } + else + { + m_propertyListeners[label].Add(version, handler); + } + } + } + + public void AddStateChangeListener(ProcessState targetState, long timeoutInterval, StateChangeEventHandler handler) + { + lock (SyncRoot) + { + if (m_currentState >= targetState) + { + handler(this, new XComputeProcessStateChangeEventArgs(m_id, m_currentState, false)); + } + else + { + if (m_stateChangeListeners.ContainsKey(targetState)) + { + m_stateChangeListeners[targetState] += handler; + } + else + { + m_stateChangeListeners.Add(targetState, handler); + } + if (timeoutInterval != long.MaxValue) + { + m_stateChangeTimers[handler] = new Timer(this.Timeout, handler, timeoutInterval, 0); + } + } + } + } + + public void AddStateChangeWaiter(ProcessState targetState, ManualResetEvent waitEvent) + { + lock (SyncRoot) + { + if (m_currentState >= targetState) + { + waitEvent.Set(); + } + else + { + m_stateChangeWaiters[targetState].Add(waitEvent); + } + } + } + + public void Cancel() + { + bool wasRunning = false; + + lock (SyncRoot) + { + // If the process has already been assigned to a node, then we will need to cancel it at the node + if (this.CurrentState < ProcessState.AssignedToNode) + { + this.m_cancelled = true; + this.ExitCode = 0x830A0003; // DrError_VertexReceivedTermination + DryadLogger.LogInformation("Cancel process", "Cancelation received for vertex {0}.{1} before it was assigned to a node", m_graphManagerId, m_graphManagerVersion); + wasRunning = false; + } + else if (this.CurrentState == ProcessState.Completed) + { + // nothing to do for this case, process already completed + DryadLogger.LogInformation("Cancel process", "Cancellation received for vertex {0}.{1} after it completed", m_graphManagerId, m_graphManagerVersion); + return; + } + else if (Dispatcher != null) + { + DryadLogger.LogInformation("Cancel process", "Cancellation received for vertex {0}.{1} after it was assigned to node {2}", m_graphManagerId, m_graphManagerVersion, Dispatcher.NodeName); + wasRunning = true; + } + else + { + // This is an unexpected condition + DryadLogger.LogError(0, null, "Cancellation received for vertex {0}.{1} in state {2} with no dispatcher", m_graphManagerId, m_graphManagerVersion, CurrentState.ToString()); + return; + } + + if (wasRunning) + { + if (Dispatcher != null) + { + Dispatcher.CancelScheduleProcess(m_id); + } + } + else + { + ChangeState(ProcessState.Completed); + } + } + } + + public void ChangeState(ProcessState newState) + { + lock (SyncRoot) + { + if (newState > m_currentState) + { + DryadLogger.LogDebug("Change State", "Transition process {0} from state {1} to state {2}", m_id, m_currentState, newState); + + m_currentState = newState; + List listenersToRemove = new List(); + List waitersToRemove = new List(); + + // Check for listeners / waiters for earlier states, in case a state is skipped (e.g. process failed to start) + foreach (ProcessState s in m_stateChangeListeners.Keys) + { + if (s <= m_currentState) + { + // Notify listeners + if (m_stateChangeListeners[s] != null) + { + XComputeProcessStateChangeEventArgs e = new XComputeProcessStateChangeEventArgs(m_id, m_currentState, false); + m_stateChangeListeners[s](this, e); + if (m_stateChangeTimers.ContainsKey(m_stateChangeListeners[s])) + { + m_stateChangeTimers[m_stateChangeListeners[s]].Dispose(); + m_stateChangeTimers.Remove(m_stateChangeListeners[s]); + } + } + listenersToRemove.Add(s); + } + } + foreach (ProcessState s in listenersToRemove) + { + m_stateChangeListeners.Remove(s); + } + + foreach (ProcessState s in m_stateChangeWaiters.Keys) + { + // Signal waiters + if (s <= m_currentState) + { + foreach (ManualResetEvent w in m_stateChangeWaiters[s]) + { + w.Set(); + } + waitersToRemove.Add(s); + } + } + foreach (ProcessState s in waitersToRemove) + { + foreach (ManualResetEvent e in m_stateChangeWaiters[s]) + { + try + { + e.Close(); + } + catch (Exception ex) + { + DryadLogger.LogError(0, ex); + } + } + m_stateChangeWaiters.Remove(s); + } + + if (m_currentState == ProcessState.AssignedToNode) + { + m_assignedToNodeEvent.Set(); + } + } + else + { + DryadLogger.LogWarning("Change State", "Unexpected state change attempted for process {0}: from {1} to {2}", this.m_id, this.m_currentState.ToString(), newState.ToString()); + } + } + } + + public void TransitionToRunning(object state) + { + DryadLogger.LogDebug("Change State", "Transitioning to Running with current state {0} for process {1}", this.m_currentState.ToString(), this.m_id); + + try + { + // In rare cases (such as a cancelled duplicate), the GM may close the handle to the process while it is transitioning to running. + // This results in Dispose being called on this process, which closes the m_assignedToNode handle. + // In this case, we want to catch the exception and log it, but do nothing else, since the GM is done with this process. + if (m_assignedToNodeEvent.WaitOne(new TimeSpan(0, 0, 10), false)) + { + DryadLogger.LogDebug("Change State", "Successfully waited for transition to {0} for process {1}", this.m_currentState.ToString(), this.m_id); + } + else + { + DryadLogger.LogWarning("Change State", "Timed out waiting for transition to AssignedToNode for process {0}", this.m_id); + // We want to fire the state change anyway or else we'll get a zombie process. + // The GM will handle the transition, it just may cause a delay. + } + ChangeState(ProcessState.Running); + } + catch (ObjectDisposedException ex) + { + DryadLogger.LogError(0, ex, "Process handle was closed while waiting for transition to assigned to node"); + } + } + + public void SetGetPropsComplete(ProcessInfo info, string[] propertyLabels, ulong[] propertyVersions) + { + lock (SyncRoot) + { + // For the Set part + if (propertyLabels != null && propertyVersions != null) + { + for (int i = 0; i < propertyLabels.Length; i++) + { + if (m_propertyListeners.ContainsKey(propertyLabels[i])) + { + List versionsToRemove = new List(); + foreach (KeyValuePair entry in m_propertyListeners[propertyLabels[i]]) + { + if (entry.Key <= propertyVersions[i] || entry.Key == ulong.MaxValue) + { + DryadLogger.LogDebug("SetGetProsComplete", "Set complete - m_id: {0} state: {1}, label: {2}", m_id, info.processState, propertyLabels[i]); + XComputeProcessGetSetPropertyEventArgs e = new XComputeProcessGetSetPropertyEventArgs(m_id, info, propertyVersions); + entry.Value(this, e); + + versionsToRemove.Add(entry.Key); + } + } + foreach (ulong version in versionsToRemove) + { + m_propertyListeners[propertyLabels[i]].Remove(version); + } + } + } + } + + // For the Get part + if (info != null && info.propertyInfos != null) + { + foreach (ProcessPropertyInfo propInfo in info.propertyInfos) + { + if (m_propertyListeners.ContainsKey(propInfo.propertyLabel)) + { + List versionsToRemove = new List(); + foreach (KeyValuePair entry in m_propertyListeners[propInfo.propertyLabel]) + { + if (entry.Key <= propInfo.propertyVersion || entry.Key == ulong.MaxValue) + { + DryadLogger.LogDebug("SetGetProsComplete", "Get complete - m_id: {0} state: {1}, label: {2}", m_id, info.processState, propInfo.propertyLabel); + + XComputeProcessGetSetPropertyEventArgs e = new XComputeProcessGetSetPropertyEventArgs(m_id, info, propertyVersions); + entry.Value(this, e); + + versionsToRemove.Add(entry.Key); + } + } + foreach (ulong version in versionsToRemove) + { + m_propertyListeners[propInfo.propertyLabel].Remove(version); + } + } + } + } + } + } + + public void SetIdAndVersion(string commandLine) + { + bool parsed = false; + string[] args = commandLine.Split(cmdLineSeparator, StringSplitOptions.RemoveEmptyEntries); + if (args != null) + { + if (args.Length == 6) + { + lock (SyncRoot) + { + if (Int32.TryParse(args[4], out m_graphManagerId)) + { + if (Int32.TryParse(args[5], out m_graphManagerVersion)) + { + parsed = true; + } + } + } + } + } + + if (!parsed) + { + DryadLogger.LogWarning("Set Vertex Id And Version", "Failed to parse vertex command line: {0}", commandLine); + } + } + + public string AssignedNode + { + get { return m_assignedNode; } + set { m_assignedNode = value; } + } + + public Dispatcher Dispatcher + { + get { return m_dispatcher; } + set + { + lock (SyncRoot) + { + m_dispatcher = value; + try + { + if (m_dispatcher != null) + { + m_assignedNode = m_dispatcher.NodeName; + } + } + catch (Exception e) + { + DryadLogger.LogError(0, e, "Failed to set assigned node from supplied dispatcher"); + } + } + } + } + + public ProcessState CurrentState + { + get { return m_currentState; } + } + + public uint ExitCode + { + get { return m_exitCode; } + set + { + lock (SyncRoot) + { + m_exitCode = value; + } + } + } + + public bool Cancelled + { + get + { + return this.m_cancelled; + } + } + + public bool HandleClosed + { + get { return m_handleClosed; } + set { m_handleClosed = value; } + } + + public int GraphManagerId + { + get + { + return m_graphManagerId; + } + } + + public int GraphManagerVersion + { + get + { + return m_graphManagerVersion; + } + } + + public object SyncRoot + { + get { return m_syncRoot; } + } + + } + + public class XComputeProcessGetSetPropertyEventArgs + { + private int processId; + private ProcessInfo processInfo; + private ulong[] propertyVersions; + + public XComputeProcessGetSetPropertyEventArgs(int m_id, ProcessInfo info, ulong[] versions) + { + this.processId = m_id; + this.processInfo = info; + this.propertyVersions = versions; + } + + public int ProcessId + { + get { return this.processId; } + } + + public ProcessInfo ProcessInfo + { + get { return this.processInfo; } + } + + public ulong[] PropertyVersions + { + get { return this.propertyVersions; } + } + } + + public class XComputeProcessStateChangeEventArgs + { + private bool timedOut; + private int processId; + private ProcessState m_currentState; + + public XComputeProcessStateChangeEventArgs(int m_id, ProcessState state, bool timedOut) + { + this.timedOut = timedOut; + processId = m_id; + m_currentState = state; + } + + public int ProcessId + { + get { return processId; } + } + + public ProcessState State + { + get { return m_currentState; } + } + + public bool TimedOut + { + get { return timedOut; } + set { timedOut = value; } + } + } +} diff --git a/xcompute_native/YarnQueryNativeClusterAdapter.vcxproj b/xcompute_native/YarnQueryNativeClusterAdapter.vcxproj new file mode 100644 index 0000000..76fa083 --- /dev/null +++ b/xcompute_native/YarnQueryNativeClusterAdapter.vcxproj @@ -0,0 +1,211 @@ + + + + + Debug + Win32 + + + Debug + x64 + + + Release + Win32 + + + Release + x64 + + + + {E092E2B9-D3C9-4CE2-8201-BDA442574C97} + YarnQueryNativeClusterAdapter + ManagedCProj + + + + DynamicLibrary + true + + + DynamicLibrary + Unicode + true + + + DynamicLibrary + true + + + DynamicLibrary + Unicode + true + + + + + + + + + + + + + + + + + + + <_ProjectFileVersion>10.0.40219.1 + Debug\ + Debug\ + true + ..\bin\$(Configuration)\ + $(Platform)\$(Configuration)\ + true + Release\ + Release\ + true + ..\bin\$(Configuration)\ + $(Platform)\$(Configuration)\ + false + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + AllRules.ruleset + + + false + false + + + + Disabled + inc;%(AdditionalIncludeDirectories) + WIN32;_DEBUG;_WINDOWS;_USRDLL;YARNQUERYNATIVECLUSTERADAPTER_EXPORTS;%(PreprocessorDefinitions) + MultiThreadedDebugDLL + Create + Level3 + ProgramDatabase + + + true + Windows + + + MachineX86 + + + + + X64 + + + Disabled + inc;%(AdditionalIncludeDirectories) + WIN32;_DEBUG;_WINDOWS;_USRDLL;YARNQUERYNATIVECLUSTERADAPTER_EXPORTS;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + MultiThreadedDebugDLL + Create + Level3 + ProgramDatabase + + + true + Console + + + MachineX64 + + + + + WIN32;NDEBUG;_WINDOWS;_USRDLL;YARNQUERYNATIVECLUSTERADAPTER_EXPORTS;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + inc;%(AdditionalIncludeDirectories) + + + true + Windows + true + true + main + MachineX86 + + + + + X64 + + + WIN32;NDEBUG;_WINDOWS;_USRDLL;YARNQUERYNATIVECLUSTERADAPTER_EXPORTS;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + MultiThreadedDLL + + + Level3 + ProgramDatabase + inc;%(AdditionalIncludeDirectories) + + + true + Windows + true + true + + + MachineX64 + + + + + true + true + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + {f4b04940-67cf-4796-b6d3-3cfd38fb988a} + + + + + + \ No newline at end of file diff --git a/xcompute_native/async.cpp b/xcompute_native/async.cpp new file mode 100644 index 0000000..167d7d4 --- /dev/null +++ b/xcompute_native/async.cpp @@ -0,0 +1,133 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/*++ + +Module Name: + + async.cpp + +Abstract: + + This module contains support routines for implementing the + the xcompute async completion notification. + +--*/ +#include "stdafx.h" + + +// +// ValidateAsync(PCCS_ASYNC_INFO asyncInfo) +// Does some basic validation of the CS_ASYNC_INFO structure +// +// returns S_OK if it's well-formed, E_INVALIDARG otherwise +// +HRESULT +ValidateAsync( + PCXC_ASYNC_INFO async + ) +{ + if (async==NULL) { + return S_OK; + } + if (async->Size < sizeof(XC_ASYNC_INFO)) { + return E_INVALIDARG; + } + if (async->pOperationState == NULL) { + return E_INVALIDARG; + } + + // + // If an IOCP was supplied, there better be a pOverlapped and CompletionKey + // to post to it + // + if (async->IOCP != NULL) + { + if ((async->pOverlapped == NULL) || (async->CompletionKey == NULL)) { + return E_INVALIDARG; + } + } else { + + // + // If no IOCP was supplied, there also must not be a pOverlapped. + // + if (async->pOverlapped != NULL) { + return E_INVALIDARG; + } + } + return S_OK; +} + +ASYNC::ASYNC( + PCXC_ASYNC_INFO pAsyncInfo + ) +{ + + + // + // Capture the supplied completion information + // + this->pOperationState = pAsyncInfo->pOperationState; + this->hEvent = pAsyncInfo->Event; + this->hIOCP = pAsyncInfo->IOCP; + this->pOverlapped = pAsyncInfo->pOverlapped; + this->CompletionKey = pAsyncInfo->CompletionKey; +} + +ASYNC::~ASYNC() +{ +} + +HRESULT +ASYNC::Complete( + HRESULT hr + ) +{ + + + // + // Indicate status + // + *pOperationState = hr; + + // + // Set the event (if present) + // + if (hEvent) { + ::SetEvent(hEvent); + } + + // + // Post to the IOCP (if present) + // + if (hIOCP) { + PostQueuedCompletionStatus(hIOCP, + 0, + CompletionKey, + pOverlapped); + } + + delete this; + + // + // Always return pending + // + return HRESULT_FROM_WIN32(ERROR_IO_PENDING); +} diff --git a/xcompute_native/context.cpp b/xcompute_native/context.cpp new file mode 100644 index 0000000..470ef08 --- /dev/null +++ b/xcompute_native/context.cpp @@ -0,0 +1,147 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/*++ + +Module Name: + + context.cpp + +Abstract: + + This module contains the public interface and support routines for + managing process context for xcompute on the HPC scheduler + +--*/ +#include "stdafx.h" +#include + +std::map g_Context; + + +/*++ + +XcSetProcessUserContext API + +Description: + +Associates API user related data with process +identified by the Process handle. + +The API user can associate any data with the XCompute process +and get back the data, using the XcGetProcessUserContext API. + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +NOTE: +a. The XcCloseProcessHandle() will not deallocate the user + context data. It is the API users responsibilty to + deallocated data associated with UserContext. + +b. The user context is associated with a process and not with a + ProcessHandle. So if multiple handles identify the same + XCompute process, they will return the same user context. + +Arguments: + + hProcessHandle + Process handle to identify process to which user context + is being associated + + userContext + The user context data + + pPreviousUserContext + If there was a previously associated user context + with the XCompute process , then returns that context data. + Otherwise NULL is returned. If the caller supplies NULL input, + the previous value is not returned + + +Return Value: + + XCERROR_OK + The call succeded + + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcSetProcessUserContext( + IN XCPROCESSHANDLE hProcessHandle, + IN ULONG_PTR userContext, + OUT ULONG_PTR* pPreviousUserContext + ) +{ + *pPreviousUserContext = g_Context[(DWORD)hProcessHandle]; + g_Context[(DWORD)hProcessHandle] = userContext; + + return S_OK; +} + + + +/*++ + +XcGetProcessUserContext API + +Description: + +Gets the API user related data associated with the XCompute Process. +The API user can associate any data with the XCompute process via the +XcAddUserContextToHandle API. + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +Arguments: + + hProcessHandle + Process handle + + pUserContext + The user context data associated with the XCompute Process. + If no user context is associated, then + NULL is returned. + +Return Value: + + XCERROR_OK + The call succeded + + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessUserContext( + IN XCPROCESSHANDLE hProcessHandle, + OUT ULONG_PTR* pUserContext + ) +{ + + *pUserContext = g_Context[(DWORD)hProcessHandle]; + + return S_OK; +} + + diff --git a/xcompute_native/dllmain.cpp b/xcompute_native/dllmain.cpp new file mode 100644 index 0000000..14f95b0 --- /dev/null +++ b/xcompute_native/dllmain.cpp @@ -0,0 +1,40 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// dllmain.cpp : Defines the entry point for the DLL application. +#include "stdafx.h" + +#pragma unmanaged +BOOL APIENTRY DllMain( HMODULE hModule, + DWORD ul_reason_for_call, + LPVOID lpReserved + ) +{ + switch (ul_reason_for_call) + { + case DLL_PROCESS_ATTACH: + case DLL_THREAD_ATTACH: + case DLL_THREAD_DETACH: + case DLL_PROCESS_DETACH: + break; + } + return TRUE; +} + diff --git a/xcompute_native/file.cpp b/xcompute_native/file.cpp new file mode 100644 index 0000000..5aafd28 --- /dev/null +++ b/xcompute_native/file.cpp @@ -0,0 +1,257 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/*++ +Module Name: + + file.cpp + +Abstract: + + This module contains the public interface and support routines for + the xcompute process file functionality on top of the HPC scheduler. + +--*/ +#include "stdafx.h" + +/*++ + +XcOpenProcessFile API + +Description: + +Opens a handle to a remote XCompute processes working File. +Using this handle, an application can read remote files +written by a XCompute process on a given Node. +Writing of files is not supported. Local files can be written +using Ordinary Windows file I/O (restricted to the working +directory and Its children). + +Arguments: + + hSession + Handle to an XCompute session associated with + this call. + + fileUri + the fully qualified file Uri (UTF-8) obtained by calling the + XcGetProcessFileUri API. + + Flags + Reserved. Must be 0. + + phFileHandle + The returned handle to the opened file. + Set to NULL if error + + pAsyncInfo + The async info structure. Its an alias to the + CS_ASYNC_INFO defined in Cosmos.h. If this + parameter is NULL, then the function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then the operation is carried + on in asynchronous manner. If an asynchronous + operation has been successfully started then + this function terminates immediately with + an HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start the asynchronous operation. + + + Return Value: + + if pAsyncInfo is NULL + CsError_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation (a SUCCESS HRESULT will never + be returned if pAsyncInfo is not NULL). + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcOpenProcessFile( + IN XCSESSIONHANDLE hSession, + IN PCSTR fileUri, + IN DWORD Flags, + OUT PXCPROCESSFILEHANDLE phFileHandle, + IN PCXC_ASYNC_INFO pAsyncInfo +) +{ + PASYNC async; + CAPTURE_ASYNC(async); + + // + // Our file URI is a UNC path so we can just open it directly + // + HANDLE hFile = ::CreateFileA(fileUri, + GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE, + NULL, + OPEN_EXISTING, + 0, + NULL); + if (hFile == INVALID_HANDLE_VALUE) { + return COMPLETE_ASYNC(async,HRESULT_FROM_WIN32(GetLastError())); + } + + *phFileHandle = (XCPROCESSFILEHANDLE)hFile; + + return COMPLETE_ASYNC(async, S_OK); +} + + + +/*++ + +XcCloseProcessFile API + +Description: + +Closes the file opened by the XcOpenProcessFile + +Arguments: + + hFileHandle + The handle to the opened file. + + Return Value: + + CsError_OK indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCloseProcessFile( + IN XCPROCESSFILEHANDLE hFileHandle +) +{ + ::CloseHandle((HANDLE)hFileHandle); + return S_OK; +} + + + + +/*++ + +XcReadProcessFile API + +Description: + +Reads the content of the file opened by the XcOpenProcessFile + +Arguments: + + + phFileHandle + The handle to the opened file. + + pBuffer + Pointer to the buffer that receives the data read. + + pBytesRead + Pointer to variable containing size of the buffer + on input. On return this variable receives number + of bytes read. + + pReadPosition + The offset from the beginning of the file at + which to read. + + pAsyncInfo + The async info structure. Its an alias to the + CS_ASYNC_INFO defined in Cosmos.h. If this + parameter is NULL, then the function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then the operation is + carried on in asynchronous manner. If an asynchronous + operation has been successfully started then + this function terminates immediately with + an HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start the asynchronous operation. + + + Return Value: + + if pAsyncInfo is NULL + CsError_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation (a SUCCESS HRESULT will never + be returned if pAsyncInfo is not NULL). + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcReadProcessFile( + IN XCPROCESSFILEHANDLE phFileHandle, + OUT PVOID pBuffer, + IN OUT PSIZE_T pBytesRead, + IN OUT XCPROCESSFILEPOSITION* pReadPosition, + IN PCXC_ASYNC_INFO pAsyncInfo +) +{ + OVERLAPPED ov = {0}; + PASYNC async; + + if (*pBytesRead > (DWORD)-1) + { + return E_NOTIMPL; + } + + CAPTURE_ASYNC(async); + + ov.Offset = (DWORD)(*pReadPosition); + ov.OffsetHigh = (DWORD)(*pReadPosition >> 32); + if (!::ReadFile((HANDLE)phFileHandle, + pBuffer, + (DWORD)*pBytesRead, + (LPDWORD)pBytesRead, + &ov)) { + + return COMPLETE_ASYNC(async,HRESULT_FROM_WIN32(GetLastError())); + } + return COMPLETE_ASYNC(async, S_OK); +} + + diff --git a/xcompute_native/inc/XCompute.h b/xcompute_native/inc/XCompute.h new file mode 100644 index 0000000..da5af00 --- /dev/null +++ b/xcompute_native/inc/XCompute.h @@ -0,0 +1,1741 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#pragma warning( push ) +/* 'X' bytes padding added after member 'Y' */ +#pragma warning( disable: 4820 ) + + +#pragma pack( push, 8 ) + + +#if !defined(_PCVOID_DEFINED) +typedef const void* PCVOID; +#define _PCVOID_DEFINED +#endif + + +#include + + +#if defined(__cplusplus) +extern "C" { +#endif + + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcResetProgress( + IN XCSESSIONHANDLE SessionHandle, + IN ULONG nTotalProgressSteps, + IN bool bUpdate +); + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcIncrementTotalSteps( + IN XCSESSIONHANDLE SessionHandle, + IN bool bUpdate +); + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcDecrementTotalSteps( + IN XCSESSIONHANDLE SessionHandle, + IN bool bUpdate +); + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcSetProgress( + IN XCSESSIONHANDLE SessionHandle, + IN ULONG nCompletedProgressSteps, + IN PCSTR pMessage +); + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcIncrementProgress( + IN XCSESSIONHANDLE SessionHandle, + IN PCSTR pMessage +); + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCompleteProgress( + IN XCSESSIONHANDLE SessionHandle, + IN PCSTR pMessage +); + + + + +/*++ + +XcOpenSession API + +Description: + +Opens an XCompute session for a given cluster. Each session +is associated with a cluster and is independent of other sessiosn. +The session (apart from other things) is associated with +user credientials. + +It is possible to create multiple sessions for the same cluster and +these multiple sessions will behave independent of each other. This +is particularly useful for applications like WebServer which will run +multiple sessions, one per user. + +Use the XcCloseSession to close the handle returned as a result of +XcOpenSession call. + +Arguments: + + pOpenSessionParams + The Open Session Parameters. Passes info about cluster to + establish session with, clientId, etc. + See XC_OPEN_SESSION_PARAMS for details. + + Pass NULL for defaults - Default cluster and a default cliend id. + + pSessionHandle + Handle to session + + pAsyncInfo + The async info structure. Its an alias to the + CS_ASYNC_INFO defined in Cosmos.h. IF this + parameter is NULL, then function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then operation is carried + on in asynchronous manner. If asynchronous + operation has been successfully started then + function terminates immediately with + HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start asynchronous operation. + + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcOpenSession( + IN PCXC_OPEN_SESSION_PARAMS pOpenSessionParams, + OUT PXCSESSIONHANDLE pSessionHandle, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + +XcCloseSession API + +Description: + +Closes the session. + +Arguments: + + SessionHandle + Handle to session to close + +Return Value: + + XCERROR_OK + Call succeeded. + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCloseSession( + IN XCSESSIONHANDLE SessionHandle +); + + + +/*++ + +XcInitialize API + +Description: + +Call this function at the start to initialize the various internal +data structures of the XCompute SDK library. + +Arguments: + + ConfigFile + Name of the config file + + ComponentName + The name of the component + +Return Value: + + XCERROR_OK + Call succeeded. + NOTE: + S_FALSE will be returned if the initialize + has already been called. + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcInitialize( + IN PCSTR ConfigFileName, + IN PCSTR ComponentName +); + + + +/*++ + +XcFreeMemory API + +Description: + +Frees the memory allocated by the XCompute API. +All the memory returned as a result of call to +the XCompute API should use the XcFreeMemory to +deallocate the memory + +Arguments: + + pMem + Pointer to the memory + +Return Value: + + XCERROR_OK + Memory was successfully deallocated + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcFreeMemory( + IN PCVOID pMem +); + + + +/*++ + +XcCreateNewProcessHandle API + +Description: + +Creates a new process handle for a new process in the +given Job. + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +Note: +1. + This method just creates the handle to + the XCompute process. It does not schedule the process itself. + Use the XcScheduleProcessAPI to schedule the XCompute process. + +2. + Use the XcCloseProcessHandle() to free the handle + +3. + Do not copy handle using the simple assignment operator.Use the + DuplicateProcessHandle() API. Each handle variable needs to be + freed using the XcCloseProcessHandle(). + +Arguments: + + SessionHandle + Handle to a session associated with this call + + pJobId + The Id of the job under which the process will + be created. A NULL value will cause the current + processes JobId to be automatically picked up. + NOTE: + This parameter is only interesting to the Task Scheduler. + For all other cases, it should be assined to NULL + + pProcessHandle + The handle to the process. + + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCreateNewProcessHandle( + IN XCSESSIONHANDLE SessionHandle, + IN const GUID* pJobId, + OUT PXCPROCESSHANDLE pProcessHandle +); + + + +/*++ + +XcOpenCurrentProcessHandle API + +Description: + +Opens the current processes handle + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +Note: +1. + This method creates the handle to the current process and assigns + it to the session on that process. + +2. + Use the XcCloseProcessHandle() to free the handle + +3. + Do not copy handle using the simple assignment operator.Use the + DuplicateProcessHandle() API. Each handle variable needs to be + freed using the XcCloseProcessHandle(). + +Arguments: + + SessionHandle + Handle to a session associated with this call. + + pProcessHandle + The handle to the process. This must be closed using the + XcClosePorcessHandle() + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcOpenCurrentProcessHandle( + IN XCSESSIONHANDLE SessionHandle, + OUT PXCPROCESSHANDLE pProcessHandle +); + + + +/*++ + +XcCloseProcessHandle API + +Description: + +Closes a process handle created either by a call to +XcCreateNewProcessHandle() or XcDupProcessHandle(). + +This call is synchronous and does not cross +machine boundaries/process boundaries. + + +NOTE: +Every call to the XcCreateNewProcessHandle() or +DupProcessHandle() should ultimately +result in a call to XcCloseProcessHandle() to deallocated the handle. + +Arguments: + + ProcessHandle + Process handle to be closed + +Return Value: + + XCERROR_OK + The call succeded + + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCloseProcessHandle ( + IN XCPROCESSHANDLE ProcessHandle +); + + + +/*++ + +XcDupProcessHandle API + +Description: + +Duplicates a process handle. Use this api, if a copy of the +process handle is needed. + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +NOTE: + a. Every call to the DupProcessHandle should ultimately result + in a call to XcCloseProcessHandle() to deallocated the handle. + +Arguments: + + ProcessHandle + Process handle to be duplicated + + pDupProcessHandle + The duplicated process handle + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcDupProcessHandle( + IN XCPROCESSHANDLE ProcessHandle, + OUT PXCPROCESSHANDLE pDupProcessHandle +); + + + +/*++ + +XcSerializeProcessHandle API + +Description: + +Creates a serialized process handle. A XCompute process can serialize a +process handle, and pass it to another XCompute process where the other +XCompute process can use the XcUnSerializeProcessHandle() API, to +recreate the process handle. Then it can use that process handle to +communicate with the process. e.g by using XcSetAndGetProcessInfo() API + +Arguments: + + ProcessHandle + The handle to the process to serialize. + + ppXcSerializedHandleBlock + The serialized process handle. Use the XcFreeMemory() API + to de-allocated the pXcSerializedHandleBlock. + + pBlockLength + The length in bytes of the serialized process handle block + +NOTE: + The UserContext assiciated with the process handle will *NOT* + be serialzed. + +Return Value: + + XCERROR_OK + The call succeded + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcSerializeProcessHandle ( + IN XCPROCESSHANDLE ProcessHandle, + OUT PCVOID* ppXcSerializedHandleBlock, + OUT PSIZE_T pBlockLength +); + + + +/*++ + +XcUnSerializeProcessHandle API + +Description: + +Un-serializes a serialized process handle. See XcSerializeProcessHandle() API +for more details + +Arguments: + + SessionHandle + The session to which to associate the un-serialized process handle with. + + pXcSerializedHandleBlock + The serialized process handle. + + pBlockLength + The length in bytes of the serialized process handle block + + pProcessHandle + The un-serialized process handle. + +Return Value: + + XCERROR_OK + The call succeded + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcUnSerializeProcessHandle ( + IN XCSESSIONHANDLE SessionHandle, + IN PCVOID pXcSerializedHandleBlock, + IN SIZE_T BlockLength, + OUT PXCPROCESSHANDLE pProcessHandle +); + + + +/*++ + +XcSetProcessUserContext API + +Description: + +Associates API user related data with process +identified by the Process handle. + +The API user can associate any data with the XCompute process +and get back the data, using the XcGetProcessUserContext API. + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +NOTE: +a. The XcCloseProcessHandle() will not deallocate the user + context data. It is the API users responsibilty to + deallocated data associated with UserContext. + +b. The user context is associated with a process and not with a + ProcessHandle. So if multiple handles identify the same + XCompute process, they will return the same user context. + +Arguments: + + ProcessHandle + Process handle to identify process to which user context + is being associated + + pUserContext + The user context data + + pPreviousUserContext + If there was a previously associated user context + with the XCompute process , then returns that context data. + Otherwise NULL is returned. If the caller supplies NULL input, + the previous value is not returned + + +Return Value: + + XCERROR_OK + The call succeded + + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcSetProcessUserContext( + IN XCPROCESSHANDLE ProcessHandle, + IN ULONG_PTR pUserContext, + OUT ULONG_PTR* pPreviousUserContext +); + + + +/*++ + +XcGetProcessUserContext API + +Description: + +Gets the API user related data associated with the XCompute Process. +The API user can associate any data with the XCompute process via the +XcAddUserContextToHandle API. + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +Arguments: + + ProcessHandle + Process handle + + pUserContext + The user context data associated with the XCompute Process. + If no user context is associated, then + NULL is returned. + +Return Value: + + XCERROR_OK + The call succeded + + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessUserContext( + IN XCPROCESSHANDLE ProcessHandle, + OUT ULONG_PTR* pUserContext +); + + + +/*++ + +XcGetProcessState API + +Description: + +Gets the process state information. If Schedule process is not +yet been called, the API will return error. + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +Arguments: + + ProcessHandle + Process handle + + pProcessState + Describes the process state. The different states + are described in XComputeTypes.h + + pProcessSchedulingError + if process state is XCPROCESSSTATE_COMPLETED + then the error code indicates reson. + S_OK means process compeleted without errors. + Other error codes indicate reasons for failed completion. + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessState( + IN XCPROCESSHANDLE ProcessHandle, + OUT PXCPROCESSSTATE pProcessState, + OUT XCERROR* pProcessSchedulingError +); + + + +/*++ + +XcGetProcessId API + +Description: + +Gets the process Id of the process associated with the process handle. +If the process state is anything less than XCPROCESSSTATE_ASSIGNEDTOPN +an error is returned. + +Arguments: + + ProcessHandle + Process handle + + + pProcessId + The id of the process + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessId( + IN XCPROCESSHANDLE ProcessHandle, + OUT GUID* pProcessId +); + + + +/*++ + +XcGetProcessNodeId API + +Description: + +Gets the process node on which the process has been assigned. +If the process state anything other than XCPROCESSSTATE_ASSIGNEDTOPN +an error is returned. + +Arguments: + + ProcessHandle + Process handle + + pProcessNodeId + Pointer to process node Id + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessNodeId( + IN XCPROCESSHANDLE ProcessHandle, + OUT PXCPROCESSNODEID pProcessNodeId +); + + + +/*++ + +ProcessScheduler API + +--*/ + + + +/*++ + +XcScheduleProcess API + +Description: + +Contacts the Process Scheduler to schedule an XCompute Process. +Any XCompute Process in a Job may schedule additional +XCompute Processes in the same Job by requesting their creation +through the XCompute Process Scheduler, using this API. + +NOTE: +This call always returns immediately. +A successful return code from the API indicates that the +XcScheduleProcess request was added to the local scheduleProcess queue. +The user should use the XcWaitForStateChange(XCPROCESSSTATE_ASSIGNEDTOPN) +API to see when the process actually gets scheduled to the Process Scheduler + +Arguments: + + ProcessHandle + Handle to the process. + Use the XcCreateNewProcessHandle () API + to get obtain the handle to the process + + pScheduleProcessDescriptor + See PCXC_SCHEDULEPROCESS_DESCRIPTOR in + XComputeTypes.h. This datastructure is + copied before the function returns and + so it is not necessary for the caller + to preserve the contents during a + async call + + Return Value: + + S_OK indicating the operation was successfully started. + + Any other return value indicates the scheduleprocess request + could not be started + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcScheduleProcess( + IN XCPROCESSHANDLE ProcessHandle, + IN PCXC_SCHEDULEPROCESS_DESCRIPTOR pScheduleProcessDescriptor +); + + + + +/*++ + +XcCancelScheduleProcess API + +Description: + +Contacts the Process Scheduler to cancel the scheduled +XCompute Process. This API is used by the Parent XCompute process +that originally scheduled the XCompute process to cancel its +creation. +NOTE: The XCompute process will get cancelled, only if has not +already been created on a process node. The returned error code +indicates whether the process was successfully cancelled or not. + +Arguments: + + ProcessHandle + Handle to the process. + + pAsyncInfo + The async info structure. Its an alias to the + CS_ASYNC_INFO defined in Cosmos.h. IF this + parameter is NULL, then function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then operation is carried + on in asynchronous manner. If asynchronous + operation has been successfully started then + function terminates immediately with + HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start asynchronous operation. + + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCancelScheduleProcess( + IN XCPROCESSHANDLE ProcessHandle, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + +PN API + +--*/ + +/*++ + +XcSetAndGetProcessInfo API + +Description: + +Gets the process related information from the Process Node. +JobManager (e.g. Dryad Job manager), will use this API to get +information about a given XCompute process, of a job. +Various bit flags (explained below) control the amount of data +retreived for a given process +It also provides the user with the ability to block on a +particular property, for maxBlockTime amount of time, before the +API finishes (synchronously or asynchronously). Dryad uses this +to extend the lease period for a given process + +Arguments: + + ProcessHandle + Handle to the process. + Use the XcCreateNewProcessHandle () API + to get obtain the handle to the process + + pXcRequestInputs + Pointer to the + XC_SETANDGETPROCESSINFO_REQINPUT struct. + It contains the various inputs to the API + clubbed together. This structure needs to + be preserverd by the user till the Async + call is completed + + ppXcRequestResults + The results structure.The user should use + the XcFreeMemory(ppXcPnProcessInfo) to free + the memory after the results have been + consumed. + See PXC_SETANDGETPROCESSINFO_REQRESULTS for + more info. + + pAsyncInfo + The async info structure. Its an alias to + the CS_ASYNC_INFO defined in Cosmos.h. If + this parameter is NULL, then function + completes in synchronous manner and error + code is returned as return value. + + If parameter is not NULL then operation is + carried on in asynchronous manner. If + asynchronous operation has been successfully + started then function terminates + immediately with + HRESULT_FROM_WIN32(ERROR_IO_PENDING) return + value. + + Any other return value indicates that it was + impossible to start asynchronous operation. + + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcSetAndGetProcessInfo( + IN XCPROCESSHANDLE ProcessHandle, + IN PXC_SETANDGETPROCESSINFO_REQINPUT pXcRequestInputs, + OUT PXC_SETANDGETPROCESSINFO_REQRESULTS* ppXcRequestResults, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + +XcGetNetworkLocalityPathOfProcessNode + +Description: + +This API translates a set of process node IDs into +network locality paths. + +Arguments: + + SessionHandle + Handle to a session associated with + this call + + ProcessNodeId + The Process Node for which the + path is required + + ppNetworkLocalityPath + Returned network locality path for the ProcessNode. + The pNetworkLocalityPath vector should be freed with + XcFreeMemory(ppNetworkLocalityPath) + + pNetworkLocalityParam + The affinity param to be used to get the locality path. + The affinity param lets the user identify the affinity + level relative to the given ProcessNodeId, + which is reflected in the returned ppNetworkLocalityPath. + Thus given a ProcessNodeId, the user might say, + L2Switch as the NetworkLocalityParam, which means + the affinity is to all process nodes under that L2Switch. + + Different affinity params are defined in the + XComputeTypes.h. See Network Locality Params for + more details. + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetNetworkLocalityPathOfProcessNode( + IN XCSESSIONHANDLE SessionHandle, + IN XCPROCESSNODEID ProcessNodeId, + IN PSTR pNetworkLocalityParam, + OUT PCSTR* ppNetworkLocalityPath +); + + + +/*++ + +XcEnumerateProcessNodes + + +Description: + +This API enumerates all the process nodes that are controlled +by the Process scheduler and returns an array of processNodeIds + +Arguments: + + SessionHandle + Handle to a session associated with + this call + + pNumNodeIds + Pointer to a int which gets filled with the + number of process Node Ids in the + ppProcessNodeIds array + + ppProcessNodeIds + Pointer to array of processNode Ids. Use the + XcFreeMemory() API to deallocate. + + pAsyncInfo + The async info structure. Its an alias to + the CS_ASYNC_INFO defined in Cosmos.h. If + this parameter is NULL, then function + completes in synchronous manner and error + code is returned as return value. + + If parameter is not NULL then operation is + carried on in asynchronous manner. If + asynchronous operation has been successfully + started then function terminates + immediately with + HRESULT_FROM_WIN32(ERROR_IO_PENDING) return + value. + + Any other return value indicates that it was + impossible to start asynchronous operation. + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcEnumerateProcessNodes( + IN XCSESSIONHANDLE SessionHandle, + OUT UINT32* pNumNodeIds, + OUT PXCPROCESSNODEID* ppProcessNodeIds, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + +XcFetchProcessNodeMetaData + +Description: + +This API fetches the process node related metadata. This +call can result in a call to the Process Scheduler, if the +metadata for a given process node is missing. + +Arguments: + + SessionHandle + Handle to a session associated with + this call + + pProcessNodeIds + Array of IDs of the nodes for which the + metadata is required + + NumNodeIds + Number of node ids in the + pProcessNodeIds array + + pAsyncInfo + The async info structure. Its an alias to + the CS_ASYNC_INFO defined in Cosmos.h. If + this parameter is NULL, then function + completes in synchronous manner and error + code is returned as return value. + + If parameter is not NULL then operation is + carried on in asynchronous manner. If + asynchronous operation has been successfully + started then function terminates + immediately with + HRESULT_FROM_WIN32(ERROR_IO_PENDING) return + value. + + Any other return value indicates that it was + impossible to start asynchronous operation. + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcFetchProcessNodeMetaData( + IN XCSESSIONHANDLE SessionHandle, + IN UINT32 NumNodeIds, + IN PXCPROCESSNODEID pProcessNodeIds, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + + Notification/Sync API + +--*/ + + + +/*++ + +XcWaitForStateChange API + +Description: + +The API allows users to get async completion status for +XCompute process when it reaches a desired state. (see XCPROCESSSTATE) +When the desired state is reached the async completion is dispatched. + +NOTE: +1. If the process gets cancelled, then completion is dispatched immediately +2. The pOperationState of the AsyncInfo will have the error code. + +Arguments: + + ProcessHandle + Handle to an XCompute process for which the + state change event is needed + + WaitForState + The state to wait for the XCompute to be in, so + that completion can be dispatched + + MaxWaitInterval + The maximum amount of time (not including network + request latencies) that the API should wait for a + change in the process list before completing. If + XCTIMEINTERVAL_ZERO, the API will return changes + that can be immediately determined without + communication with the process scheduler. If + XC_TIMEINTERVAL_INFINITE, the API will wait until a + change occurs or the process is cancelled. + + pAsyncInfo + The async info structure. Its an alias to the + CS_ASYNC_INFO defined in Cosmos.h. If this + parameter is NULL, then the function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then the operation is carried + on in asynchronous manner. If an asynchronous + operation has been successfully started then + this function terminates immediately with + an HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start the asynchronous operation. + + Return Value: + + CsError_OK indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcWaitForStateChange( + IN XCPROCESSHANDLE ProcessHandle, + IN XCPROCESSSTATE WaitForState, + IN XCTIMEINTERVAL MaxWaitInterval, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + + XCompute File access API. + +--*/ + + + +/*++ + +XcGetProcessUri API + +Description: + +Gets the Uri to a file or directory local to XCompute process. +The returned Uri is a fully qualified and can be used to +create paths for file URI's in the processes root directory or +another directory under the root by appending path/s relative +to the initial working directory. + +NOTE: + +The Job does not have access to directories above the +Process's Root Directory. +All directories e.g. Process Working Directory, Data directory +are sub directories under the Process's Root directory + + +Arguments: + + ProcessHandle + The process handle for which to get the + Process File Uri. + + pRelativePath + The path relative to process's working directory + that will be appended to the working directory. + NOTE: + If relative path is NULL. or '.' or '/', then the working directory + path is returned. + + ppProcessRootDirUri + The Processes Root directory Uri. + Use the XcFreeMemory() API to free this buffer. + + Return Value: + + CsError_OK indicates call succeeded + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessUri( + IN XCPROCESSHANDLE ProcessHandle, + IN PCSTR pRelativePath, + OUT PSTR* ppProcessRootDirUri +); + + + +/*++ + +XcGetProcessPath API + +Description: + +Gets the path to a file or directory local to XCompute process. +The returned path is fully qualified and can be used to +create paths for files in the processes root directory or +another directory under the root by appending path/s relative +to the initial working directory. + +The returned path is suitable for passing to OS APIs like CreateFile() + +NOTE: + +The Job does not have access to directories above the +Process's Root Directory. +All directories e.g. Process Working Directory, Data directory +are sub directories under the Process's Root directory + + +Arguments: + + hProcessHandle + The process handle for which to get the + Process File Uri. + + pszRelativePath + The path relative to process's working directory + that will be appended to the working directory. + NOTE: + If relative path is NULL. or '.' or '/', then the working directory + path is returned. + + ppszProcessRootDirPath + The Processes Root directory Uri. + Use the XcFreeMemory() API to free this buffer. + + Return Value: + + CsError_OK indicates call succeeded + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessPath( + IN XCPROCESSHANDLE hProcessHandle, + IN PCSTR pszRelativePath, + OUT PSTR* ppszProcessRootDirPath +); + +/*++ + +XcOpenProcessFile API + +Description: + +Opens a handle to a remote XCompute processes working File. +Using this handle, an application can read remote files +written by a XCompute process on a given Node. +Writing of files is not supported. Local files can be written +using Ordinary Windows file I/O (restricted to the working +directory and Its children). + +Arguments: + + SessionHandle + Handle to an XCompute session associated with + this call. + + pFileUri + the fully qualified file Uri (UTF-8) obtained by calling the + XcGetProcessFileUri API. + + Flags + Reserved. Must be 0. + + pFileHandle + The returned handle to the opened file. + Set to NULL if error + + pAsyncInfo + The async info structure. Its an alias to the + CS_ASYNC_INFO defined in Cosmos.h. If this + parameter is NULL, then the function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then the operation is carried + on in asynchronous manner. If an asynchronous + operation has been successfully started then + this function terminates immediately with + an HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start the asynchronous operation. + + + Return Value: + + if pAsyncInfo is NULL + CsError_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation (a SUCCESS HRESULT will never + be returned if pAsyncInfo is not NULL). + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcOpenProcessFile( + IN XCSESSIONHANDLE SessionHandle, + IN PCSTR pFileUri, + IN DWORD Flags, + OUT PXCPROCESSFILEHANDLE pFileHandle, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + +XcGetProcessFileSize API + +Description: + +Gets the fileSize of the given process file handle. + +Arguments: + + FileHandle + The handle to the opened file. + + Flags + The options for fetching the size. These option flags are mutually exclusive + One of the following is permissible: + XC_REFRESH_AGGRESSIVE (default) + - visit server to find out latest known length + XC_REFRESH_PASSIVE + - return length from local cache if available otherwise + visit server to find out latest known length + XC_REFRESH_FROM_CACHE + - return length from local cache. + Fail if not available. This is a non blocking call. + + pSize + Pointer to the output size variable. + Must not be NULL. + The memory pointed to by this variable must remain valid and writable + for the duration of the asynchronous operation. + + pAsyncInfo + The async info structure. Its an alias to the + CS_ASYNC_INFO defined in Cosmos.h. If this + parameter is NULL, then the function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then the operation is carried + on in asynchronous manner. If an asynchronous + operation has been successfully started then + this function terminates immediately with + an HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start the asynchronous operation. + + + Return Value: + + E_NOTIMPL is returned if the underlyning file does not support GetFileSize. + + if pAsyncInfo is NULL + CsError_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation (a SUCCESS HRESULT will never + be returned if pAsyncInfo is not NULL). + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessFileSize( + IN XCPROCESSFILEHANDLE FileHandle, + IN UINT Flags, + OUT PUINT64 pSize, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + +XcCloseProcessFile API + +Description: + +Closes the file opened by the XcOpenProcessFile + +Arguments: + + FileHandle + The handle to the opened file. + + Return Value: + + CsError_OK indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCloseProcessFile( + IN XCPROCESSFILEHANDLE FileHandle +); + + + +/*++ + +XcReadProcessFile API + +Description: + +Reads the content of the file opened by the XcOpenProcessFile + +Arguments: + + + FileHandle + The handle to the opened file. + + pBuffer + Pointer to the buffer that receives the data read. + + pBytesRead + Pointer to variable containing size of the buffer + on input. On return this variable receives number + of bytes read. + + pReadPosition + The offset from the beginning of the file at + which to read. + + pAsyncInfo + The async info structure. Its an alias to the + CS_ASYNC_INFO defined in Cosmos.h. If this + parameter is NULL, then the function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then the operation is + carried on in asynchronous manner. If an asynchronous + operation has been successfully started then + this function terminates immediately with + an HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start the asynchronous operation. + + + Return Value: + + if pAsyncInfo is NULL + CsError_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation (a SUCCESS HRESULT will never + be returned if pAsyncInfo is not NULL). + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcReadProcessFile( + IN XCPROCESSFILEHANDLE FileHandle, + OUT PVOID pBuffer, + IN OUT PSIZE_T pBytesRead, + IN OUT XCPROCESSFILEPOSITION* pReadPosition, + IN PCXC_ASYNC_INFO pAsyncInfo +); + + + +/*++ + +XcGetCurrentProcessNodeId API + +Description: + +Gets the current Process Node Id. The Process Node Id to +the node name map is maintained internally. + +Arguments: + + SessionHandle + Handle to an XCompute session associated with + this call. + + pProcessNodeId + Pointer to Pointer of the Id of the node + + Return Value: + + CsError_OK + indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetCurrentProcessNodeId( + IN XCSESSIONHANDLE SessionHandle, + OUT PXCPROCESSNODEID pProcessNodeId +); + + + +/*++ + +XcProcessNodeIdFromName API + +Description: + +Gets the Process Node Id for a node given the node name. The +Process Node Id to the node name map is maintained internally. +If a node name is not found in the internal map, then a new +entry is created and the corrosponding id is returned back + +Arguments: + + SessionHandle + Handle to an XCompute session associated with this + call. Reserved for future use. Must be NULL. + + + pProcessNodeName + Name of the process node for which Id is needed + + pProcessNodeId + Pointer to Pointer of the Id of the node + + Return Value: + + CsError_OK indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcProcessNodeIdFromName( + IN XCSESSIONHANDLE SessionHandle, + IN PCSTR pProcessNodeName, + OUT PXCPROCESSNODEID pProcessNodeId +); + + + +/*++ + +XcProcessNodeNameFromId API + +Description: + +Gets the Process Node name from the given Process Node Id.The +Process Node Id to the node name map is maintained internally. + +Arguments: + + SessionHandle + Handle to an XCompute session associated with this + call. + + + ProcessNodeId + The process Node Id for which the node name + is needed + + ppProcessNodeName + Name of the process node corrosponding to Id + Note: the returned process node name string + is permanently allocated and will remain + valid for the life of the process. There + is no need to make a copy of this string. + + Return Value: + + CsError_OK indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcProcessNodeNameFromId( + IN XCSESSIONHANDLE SessionHandle, + IN XCPROCESSNODEID ProcessNodeId, + OUT PCSTR* ppProcessNodeName +); + + + +#pragma pack( pop ) + +#pragma warning( pop ) + +#if defined(__cplusplus) +} +#endif diff --git a/xcompute_native/inc/XComputeTypes.h b/xcompute_native/inc/XComputeTypes.h new file mode 100644 index 0000000..585f4eb --- /dev/null +++ b/xcompute_native/inc/XComputeTypes.h @@ -0,0 +1,1311 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +#pragma warning( push ) +/* 'X' bytes padding added after member 'Y' */ +#pragma warning( disable: 4820 ) + + + +#if !defined(_PCVOID_DEFINED) +typedef const void* PCVOID; +#define _PCVOID_DEFINED +#endif + +#if defined(__cplusplus) +extern "C" { +#endif + +#if defined(XCOMPUTE_EXPORTS) +#define XCOMPUTEAPI_EXT __declspec(dllexport) +#else +#define XCOMPUTEAPI_EXT __declspec(dllimport) +#endif +#define XCOMPUTEAPI __stdcall + + +/*++ + +Error codes and exit code typedefs + +--*/ +typedef DWORD XCEXITCODE; +typedef HRESULT XCERROR; + + + +/* + +Process state related typedefs. Used by the sync +API, to depict process state + +*/ +typedef DWORD XCPROCESSSTATE; +typedef XCPROCESSSTATE* PXCPROCESSSTATE; + + + +/* + +Various process states. + +State transictions are shown below +Each state is explained in detail in the section below this + + XCPROCESSSTATE_INVALID -------------> XCPROCESSSTATE_COMPLETED + | + | + \/ + XCPROCESSSTATE_UNSCHEDULED -------------> XCPROCESSSTATE_COMPLETED + | + | + \/ + + XCPROCESSSTATE_SCHEDULING-------------> XCPROCESSSTATE_COMPLETED + | + | + \/ + + XCPROCESSSTATE_SCHEDULED-------------> XCPROCESSSTATE_COMPLETED + | + | + \/ + + XCPROCESSSTATE_ASSIGNEDTONODE-------------> XCPROCESSSTATE_COMPLETED + | + | + \/ + + XCPROCESSSTATE_BINDING-------------> XCPROCESSSTATE_COMPLETED + /\ + | + | + \/ + + XCPROCESSSTATE_BINDCOMPLETED-------------> XCPROCESSSTATE_COMPLETED + | + | + \/ + + XCPROCESSSTATE_LAUNCHING-------------> XCPROCESSSTATE_COMPLETED + | + | + \/ + + XCPROCESSSTATE_RUNNING-------------> XCPROCESSSTATE_COMPLETED + | + | + \/ + XCPROCESSSTATE_TERMINATING + | + | + \/ + XCPROCESSSTATE_COMPLETED + + | + | + \/ + XCPROCESSSTATE_STATEDELETED + | + | + \/ + XCPROCESSSTATE_DELETED + + + + +XCPROCESSSTATE_INVALID + The process state is invalid. This will be returned, if a call to + the XcGetProcessState() api is made, even before + XcScheduleProcess() has been called. + +XCPROCESSSTATE_UNSCHEDULED + The process has NOT been scheduled on the Process Scheduler. + It is possible to XcCancelScheduleProcess() in this state + +XCPROCESSSTATE_SCHEDULING + The scheduling is in flight to the Process Scheduler. But it is + not yet assigned to a process node. + It is possible to XcCancelScheduleProcess() in this state + +XCPROCESSSTATE_SCHEDULED + The process has been scheduled on the Process Scheduler. + It is possible to XcCancelScheduleProcess() in this state + +XCPROCESSSTATE_ASSIGNEDTONODE + The process has been assinged to a process Node. The user can + use XcGetProcessNode() api to get PN related info for the process + After this point it is possible to interact with process node state for + the process (update process constraints, request resource bindings, + launch the process, get and set properties, open and read process + files, etc.) + +XCPROCESSSTATE_BINDING + The process is assigned to a PN and resource binding (copying) is + in progress. All required resources need to be copied before a + process can be launched. While in this state, it is possible to interact + with process node state for the process (update process constraints, + request additional resource bindings, get and set properties, open + and read process files, etc.) + +XCPROCESSSTATE_BINDCOMPLETED + The resources copying is compelted. NOTE: This might signal + completion of binding for a batch of resources. The state can + jump back to XCPROCESSSTATE_BINDING. if further bindings are + requested at this point. If the process was originally scheduled + without the XC_CREATEPROCESS_DESCRIPTOR_LATEBOUNDRESOURCES + option set, or if process flag + XCPROCESS_FLAG_LAUNCH_AFTER_RESOURCE_BIND (TBD) has been set + on the process, then the state will automatically proceed to + XCPROCESSSTATE_LAUNCHING. + +XCPROCESSSTATE_LAUNCHING + All resources binding finished and the process is being launched. + +XCPROCESSSTATE_RUNNING + The corresponding win32 process has been created on the PN + node and is currently running + +XCPROCESSSTATE_TERMINATING + The jobobject for the corresponding win32 process is being terminated/ + GC'ed. + +XCPROCESSSTATE_COMPLETED + The XComputeProcess completed. + If the process successfully reached XCPROCESSSTATE_ASSIGNEDTONODE + before completing, then it is still possible to interact with process node state + for the process (e.g., you can still open and read process files). + + NOTE: The process can complete for various reasons. The error + code associated with the state explains the exact reason. + +XCPROCESSSTATE_STATEDELETED + The XCompute process is completed. Its state like directories etc have + been garbage collected. But the ProcessNode still hold information about + the XCompute process statistics. + +XCPROCESSSTATE_DELETEDFROMNODE + The XComputeProcess has been garbage collected and no info + about it exists at the process node. Only locally cached status + information is available. + +*/ +#define XCPROCESSSTATE_ZERO ((XCPROCESSSTATE)0x00000000) +#define XCPROCESSSTATE_INVALID XCPROCESSSTATE_ZERO + +// The folowing states are managed on the Process Scheduler before the process has been allocated on the Process Node +#define XCPROCESSSTATE_UNSCHEDULED ((XCPROCESSSTATE)0x40000000) +#define XCPROCESSSTATE_SCHEDULING ((XCPROCESSSTATE)0x40100000) +#define XCPROCESSSTATE_SCHEDULED ((XCPROCESSSTATE)0x40200000) +#define XCPROCESSSTATE_SCHEDULINGFAILED ((XCPROCESSSTATE)0x40300000) + +// The following states are managed on the Process Node +#define XCPROCESSSTATE_NODE_UNINITIALIZED ((XCPROCESSSTATE)0x80000000) +#define XCPROCESSSTATE_NODE_PREINITIALIZE ((XCPROCESSSTATE)0x80100000) +#define XCPROCESSSTATE_NODE_ALLOCATED ((XCPROCESSSTATE)0x80200000) +#define XCPROCESSSTATE_ASSIGNEDTONODE XCPROCESSSTATE_NODE_ALLOCATED +#define XCPROCESSSTATE_NODE_READYTOBINDRESOURCES ((XCPROCESSSTATE)0x80300000) +#define XCPROCESSSTATE_NODE_BINDINGRESOURCES ((XCPROCESSSTATE)0x80400000) +#define XCPROCESSSTATE_BINDING XCPROCESSSTATE_NODE_BINDINGRESOURCES +#define XCPROCESSSTATE_NODE_RESOURCEBINDINGCOMPLETE ((XCPROCESSSTATE)0x80500000) +#define XCPROCESSSTATE_BINDCOMPLETED XCPROCESSSTATE_NODE_RESOURCEBINDINGCOMPLETE +#define XCPROCESSSTATE_NODE_LOADPENDING ((XCPROCESSSTATE)0x80600000) +#define XCPROCESSSTATE_LAUNCHING XCPROCESSSTATE_NODE_LOADPENDING +#define XCPROCESSSTATE_NODE_LOADING ((XCPROCESSSTATE)0x80700000) +#define XCPROCESSSTATE_NODE_LOADED ((XCPROCESSSTATE)0x80800000) +#define XCPROCESSSTATE_NODE_APPINITIALIZATION ((XCPROCESSSTATE)0x80900000) +#define XCPROCESSSTATE_NODE_APPRUNPENDING ((XCPROCESSSTATE)0x80A00000) +#define XCPROCESSSTATE_NODE_APPRUNNING ((XCPROCESSSTATE)0x80B00000) +#define XCPROCESSSTATE_RUNNING XCPROCESSSTATE_NODE_APPRUNNING +#define XCPROCESSSTATE_NODE_TERMINATING ((XCPROCESSSTATE)0x80c00000) +#define XCPROCESSSTATE_TERMINATING XCPROCESSSTATE_NODE_TERMINATING +#define XCPROCESSSTATE_NODE_COMPLETE ((XCPROCESSSTATE)0x80d00000) +#define XCPROCESSSTATE_COMPLETED XCPROCESSSTATE_NODE_COMPLETE +#define XCPROCESSSTATE_NODE_DELETINGSTATE ((XCPROCESSSTATE)0x80e00000) +#define XCPROCESSSTATE_NODE_STATEDELETED ((XCPROCESSSTATE)0x80f00000) +#define XCPROCESSSTATE_STATEDELETED XCPROCESSSTATE_NODE_STATEDELETED +#define XCPROCESSSTATE_NODE_ZOMBIE ((XCPROCESSSTATE)0x8fffffff) + +// The following states are maintained by the client SDK after the Process Node has forgotten about the process + +#define XCPROCESSSTATE_DELETEDFROMNODE ((XCPROCESSSTATE)0xc0000000) + +// End Process States + +#define XCPROCESSSTATE_NEVER ((XCPROCESSSTATE)0xffffffff) + + + + +/*++ + +NOTE: All byte strings in this .h file are UTF8 strings +unless otherwise noted. + +--*/ + + + +/*++ + +XCPROCESSHANDLE +Handle to a XCompute Process. Used in the API's that require a +process identifier as its input. +Some e.g. API's are: XcScheduleProcess(), XcSetAndGetProcessInfo(). + +This is used to assist API users in following ways: + a. Pass this identifier instead of generating their own unique + process GUIDs + + b. If using Async API, then there is no need to keep track of / + or lookup GUID to call back function/object map. The user + context associated with the handle can be used instead. + + c. All operations related to a process, e.g. Process states, etc + will ultimately get tied to the handle, thus making it easy + for users to collect/modify process related information + +--*/ +typedef struct tagXCPROCESSHANDLE +{ + ULONG_PTR Unused; +} *XCPROCESSHANDLE, **PXCPROCESSHANDLE; + +const XCPROCESSHANDLE INVALID_XCPROCESSHANDLE = NULL; + + + +/*++ + +XCPROCESSNODEID +Identifies a XCompute Node. Used in the ProcessNodeIdFromName() +and the ProcessNodeNameFromId() API's. All XCompute API's take +the XCPROCESSNODEID as input where ever a node related entry +is needed. This is used to assist users to be able to pass +this identifier instead of passing strings around and having to +go through allocating/deallocating/Copying/Comparing them. + +Fields: +--*/ +typedef struct tagXCPROCESSNODEID +{ + ULONG_PTR Unused; +} *XCPROCESSNODEID, **PXCPROCESSNODEID; + +const XCPROCESSNODEID INVALID_XCPROCESSNODEID = NULL; + + +/*++ + +XC_RESOURCE_SOURCE_TYPE enum + +Determines the type of resource passed in the XC_RESOURCEFILE_DESCRIPTOR + +Fields: + + XC_RESOURCE_SOURCE_UTF8_PATHNAME + Indicates XC_RESOURCEFILE_DESCRIPTOR's pResourceSource + points to a UTF-8share path. + The pathname may be an xstream URI or a path to a + working file in another XCompute Process in the same Job. + The resource will be copied as a binary image from this + location + + XC_RESOURCE_SOURCE_EMBEDDED_CONTENT + Indicates XC_RESOURCEFILE_DESCRIPTOR's pResourceSource + points to a embedded resource. The contents of the buffer + pointed to by pResource are written to the resource file + as a binary image. + + +--*/ +typedef enum tagXC_RESOURCE_SOURCE_TYPE { + XC_RESOURCE_SOURCE_UTF8_PATHNAME = 0, + XC_RESOURCE_SOURCE_EMBEDDED_CONTENT +} XC_RESOURCE_SOURCE_TYPE; + + + + +/*++ + +XC_RESOURCEFILE_DESCRIPTOR structure + +Holds information about a single resource file. + +Fields: + + Size + sizeof(XC_RESOURCEFILE_DESCRIPTOR) + + Flags + Reserved for future use + + pFileName + Name of the file on the destination. + The name is a relative path to the + target processl's working directory + + ResourceSourceType + See XC_RESOURCE_SOURCE_TYPE above + + NumberOfResourceSourceBytes + Length of the pResourceSource buffer. + + pResourceSource + Depending on resourceSourceType, either + points to an embedded resource or to a share + path from where the resource will be copied + +--*/ +typedef struct tagXC_RESOURCEFILE_DESCRIPTOR{ + SIZE_T Size; + DWORD Flags; + PCSTR pFileName; + XC_RESOURCE_SOURCE_TYPE ResourceSourceType; + SIZE_T NumberOfResourceSourceBytes; + PVOID pResourceSource; +} XC_RESOURCEFILE_DESCRIPTOR, *PXC_RESOURCEFILE_DESCRIPTOR; +typedef const XC_RESOURCEFILE_DESCRIPTOR* PCXC_RESOURCEFILE_DESCRIPTOR; + + + +/*++ + +This bit flag indicates that the resources for a process will +be late bound. This is used in the Flags parameter in the +XC_CREATEPROCESS_DESCRIPTOR. See below + +--*/ +#define XCCREATEPROCESSDESCRIPTOR_LATEBOUNDRESOURCES 0x00000002 + + + +/*++ + +Typedefs/various aliases +--*/ +typedef GUID XC_PROCESSID; +typedef const XC_PROCESSID* PCXC_PROCESSID; +typedef GUID XC_TASKID_USER; +typedef GUID XC_TASKID; +typedef GUID XC_JOBID; +typedef const XC_JOBID* PCXC_JOBID; + +/*++ + +XC_ASYNC_INFO structure + +Each function that may be executed asynchronously takes pointer +to XC_ASYNC_INFO structure as last parameter. + +If NULL is passed then function completes in synchronous manner +and error code is returned as return value. + +If parameter is not NULL then operation is carried on in asynchronous manner. +If asynchronous operation has been successfully started then function terminates +immediately with HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. +Any other return value indicates that it was impossible to start asynchronous operation. + +Fields: + + Size Size of structure in bytes. Set to sizeof(XC_ASYNC_INFO). + + pOperationState Pointer to error code returned by completed operation. + While operation is in progress value is set to HRESULT_FROM_WIN32(ERROR_IO_PENDING). + Before completion is reported value is set to an error code of completed operation. + Cannot be NULL. + + Event Handle to event. Event is set once operation is completed. May be NULL. + + IOCP Handle to IO completion port. If not NULL then upon completion status is posted + to specified completion port. + + pOverlapped Pointer to OVERLAPPED structure. Used in conjunction with IOCP parameter + to post status to IOCP. + Should be null if IOCP is NULL, cannot be NULL if IOCP is not NULL. + + CompletionKey Used in conjunction with IOCP parameter to post status to IOCP. + Should be 0 if IOCP is NULL. + + unusedX Fields reserved for future use. + +Note that XC_ASYNC_INFO structure is not required to be available for the duration of asynchronous call +(for example, this allows to allocate XC_ASYNC_INFO structure on stack). +In contrast variable specified by pOperationState pointer is required to be available +for the duration of the asynchronous call. + +--*/ +typedef struct tagXC_ASYNC_INFO { + SIZE_T Size; + + HRESULT* pOperationState; + + HANDLE Event; + + HANDLE IOCP; + LPOVERLAPPED pOverlapped; + UINT_PTR CompletionKey; + + UINT64 unused0; + UINT64 unused1; +} XC_ASYNC_INFO, *PXC_ASYNC_INFO; +typedef const XC_ASYNC_INFO* PCXC_ASYNC_INFO; + + +/*++ + +A session handle will contain user related information that +can be used to determine what part of information the user +can access. + +--*/ +typedef struct tagXCSESSIONHANDLE +{ + ULONG_PTR Unused; +} *XCSESSIONHANDLE, **PXCSESSIONHANDLE; +const XCSESSIONHANDLE INVALID_XCSESSIONHANDLE = NULL; + + + +/*++ + +XC_OPEN_SESSION_PARAMS structure + +Session related information + +Fields: + + Size + sizeof(XC_OPEN_SESSION_PARAMS) + + Flags + Reserved for future use. Must be 0. + + pCluster + Name of the cluster to connect to. + If this filed is NULL then default cluster + will be used. + + ClientId + The unique ID to use for syncing process + related information with the Process Scheduler. + If a NULL GUID, a default ID will be generated. + + Use this field, when implementing failover. + For e.g if JobManager wants to provide failover, + it can use a well known clientId, between various + redundant JobManagers. When the failover happens, + the new Job Manager can use the same client Id to + sync process states with the Process Scheduler. + +--*/ + +typedef struct tagXC_OPEN_SESSION_PARAMS{ + SIZE_T Size; + DWORD Flags; + PCSTR pCluster; + GUID ClientId; +}XC_OPEN_SESSION_PARAMS, *PXC_OPEN_SESSION_PARAMS; +typedef const XC_OPEN_SESSION_PARAMS* PCXC_OPEN_SESSION_PARAMS; + + + +/*++ + +An XCDATETIME is defined as the number of 100-nanosecond intervals +that have elapsed since 12:00 A.M. January 1, 1601 (UTC). It is +the representation of choice whenever an absolute date/time must +be used. + +XCDATETIME is equivalent to a windows FILETIME value without local +time zone adjustment. + +--*/ +typedef UINT64 XCDATETIME; + +#define XCDATETIME_NEVER _UI64_MAX +#define XCDATETIME_LONGAGO 0 + + + +/*++ + +XCTIMEINTERVAL represents a measurement of elapsed time in 100ns +Intervals. It is a signed entity (elapsed time may be negative). +It is the natural type for the result of subtracting two XCDATETIME +Values. +--*/ +typedef INT64 XCTIMEINTERVAL; + + + +/*++ + +The below #defines help define commonly used time intervals + +--*/ +#define XCTIMEINTERVAL_INFINITE 0X7FFFFFFFFFFFFFFF // TODO: Use _I64_MAX after upgrading to VS2010 +#define XCTIMEINTERVAL_NEGATIVEINFINITE _I64_MIN +#define XCTIMEINTERVAL_ZERO 0 +#define XCTIMEINTERVAL_QUANTUM 1 +#define XCTIMEINTERVAL_100NS ( (XCTIMEINTERVAL) (XCTIMEINTERVAL_QUANTUM) ) +#define XCTIMEINTERVAL_MICROSECOND ( (XCTIMEINTERVAL) ( XCTIMEINTERVAL_100NS * 10 ) ) +#define XCTIMEINTERVAL_MILLISECOND ( (XCTIMEINTERVAL) ( XCTIMEINTERVAL_MICROSECOND * 1000 ) ) +#define XCTIMEINTERVAL_SECOND ( (XCTIMEINTERVAL) ( XCTIMEINTERVAL_MILLISECOND * 1000 ) ) +#define XCTIMEINTERVAL_MINUTE ( (XCTIMEINTERVAL) ( XCTIMEINTERVAL_SECOND * 60 ) ) +#define XCTIMEINTERVAL_HOUR ( (XCTIMEINTERVAL) ( XCTIMEINTERVAL_MINUTE * 60 ) ) +#define XCTIMEINTERVAL_DAY ( (XCTIMEINTERVAL) ( XCTIMEINTERVAL_HOUR * 24 ) ) +#define XCTIMEINTERVAL_WEEK ( (XCTIMEINTERVAL) ( XCTIMEINTERVAL_DAY * 7 ) ) + + + +/*++ + +Flags for the various options set in the XC_PROCESS_CONSTRAINTS structure + +--*/ +#define XCPROCESSCONSTRAINTOPTION_SETMAXREMAININGELAPSEDEXECUTIONTIME 0x1 +#define XCPROCESSCONSTRAINTOPTION_SETMAXREMAININGRETAINAFTERTERMINATETIME 0x2 +#define XCPROCESSCONSTRAINTOPTION_SETMAXPERWIN32PROCESSUSERMODETIME 0x4 +#define XCPROCESSCONSTRAINTOPTION_SETMAXREMAININGUSERMODETIME 0x8 +#define XCPROCESSCONSTRAINTOPTION_SETMAXWORKINGSETSIZE 0x10 +#define XCPROCESSCONSTRAINTOPTION_SETMAXNUMWIN32PROCESSES 0x20 +#define XCPROCESSCONSTRAINTOPTION_SETMAXPERWIN32PROCESSMEMORYSIZE 0x40 +#define XCPROCESSCONSTRAINTOPTION_SETMAXMEMORYSIZE 0x80 + + + +/*++ + +Default process priority + +--*/ +#define XCPROCESSPRIORITY_DEFAULT 0x80000000 + + + +/*++ + +XC_PROCESS_CONSTRAINTS structure + +Constraints that will be applied to the process that gets started +on a given node + +Fields: + + Size + sizeof(XC_PROCESS_CONSTRAINTS) + + ProcessConstraintOptions + Bit flag indicating what options have + been set + + MaxRemainingElapsedExecutionTime + Maximum amount of time process can + continue to run without terminating. + + MaxRemainingRetainAfterTerminateTime + Amount of time after process + termination before the process + persistent state is discarded + + MaxPerWin32ProcessUserModeTime + Max amount of user-mode CPU time for + each Windows process associated with + the XCompute process + + MaxRemainingUserModeTime + Max amount of total user-mode CPU + time for XCompute process + + MaxWorkingSetSize + Maximum working set size for + Windows processes + + MaxNumWin32Processes + Maximum number of Windows processes + that can be running + + MaxPerWin32ProcessMemorySize + Maximum amount of memory per win32 + process + + MaxMemorySize + Max total memory for the XCompute + process + +--*/ +typedef struct tagXC_PROCESS_CONSTRAINTS{ + SIZE_T Size; + DWORD ProcessConstraintOptions; + XCTIMEINTERVAL MaxRemainingElapsedExecutionTime; + XCTIMEINTERVAL MaxRemainingRetainAfterTerminateTime; + XCTIMEINTERVAL MaxPerWin32ProcessUserModeTime; + XCTIMEINTERVAL MaxRemainingUserModeTime; + UINT64 MaxWorkingSetSize; + UINT32 MaxNumWin32Processes; + UINT64 MaxPerWin32ProcessMemorySize; + UINT64 MaxMemorySize; +} XC_PROCESS_CONSTRAINTS, *PXC_PROCESS_CONSTRAINTS; +typedef const XC_PROCESS_CONSTRAINTS* PCXC_PROCESS_CONSTRAINTS; + + + +/*++ + +XC_CREATEPROCESS_DESCRIPTOR structure + +Used in ScheduleProcess API. Has all the information needed to +launch a process on a particular node + +Fields: + + Size + Sizeof(XC_CREATEPROCESS_DESCRIPTOR) + + Flags + Option bit flags + + XC_CREATEPROCESS_DESCRIPTOR_LATEBOUNDRESOURCES + indicates that the resources are late + bound. When the process gets created + on a node, the process is set to + UnInitialized state, and waits for the + parent process to contact the PN to bind + the resources + + pCommandLine + Command line that will launch the process + + pProcessClass + Process class name. User-defined. + + pProcessFriendlyName + The process friendly name.User-defined + + pEnvironmentStrings + The environment strings that will be + set before launching the process on a node + The environment strings are represented as + a series of null-terminated UTF8 strings + with an extra NULL at the end + + pAppProcessConstraints + See PCXC_PROCESS_CONSTRAINTS above + + NumberOfResourceFileDescriptors + The number of resource file descriptors in + the pResourceFileDescriptors array + + pResourceFileDescriptors + Pointer to array of + PCXC_RESOURCEFILE_DESCRIPTOR's. These + resources will be copied to the process + working directory before launching the + process using the commandline + +--*/ +typedef struct tagXC_CREATEPROCESS_DESCRIPTOR{ + SIZE_T Size; + DWORD Flags; + PCSTR pCommandLine; + PCSTR pProcessClass; + PCSTR pProcessFriendlyName; + PCSTR pEnvironmentStrings; + PCXC_PROCESS_CONSTRAINTS pAppProcessConstraints; + SIZE_T NumberOfResourceFileDescriptors; + PCXC_RESOURCEFILE_DESCRIPTOR pResourceFileDescriptors; +} XC_CREATEPROCESS_DESCRIPTOR, *PXC_CREATEPROCESS_DESCRIPTOR; +typedef const XC_CREATEPROCESS_DESCRIPTOR* PCXC_CREATEPROCESS_DESCRIPTOR; + + + +/*++ + +Defines the Network Locality Params used in the +XcGetNetworkLocalityPathOfProcessNode() API. +These params are passed to the API, to identify, the +Affinity level. The resulting NetworkLocalityParam returned +from the API, can then be passed cia the XC_AFFINITY struct +(defined below), to the Process Scheduler, to help the +Process Scheduler in making decisions about the choice of +Process Node to pick to run a given XCompute process. + +NOTE: The special Network Locality Param ".." can be combined + with other locality params to represent one level up from + the current level. + E.g. XCLOCALITYPARAM_POD/.. (NOTE the forward slash) +--*/ +#define XCLOCALITYPARAM_ONELEVELUP ".." +#define XCLOCALITYPARAM_POD "POD" +#define XCLOCALITYPARAM_L2SWITCH "L2" +#define XCLOCALITYPARAM_L3SWITCH "L3" +#define XCLOCALITYPARAM_VLAN "VLAN" +#define XCLOCALITYPARAM_CLUSTER "CLUSTER" +#define XCLOCALITYPARAM_DATACENTER "DC" + + + +/*++ + +Defines the bit flag used in the XC_AFFINITY structure. +If XCAFFINITY_HARD of the Flags in the XC_AFFINITY structure is +set, then the affinity is considered to have hard affinity +to the NetworkNodePath/s. See below for details. + +--*/ +#define XCAFFINITY_HARD 0x01 + + + +/*++ + +XC_AFFINITY structure + +Each Affinity is comprised of list of network locality paths, +an associated weight and a flag for hard affinity. +A network locality can refer to a data center, a top/middle +level switch, POD, or a specific host machine. + +Fields: + + Size + Sizeof(XC_AFFINITY) + + Flags + Bit flags. XC_AFFINITY_HARD indicates that + affinity is hard affinity. + + Weight + The Process Scheduler will give preference to + the Affinity (list of Nodes) that have higher + weight, while picking up the Node on which to + run the XCompute Process. + The intended units for Weight are + "estimated bytes of I/O" + + NumberOfNetworkLocalityPaths + Number of Nodes in + pNetworkLocalityPaths array. + + pNetworkLocalityPaths + Pointer to the network locality paths array. + A network locality path is represented as a + string and is an opaque format. The caller + gets the locality path information by calling the + XcGetNetworkLocalityPath API + +--*/ +typedef struct tagXC_AFFINITY{ + SIZE_T Size; + DWORD Flags; + UINT64 Weight; + SIZE_T NumberOfNetworkLocalityPaths; + PCSTR* pNetworkLocalityPaths; +} XC_AFFINITY, *PXC_AFFINITY; +typedef const XC_AFFINITY* PCXC_AFFINITY; + + + +/*++ + +XC_LOCALITY_DESCRIPTOR structure + +Locality is represented as a collection of Affinities. + +Fields: + + Size + sizeof(XC_LOCALITY_DESCRIPTOR) + + Flags + Reserved. Must be 0. + + NumberOfAffinities + Number of XC_AFFINITY'es + + pAffinities + Pointer to Array of Affinities + +--*/ +typedef struct tagXC_LOCALITY_DESCRIPTOR{ + SIZE_T Size; + DWORD Flags; + SIZE_T NumberOfAffinities; + PXC_AFFINITY pAffinities; +} XC_LOCALITY_DESCRIPTOR, *PXC_LOCALITY_DESCRIPTOR; +typedef const XC_LOCALITY_DESCRIPTOR* PCXC_LOCALITY_DESCRIPTOR; + + + +/*++ + +XC_SCHEDULEPROCESS_DESCRIPTOR + +The descriptor that has all the information about the process +to be scheduled + +Fields: + + Size + sizeof(XC_SCHEDULEPROCESS_DESCRIPTOR) + + Flags Reserved for later use. Must be 0. + + ProcessPriority + The priority of the process. The priority + is process priority, within all the + processes for a given job. This is + different from job priority. + + pLocalityDescriptor + See XC_LOCALITY_DESCRIPTOR above + + pCreateProcessDescriptor + See XC_CREATEPROCESS_DESCRIPTOR above + +--*/ +typedef struct tagXC_SCHEDULEPROCESS_DESCRIPTOR{ + SIZE_T Size; + DWORD Flags; + UINT32 ProcessPriority; + PCXC_LOCALITY_DESCRIPTOR pLocalityDescriptor; + PCXC_CREATEPROCESS_DESCRIPTOR pCreateProcessDescriptor; +} XC_SCHEDULEPROCESS_DESCRIPTOR, *PXC_SCHEDULEPROCESS_DESCRIPTOR; + +typedef +const XC_SCHEDULEPROCESS_DESCRIPTOR* PCXC_SCHEDULEPROCESS_DESCRIPTOR; + + + +/*++ + +XC_PROCESSPROPERTY_INFO + +The structure is embedded in the XC_POCESS_INFO struct explained +below. It has all the information related to a particular property + +Fields: + + Size + sizeof(XC_SCHEDULE_PROCESS_RESULTS) + + pPropertyLabel + The property label + + propertyVersion + The property version + + pPropertyString + The property string value + + PropertyBlockSize + Memory block size of property + + pPropertyBlock + Pointer to memory block related to property + +--*/ +typedef struct tagXC_PROCESSPROPERTY_INFO{ + SIZE_T Size; + PSTR pPropertyLabel; + UINT64 PropertyVersion; +#if __midl + [string] +#endif + PSTR pPropertyString; + UINT32 PropertyBlockSize; + UINT32 bugbugPAD; +#if __midl + [size_is(PropertyBlockSize)] +#endif + char * pPropertyBlock; +} XC_PROCESSPROPERTY_INFO, *PXC_PROCESSPROPERTY_INFO; +typedef const XC_PROCESSPROPERTY_INFO* PCXC_PROCESSPROPERTY_INFO; + + + +/*++ + +XC_PROCESS_STATISTICS + +Contains all the statistics related to a given process/job + +Fields: + + Size + sizeof(XC_PROCESS_STATISTICS) + + Flags + Reserved for later use + + ProcessUserTime + Total user time the whole process + consumed in 100 nanosec + + ProcessKernelTime + Total kernel time the whole process + consumed in 100 nanosec + + PageFaults + Total #page faults for the whole process + + TotalProcessesCreated + Total #win32 processes the process ever + created + + PeakVMUsage + The peak Virtual memory usage + + PeakMemUsage + The peak working set memory usage + + MemUsageSeconds + Working set memory usage * time used + + TotalIo + Total IO transferred + +--*/ +typedef struct tagXC_PROCESS_STATISTICS{ + SIZE_T Size; + DWORD Flags; + XCTIMEINTERVAL ProcessUserTime; + XCTIMEINTERVAL ProcessKernelTime; + INT32 PageFaults; + INT32 TotalProcessesCreated; + UINT64 PeakVMUsage; + UINT64 PeakMemUsage; + UINT64 MemUsageSeconds; + UINT64 TotalIo; +} XC_PROCESS_STATISTICS, *PXC_PROCESS_STATISTICS; +typedef const XC_PROCESS_STATISTICS* PCXC_PROCESS_STATISTICS; + + + +/*++ + +Bit flag definitions for XC_PROCESSINFO structure that is used +in the GetProcessProperty API + +--*/ +#define XCPROCESSINFOOPTION_STATICINFO (0x01) +#define XCPROCESSINFOOPTION_TIMINGINFO (0x02) +#define XCPROCESSINFOOPTION_EFFECTIVECONSTRAINTS (0x04) +#define XCPROCESSINFOOPTION_EXTENDEDPROCESSDESCRIPTOR (0x08) +#define XCPROCESSINFOOPTION_EXTENDEDJOBDESCRIPTOR (0x10) +#define XCPROCESSINFOOPTION_PROCESSSTAT (0x20) +#define XCPROCESSINFOOPTION_APPCONSTRAINTS (0x40) +#define XCPROCESSINFOOPTION_SYSTEMCONSTRAINTS (0x80) + +#define XCPROCESSINFOOPTION_All \ + XCPROCESSINFOOPTION_STATICINFO | \ + XCPROCESSINFOOPTION_TIMINGINFO | \ + XCPROCESSINFOOPTION_EFFECTIVECONSTRAINTS \ + XCPROCESSINFOOPTION_EXTENDEDPROCESSDESCRIPTOR | \ + XCPROCESSINFOOPTION_EXTENDEDJOBDESCRIPTOR | \ + XCPROCESSINFOOPTION_PROCESSSTAT | \ + XCPROCESSINFOOPTION_APPCONSTRAINTS | \ + XCPROCESSINFOOPTION_SYSTEMCONSTRAINTS + + + +/*++ + +XC_SETANDGETPROCESSINFO_REQINPUT + +The structure is used to make the XcPnSetAndGetProcessInfo call. +It contains the various inputs to the API clubbed together. + +Fields: + + Size + sizeof(XC_SETANDGETPROCESSINFO_REQINPUT) + + pAppProcessConstraints + The process constraints to be set for the + process. The user will need to preserve this + structure till the async call is completed + + NumberOfProcessPropertiesToSet + The number of properties to set in the + pPropertiesToSet array + + ppPropertiesToSet Pointer to property info array. These are the + properties that will be set in this call. + + pBlockOnPropertyLabel + Name of the property on which to block.The + request finishes, when either the process + terminates, or the property is changed or + after timeout amount of time. + + BlockOnPropertyversionLastSeen + The latest known version number of property + on which to block + + MaxBlockTime Time to wait for property to change or pricess + to terminste before returning with unchanged + property version. If 0, API returns + immediately with current values. + + pPropertyFetchTemplate + The property fetch template. It support the + * wild card. A set of properties, whose + labels match the propertyFetchTemplate are + returned. If NULL, no properties are returned + + ProcessInfoFetchOptions + bit flag indicating the different + processInfo fields to fetch. + +--*/ +typedef struct tagXC_SETANDGETPROCESSINFO_REQINPUT{ + DWORD Size; + PXC_PROCESS_CONSTRAINTS pAppProcessConstraints; + UINT32 NumberOfProcessPropertiesToSet; + PXC_PROCESSPROPERTY_INFO* ppPropertiesToSet; + PCSTR pBlockOnPropertyLabel; + UINT64 BlockOnPropertyversionLastSeen; + XCTIMEINTERVAL MaxBlockTime; + PCSTR pPropertyFetchTemplate; + DWORD ProcessInfoFetchOptions; +} XC_SETANDGETPROCESSINFO_REQINPUT, + *PXC_SETANDGETPROCESSINFO_REQINPUT; + +typedef +const XC_SETANDGETPROCESSINFO_REQINPUT* PCXC_SETANDGETPROCESSINFO_REQINPUT; + + + +/*++ + +XC_PROCESS_INFO + +The structure gets returned as a result of the XcPnGetProcessProperty +call. Use the XcFreeMemory API to release memory for this structure + +Fields: + + Size + sizeof(XC_SCHEDULE_PROCESS_RESULTS) + + Flags + Bit flag that indicates which fields in the + data structure have valid information. The + bit flags are defined above + + ProcessState + The current state of the process from the PN's point of view. + This field is always sent. + + ProcessStatus + The process status. Indicates whether the + process is running or exited, and the reason. + This field is always sent. + + ExitCode + The process exit code. + This field is always sent. + + Win32Pid The Windows processId of the process + This field is always sent. + + NumberofProcessProperties + Number of XC_PROCESSPROPERTY_INFO's returned + + ppProperties + Array of XC_PROCESSPROPERTY_INFO structs + + CurrentPnTime + Always sent. This is the time on PN + + CreatedTime + Time when Win32 CreateProcess was + initiated (XCDATETIME_NEVER if not yet created) + Bit flag:XCPROCESSINFOOPTION_TIMINGINFO + + BeginExecutionTime + Time when Win32 process was first resumed + (XCDATETIME_NEVER if not yet resumed) + Bit flag:XCPROCESSINFOOPTION_TIMINGINFO + + TerminatedTime + Time when Win32 Process terminated + (XCDATETIME_NEVER if not yet terminated) + Bit flag:XCPROCESSINFOOPTION_TIMINGINFO + + LastPropertyUpdateTime + Most recent time when any property was set + + pEffectiveProcessConstraints + Effective constraints for the process (combined constraints + from application and system) + pointer to XC_PROCESS_CONSTRAINTS struct + Bit flag:XCPROCESSINFOOPTION_EFFECTIVECONSTRAINTS + + pAppProcessConstraints + Application constraints for the process + pointer to XC_PROCESS_CONSTRAINTS struct + Bit flag:XCPROCESSINFOOPTION_APPCONSTRAINTS + + pSystemProcessConstraints + System constraints for the process + pointer to XC_PROCESS_CONSTRAINTS struct + Bit flag:XCPROCESSINFOOPTION_SYSTEMCONSTRAINTS + + pCommandLine + The command line for the process + + pProcessStatistics + Pointer to the XC_PROCESS_STATISTICS struct + Bit flag:XCPROCESSINFOOPTION_STAT + +--*/ +typedef struct tagXC_PROCESS_INFO{ + SIZE_T Size; + DWORD Flags; + XCPROCESSSTATE ProcessState; + XCERROR ProcessStatus; + XCEXITCODE ExitCode; + UINT32 Win32Pid; + UINT32 NumberofProcessProperties; + PXC_PROCESSPROPERTY_INFO *ppProperties; + XCDATETIME CurrentPnTime; + XCDATETIME CreatedTime; + XCDATETIME BeginExecutionTime; + XCDATETIME TerminatedTime; + XCDATETIME LastPropertyUpdateTime; + PXC_PROCESS_CONSTRAINTS pEffectiveProcessConstraints; + PXC_PROCESS_CONSTRAINTS pAppProcessConstraints; + PXC_PROCESS_CONSTRAINTS pSystemProcessConstraints; + PSTR pCommandLine; + PXC_PROCESS_STATISTICS pProcessStatistics; +} XC_PROCESS_INFO, *PXC_PROCESS_INFO; +typedef const XC_PROCESS_INFO* PCXC_PROCESS_INFO; + + + +/*++ + +XC_SETANDGETPROCESSINFO_REQRESULTS + +The structure is gets returned as a result of call to the +XcPnSetAndGetProcessInfo API. +It contains the results that match the ProcessInfoFetchOptions +and the PropertyFetchTemplate passed to the API via the +XC_SETANDGETPROCESSINFO_REQINPUT struct + +Fields: + + Size + sizeof(XC_SETANDGETPROCESSINFO_REQINPUT) + + pProcessInfo + The process info that has information about + all the properties for which information + was asked to be retreived (using the + PropertyFetchTemplate). It also has all + the information that was asked to be + retreived using the ProcessInfoFetchOptions. + + NumberOfPropertyVersions + The number of property versions in the + pPropertyVersions array + + pPropertyVersions + Pointer to array of property versions. + Note: The indexes of version numbers in the + pPropertyVersions array corrosponds 1:1 with the + pPropertiesToSet array in the + XC_SETANDGETPROCESSINFO_REQINPUT that gets + passed to the XcPnSetAndGetProcessInfo() API. +--*/ +typedef struct tagXC_SETANDGETPROCESSINFO_REQRESULTS{ + DWORD Size; + PXC_PROCESS_INFO pProcessInfo; + UINT32 NumberOfPropertyVersions; +#if __midl + [size_is(NumberOfPropertyVersions)] +#endif + UINT64* pPropertyVersions; +} XC_SETANDGETPROCESSINFO_REQRESULTS, + *PXC_SETANDGETPROCESSINFO_REQRESULTS; + +typedef +const XC_SETANDGETPROCESSINFO_REQRESULTS* + PCXC_SETANDGETPROCESSINFO_REQRESULTS; + + + +/*++ + +XCPROCESSFILEHANDLE +A handle to represent an open XCompute Process File. +This is used in the XCompute Process File API, which gives +the ability to read remote files written by a XComputeProcess +into its working directory + +Fields: +--*/ +typedef struct tagXCPROCESSFILEHANDLE +{ + ULONG_PTR Unused; +} *XCPROCESSFILEHANDLE, **PXCPROCESSFILEHANDLE; + +const XCPROCESSFILEHANDLE INVALID_XCPROCESSFILEHANDLE = NULL; + + + +/* File offset value for XCompute files */ +typedef UINT64 XCPROCESSFILEPOSITION, *PXCPROCESSFILEPOSITION; + + + +/* + +Various XcGetProcessFileSize options + + XCREFRESH_AGGRESSIVE (default) + - visit server to find out latest known length + + XCREFRESH_PASSIVE + - return length from local cache if available otherwise + visit server to find out latest known length + + XCREFRESH_FROM_CACHE + - return length from local cache + fail if not available + +*/ +#define XCREFRESH_AGGRESSIVE 0x10000000u +#define XCREFRESH_PASSIVE 0x20000000u +#define XCREFRESH_FROM_CACHE 0x30000000u + + + +#pragma warning( pop ) + +#if defined(__cplusplus) +} +#endif diff --git a/xcompute_native/inc/auto_any.h b/xcompute_native/inc/auto_any.h new file mode 100644 index 0000000..01c79f4 --- /dev/null +++ b/xcompute_native/inc/auto_any.h @@ -0,0 +1,331 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//+--------------------------------------------------------------------------- +// +// File: auto_any.h +// +// Contents: automatic resource management, a-la std::auto_ptr +// +// Classes: auto_any<> and various typedefs +// +// Functions: get +// reset +// release +// valid +// address +// +//---------------------------------------------------------------------------- + +#ifndef AUTO_ANY +#define AUTO_ANY +#include +#include "smart_any_fwd.h" + +#pragma warning(push) + +// 4284 warning for operator-> returning non-pointer; +// compiler issues it even if -> is not used for the specific instance +#pragma warning(disable: 4284) + +namespace detail +{ + // friend function definitions go in auto_any_helper + template + struct auto_any_helper; +} + +// proxy reference for auto_any copying +template +struct auto_any_ref +{ + // construct from compatible auto_any + auto_any_ref( auto_any & that ) + : m_that( that ) + { + } + + // reference to constructor argument + auto_any & m_that; + +private: + auto_any_ref * operator=( auto_any_ref const & ); +}; + +// wrap a resource to enforce strict ownership and ensure proper cleanup +template +class auto_any +{ + typedef detail::safe_types safe_types; + + // disallow comparison of auto_any's + bool operator==( detail::safe_bool ) const; + bool operator!=( detail::safe_bool ) const; + +public: + typedef typename detail::holder::type element_type; + typedef close_policy close_policy_type; + typedef typename safe_types::pointer_type pointer_type; + typedef typename safe_types::reference_type reference_type; + + // Fix-up the invalid_value type on older compilers + typedef typename detail::fixup_invalid_value:: + template rebind::type invalid_value_type; + + friend struct detail::auto_any_helper; + + // construct from object pointer + explicit auto_any( T t = invalid_value_type() ) + : m_t( t ) + { + } + + // construct by assuming pointer from right auto_any + auto_any( auto_any & right ) + : m_t( release( right ) ) + { + } + + // construct by assuming pointer from right auto_any_ref + auto_any( auto_any_ref right ) + : m_t( release( right.m_that ) ) + { + } + + // convert to compatible auto_any_ref + operator auto_any_ref() + { + return auto_any_ref( *this ); + } + + // assign compatible right + auto_any & operator=( + auto_any & right ) + { + reset( *this, release( right ) ); + return *this; + } + + // assign compatible right.ref + auto_any & operator=( + auto_any_ref right ) + { + reset( *this, release( right.m_that ) ); + return *this; + } + + // destroy the object + ~auto_any() + { + if( valid() ) + { + close_policy::close( m_t ); + } + } + + // return pointer to class object (assume pointer) + pointer_type operator->() const + { + #ifdef SMART_ANY_PTS + // You better not be applying operator-> to a handle! + static detail::static_assert::value> const cannot_dereference_a_handle; + #endif + assert( valid() ); + return safe_types::to_pointer( m_t ); + } + + // for use when auto_any appears in a conditional + operator detail::safe_bool() const + { + return valid() ? detail::safe_true : detail::safe_false; + } + + // for use when auto_any appears in a conditional + bool operator!() const + { + return ! valid(); + } + + #ifdef SMART_ANY_PTS + // if this auto_any is managing an array, we can use operator[] to index it + typename detail::deref::type operator[]( int i ) const + { + static detail::static_assert::value> const cannot_dereference_a_handle; + static detail::static_assert::value> const accessed_like_an_array_but_not_deleted_like_an_array; + assert( valid() ); + return m_t[ i ]; + } + + // unary operator* lets you write code like: + // auto_any pfoo( new foo ); + // foo & f = *pfoo; + reference_type operator*() const + { + static detail::static_assert::value> const cannot_dereference_a_handle; + assert( valid() ); + return safe_types::to_reference( m_t ); + } + #endif + +private: + + bool valid() const + { + // see if the managed resource is in the invalid state. + return m_t != static_cast( invalid_value_type() ); + } + + // the wrapped object + element_type m_t; +}; + +namespace detail +{ + // friend function definitions go in auto_any_helper + template + struct auto_any_helper + { + // return wrapped pointer + static T get( auto_any const & t ) + { + return t.m_t; + } + + // return wrapped pointer and give up ownership + static T release( auto_any & t ) + { + // Fix-up the invalid_value type on older compilers + typedef typename detail::fixup_invalid_value:: + template rebind::type invalid_value_type; + + T tmpT = t.m_t; + t.m_t = static_cast( invalid_value_type() ); + return tmpT; + } + + // destroy designated object and store new pointer + static void reset( auto_any & t, T newT ) + { + if( t.m_t != newT ) + { + if( t.valid() ) + { + close_policy::close( t.m_t ); + } + t.m_t = newT; + } + } + + typedef typename auto_any::element_type element_type; + + // return the address of the wrapped pointer + static element_type* address( auto_any & t ) + { + // check to make sure the wrapped object is in the invalid state + assert( !t.valid() ); + return address_of( t.m_t ); + } + }; +} + +// return wrapped resource +template +inline T get( auto_any const & t ) +{ + return detail::auto_any_helper::get( t ); +} + +// return true if the auto_any contains a currently valid resource +template +inline bool valid( auto_any const & t ) +{ + return t; +} + +// return wrapped resource and give up ownership +template +inline T release( auto_any & t ) +{ + return detail::auto_any_helper::release( t ); +} + +// destroy designated object and store new resource +template +inline void reset( auto_any & t ) +{ + typedef typename detail::fixup_invalid_value:: + template rebind::type invalid_value_type; + detail::auto_any_helper::reset( t, invalid_value_type() ); +} + +// destroy designated object and store new resource +template +inline void reset( auto_any & t, U newT ) +{ + detail::auto_any_helper::reset( t, newT ); +} + +// swap the contents of two shared_any objects +template +void swap( auto_any & left, + auto_any & right ) +{ + auto_any tmp( left ); + left = right; + right = tmp; +} + +// return the address of the wrapped resource +// WARNING: this will assert if the value of the resource is +// anything other than invalid_value. +template +inline typename auto_any::element_type* + address( auto_any & t ) +{ + return detail::auto_any_helper::address( t ); +} + +#pragma warning(pop) + +#endif + +// This causes the auto_* typedefs to be defined +DECLARE_SMART_ANY_TYPEDEFS(auto) + +#if defined(_OBJBASE_H_) & !defined(AUTO_ANY_CO_INIT) +# define AUTO_ANY_CO_INIT + typedef auto_any auto_co_close; + + // Helper class for balancing calls to CoInitialize and CoUninitialize + struct auto_co_init + { + explicit auto_co_init( DWORD dwCoInit = COINIT_APARTMENTTHREADED ) + : m_hr( smart_co_init_helper( dwCoInit ) ) + { + } + HRESULT hresult() const + { + return get(m_hr); + } + auto_co_close const m_hr; + private: + auto_co_init & operator=( auto_co_init const & ); + }; +#endif diff --git a/xcompute_native/inc/scoped_any.h b/xcompute_native/inc/scoped_any.h new file mode 100644 index 0000000..7d0813b --- /dev/null +++ b/xcompute_native/inc/scoped_any.h @@ -0,0 +1,273 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//+--------------------------------------------------------------------------- +// +// File: scoped_any.h +// +// Contents: automatic resource management, a-la std::scoped_ptr +// +// Classes: scoped_any<> and various typedefs +// +// Functions: get +// reset +// release +// valid +// address +// +//---------------------------------------------------------------------------- + +#ifndef SCOPED_ANY +#define SCOPED_ANY +#include +#include "smart_any_fwd.h" + +#pragma warning(push) + +// 4284 warning for operator-> returning non-pointer; +// compiler issues it even if -> is not used for the specific instance +#pragma warning(disable: 4284) + +namespace detail +{ + // friend function definitions go in scoped_any_helper + template + struct scoped_any_helper; +} + +// wrap a resource to enforce strict ownership and ensure proper cleanup +template +class scoped_any +{ + // disallow copy and assignment + scoped_any( scoped_any const & ); + scoped_any & operator=( + scoped_any const & ); + + // disallow comparison of scoped_any's + bool operator==( detail::safe_bool ) const; + bool operator!=( detail::safe_bool ) const; + + typedef detail::safe_types safe_types; + +public: + typedef typename detail::holder::type element_type; + typedef close_policy close_policy_type; + typedef typename safe_types::pointer_type pointer_type; + typedef typename safe_types::reference_type reference_type; + + // Fix-up the invalid_value type on older compilers + typedef typename detail::fixup_invalid_value:: + template rebind::type invalid_value_type; + + friend struct detail::scoped_any_helper; + + // construct from object pointer + explicit scoped_any( T t = invalid_value_type() ) + : m_t( t ) + { + } + + // destroy the object + ~scoped_any() + { + if( valid() ) + { + close_policy::close( m_t ); + } + } + + // return pointer to class object (assume pointer) + pointer_type operator->() const + { + #ifdef SMART_ANY_PTS + // You better not be applying operator-> to a handle! + static detail::static_assert::value> const cannot_dereference_a_handle; + #endif + assert( valid() ); + return safe_types::to_pointer( m_t ); + } + + // for use when scoped_any appears in a conditional + operator detail::safe_bool() const + { + return valid() ? detail::safe_true : detail::safe_false; + } + + // for use when scoped_any appears in a conditional + bool operator!() const + { + return ! valid(); + } + + #ifdef SMART_ANY_PTS + // if this scoped_any is managing an array, we can use operator[] to index it + typename detail::deref::type operator[]( int i ) const + { + static detail::static_assert::value> const cannot_dereference_a_handle; + static detail::static_assert::value> const accessed_like_an_array_but_not_deleted_like_an_array; + assert( valid() ); + return m_t[ i ]; + } + + // unary operator* lets you write code like: + // scoped_any pfoo( new foo ); + // foo & f = *pfoo; + reference_type operator*() const + { + static detail::static_assert::value> const cannot_dereference_a_handle; + assert( valid() ); + return safe_types::to_reference( m_t ); + } + #endif + +private: + + bool valid() const + { + // see if the managed resource is in the invalid state. + return m_t != static_cast( invalid_value_type() ); + } + + // the wrapped object + element_type m_t; +}; + +namespace detail +{ + // friend function definitions go in scoped_any_helper + template + struct scoped_any_helper + { + // return wrapped pointer + static T get( scoped_any const & t ) + { + return t.m_t; + } + + // return wrapped pointer and give up ownership + static T release( scoped_any & t ) + { + // Fix-up the invalid_value type on older compilers + typedef typename detail::fixup_invalid_value:: + template rebind::type invalid_value_type; + + T tmpT = t.m_t; + t.m_t = static_cast( invalid_value_type() ); + return tmpT; + } + + // destroy designated object and store new pointer + static void reset( scoped_any & t, T newT ) + { + if( t.m_t != newT ) + { + if( t.valid() ) + { + close_policy::close( t.m_t ); + } + t.m_t = newT; + } + } + + typedef typename scoped_any::element_type element_type; + + // return the address of the wrapped pointer + static element_type* address( scoped_any & t ) + { + // check to make sure the wrapped object is in the invalid state + assert( !t.valid() ); + return address_of( t.m_t ); + } + }; +} + +// return wrapped resource +template +inline T get( scoped_any const & t ) +{ + return detail::scoped_any_helper::get( t ); +} + +// return true if the scoped_any contains a currently valid resource +template +inline bool valid( scoped_any const & t ) +{ + return t; +} + +// return wrapped resource and give up ownership +template +inline T release( scoped_any & t ) +{ + return detail::scoped_any_helper::release( t ); +} + +// destroy designated object and store new resource +template +inline void reset( scoped_any & t ) +{ + typedef typename detail::fixup_invalid_value:: + template rebind::type invalid_value_type; + detail::scoped_any_helper::reset( t, invalid_value_type() ); +} + +// destroy designated object and store new resource +template +inline void reset( scoped_any & t, U newT ) +{ + detail::scoped_any_helper::reset( t, newT ); +} + +// return the address of the wrapped resource +// WARNING: this will assert if the value of the resource is +// anything other than invalid_value. +template +inline typename scoped_any::element_type* + address( scoped_any & t ) +{ + return detail::scoped_any_helper::address( t ); +} + +#pragma warning(pop) + +#endif + +// This causes the scoped_* typedefs to be defined +DECLARE_SMART_ANY_TYPEDEFS(scoped) + +#if defined(_OBJBASE_H_) & !defined(SCOPED_ANY_CO_INIT) +# define SCOPED_ANY_CO_INIT + typedef scoped_any scoped_co_close; + + // Helper class for balancing calls to CoInitialize and CoUninitialize + struct scoped_co_init + { + explicit scoped_co_init( DWORD dwCoInit = COINIT_APARTMENTTHREADED ) + : m_hr( smart_co_init_helper( dwCoInit ) ) + { + } + HRESULT hresult() const + { + return get(m_hr); + } + scoped_co_close const m_hr; + }; +#endif diff --git a/xcompute_native/inc/smart_any_fwd.h b/xcompute_native/inc/smart_any_fwd.h new file mode 100644 index 0000000..cc8daa6 --- /dev/null +++ b/xcompute_native/inc/smart_any_fwd.h @@ -0,0 +1,1079 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +//+--------------------------------------------------------------------------- +// +// File: smart_any_fwd.h +// +// Contents: automatic resource management +// +// Classes: auto_any, scoped_any and shared_any +// +// Functions: get +// reset +// release +// valid +// address +// +//---------------------------------------------------------------------------- + +#ifndef SMART_ANY_FWD +#define SMART_ANY_FWD + +#ifdef _MANAGED +#pragma warning( push ) +#pragma warning( disable : 4244 ) +#include +#pragma warning( pop ) + +#ifdef __cplusplus_cli +#define MGDHANDLE ^ +#define MGDHANDLECV ^ +#define MGDREFERENCE % +#else +#define MGDHANDLE __gc* +#define MGDHANDLECV const volatile__gc* +#define MGDREFERENCE __gc& +#endif + +#endif + +// Check to see if partial template specialization is available +#if _MSC_VER >= 1310 +#define SMART_ANY_PTS +#endif + +// forward declare some invalid_value policy classes +struct null_t; + +template +struct value_const; + +struct close_release_com; + +// +// TEMPLATE CLASS auto_any +// +template +class auto_any; + +// return wrapped resource +template +T get( auto_any const & t ); + +// return true if the auto_any contains a currently valid resource +template +bool valid( auto_any const & t ); + +// return wrapped resource and give up ownership +template +T release( auto_any & t ); + +// destroy designated object +template +void reset( auto_any & t ); + +// destroy designated object and store new resource +template +void reset( auto_any & t, U newT ); + +// swap the contents of two shared_any objects +template +void swap( auto_any & left, + auto_any & right ); + +// return the address of the wrapped resource +// WARNING: this will assert if the value of the resource is +// anything other than invalid_value. +//template +//T* address( auto_any & t ); + +// +// TEMPLATE CLASS shared_any +// +template +class shared_any; + +// return wrapped resource +template +T get( shared_any const & t ); + +// return true if the auto_any contains a currently valid resource +template +bool valid( shared_any const & t ); + +// destroy designated object +template +void reset( shared_any & t ); + +// destroy designated object and store new resource +template +void reset( shared_any & t, U newT ); + +// swap the contents of two shared_any objects +template +void swap( shared_any & left, + shared_any & right ); + + +// +// TEMPLATE CLASS scoped_any +// +template +class scoped_any; + +// return wrapped resource +template +T get( scoped_any const & t ); + +// return true if the auto_any contains a currently valid resource +template +bool valid( scoped_any const & t ); + +// return wrapped resource and give up ownership +template +T release( scoped_any & t ); + +// destroy designated object +template +void reset( scoped_any & t ); + +// destroy designated object and store new resource +template +void reset( scoped_any & t, U newT ); + +// return the address of the wrapped resource +// WARNING: this will assert if the value of the resource is +// anything other than invalid_value. +//template +//T* address( scoped_any & t ); + +// close policy for objects allocated with new +struct close_delete; + +namespace detail +{ + typedef char (&yes)[1]; + typedef char (&no) [2]; + + struct dummy_struct + { + void dummy_method() {} + }; + + typedef void (dummy_struct::*safe_bool)(); + safe_bool const safe_true = &dummy_struct::dummy_method; + safe_bool const safe_false = 0; + + // Because of older compilers, we can't always use + // null_t when we would like to. + template + struct fixup_invalid_value + { + template struct rebind { typedef invalid_value type; }; + }; + + // for compile-time assertions + template + struct static_assert; + + template<> + struct static_assert + { + static_assert() {} + }; + + template + struct static_init + { + static T const value; + }; + + template + T const static_init::value = T(); + + template + struct null_helper // unmanaged + { + template + struct inner + { + static T const get() + { + return static_init::value; + } + }; + }; + + template<> + struct null_helper // managed + { + template + struct inner + { + static T const get() + { + return 0; + } + }; + }; + + template + struct select_helper + { + template + struct inner { typedef T type; }; + }; + + template<> + struct select_helper + { + template + struct inner { typedef U type; }; + }; + + template + struct select + { + typedef typename select_helper::template inner::type type; + }; + + + template< bool > + struct holder_helper + { + template + struct inner + { + typedef T type; + }; + }; + + template< typename T > + struct remove_ref + { + typedef T type; + }; + + #ifdef SMART_ANY_PTS + template< typename T > + struct remove_ref + { + typedef T type; + }; + #endif + + template + T* address_of( T & v ) + { + return reinterpret_cast( + &const_cast( + reinterpret_cast(v))); + } + + #ifndef _MANAGED + + template + struct is_managed + { + static bool const value = false; + }; + + #else + + struct managed_convertible + { + managed_convertible( System::Object MGDHANDLECV ); + managed_convertible( System::Enum MGDHANDLECV ); + managed_convertible( System::ValueType MGDHANDLECV ); + managed_convertible( System::Delegate MGDHANDLECV ); + }; + + template + struct is_managed + { + private: + static yes check( managed_convertible ); + static no __cdecl check( ... ); + static typename remove_ref::type & make(); + public: + static bool const value = sizeof( yes ) == sizeof( check( make() ) ); + }; + + #ifdef SMART_ANY_PTS + template + struct is_managed + { + static bool const value = true; + }; + template + struct is_managed + { + static bool const value = true; + }; + template + struct is_managed + { + static bool const value = is_managed::value; + }; + template + struct is_managed + { + static bool const value = is_managed::value; + }; + template + struct is_managed + { + static bool const value = is_managed::value; + }; + #endif +#ifndef __cplusplus_cli + template<> + struct is_managed + { + static bool const value = true; + }; + template<> + struct is_managed + { + static bool const value = true; + }; + template<> + struct is_managed + { + static bool const value = true; + }; + template<> + struct is_managed + { + static bool const value = true; + }; +#endif + + template<> + struct holder_helper + { + template + struct inner + { + typedef gcroot type; + }; + }; + #endif + + template + struct holder + { + typedef typename holder_helper::value>::template inner::type type; + }; + + template + struct is_delete + { + static bool const value = false; + }; + + template<> + struct is_delete + { + static bool const value = true; + }; + + // dummy type, don't define + struct smart_any_cannot_dereference; + + // For use in implementing unary operator* + template + struct deref + { + typedef smart_any_cannot_dereference type; // will cause a compile error by default + }; + + #ifndef SMART_ANY_PTS + + // Old compiler needs extra help + template<> + struct fixup_invalid_value + { + template struct rebind { typedef value_const type; }; + }; + + #else + + template + struct same_type + { + static const int value = false; + }; + + template + struct same_type + { + static const int value = true; + }; + + // Handle reference types + template + struct deref + { + typedef typename deref::type type; + }; + + // Partially specialize for pointer types + template + struct deref + { + typedef T& type; // The result of dereferencing a T* + }; + + // Partially specialize for pointer types + template + struct deref + { + typedef typename deref::type type; // The result of dereferencing a T* + }; + + // Partially specialize for pointer types + template + struct deref + { + typedef typename deref::type type; // The result of dereferencing a T* + }; + + // Partially specialize for pointer types + template + struct deref + { + typedef typename deref::type type; // The result of dereferencing a T* + }; + + // Fully specialize for void* + template<> + struct deref + { + typedef smart_any_cannot_dereference type; // cannot dereference a void* + }; + + // Fully specialize for void const* + template<> + struct deref + { + typedef smart_any_cannot_dereference type; // cannot dereference a void* + }; + + // Fully specialize for void volatile* + template<> + struct deref + { + typedef smart_any_cannot_dereference type; // cannot dereference a void* + }; + + // Fully specialize for void const volatile* + template<> + struct deref + { + typedef smart_any_cannot_dereference type; // cannot dereference a void* + }; + + #ifdef _MANAGED + // Handle reference types + template + struct deref + { + typedef typename deref::type type; + }; + + // Partially specialize for pointer types + template + struct deref + { + typedef T MGDREFERENCE type; // The result of dereferencing a T MGDHANDLE + }; + + // Partially specialize for pointer types + template + struct deref + { + typedef typename deref::type type; // The result of dereferencing a T MGDHANDLE + }; + + // Partially specialize for pointer types + template + struct deref + { + typedef typename deref::type type; // The result of dereferencing a T MGDHANDLE + }; + + // Partially specialize for pointer types + template + struct deref + { + typedef typename deref::type type; // The result of dereferencing a T MGDHANDLE + }; + +#ifndef __cplusplus_cli + // Fully specialize for void* + template<> + struct deref + { + typedef smart_any_cannot_dereference type; // cannot dereference a System::Void MGDHANDLE + }; + + // Fully specialize for void const* + template<> + struct deref + { + typedef smart_any_cannot_dereference type; // cannot dereference a System::Void MGDHANDLE + }; + + // Fully specialize for void volatile* + template<> + struct deref + { + typedef smart_any_cannot_dereference type; // cannot dereference a System::Void MGDHANDLE + }; + + // Fully specialize for void const volatile* + template<> + struct deref + { + typedef smart_any_cannot_dereference type; // cannot dereference a System::Void MGDHANDLE + }; +#endif // __cplusplus_cli + #endif + + // The DECLARE_HANDLE macro in winnt.h defines a handle to be a pointer + // to a struct containing one member named "unused" of type int. We can + // use that information to make auto_any safer by disallowing actions like + // dereferencing a handle or calling delete on a handle. + template + struct has_unused + { + private: + template struct wrap_t; + template static yes check( wrap_t* ); + template static no __cdecl check( ... ); + public: + static bool const value = ( sizeof(check(0)) == sizeof(yes) ); + }; + + template + struct is_handle_helper + { + static bool const value = ( sizeof(T)==sizeof(int) && has_unused::value ); + }; + + #ifdef _MANAGED + template + struct is_handle_helper + { + static bool const value = false; + }; + #endif + + template<> + struct is_handle_helper + { + static bool const value = false; + }; + + // used to see whether a given type T is a handle type or not. + template + struct is_handle + { + private: + typedef typename remove_ref::type>::type deref_t; + public: + static bool const value = + ( same_type::value || is_handle_helper::value ); + }; + #endif + + template + struct safe_types + { + typedef T pointer_type; + typedef typename deref::type reference_type; + + static pointer_type to_pointer( T t ) + { + return t; + } + static reference_type to_reference( T t ) + { + return *t; + } + }; + + #ifdef SMART_ANY_PTS + template + class no_addref_release : public T + { + unsigned long __stdcall AddRef(); + unsigned long __stdcall Release(); + }; + + // shouldn't be able to call AddRef or Release + // through a smart COM wrapper + template + struct safe_types + { + typedef no_addref_release* pointer_type; + typedef no_addref_release& reference_type; + + static pointer_type to_pointer( T* t ) + { + return static_cast( t ); + } + static reference_type to_reference( T* t ) + { + return *static_cast( t ); + } + }; + #endif +} + +// a generic close policy that uses a ptr to a function +template +struct close_fun +{ + template + static void close( T t ) + { + Pfn( t ); + } +}; + +// free an object allocated with new by calling delete +struct close_delete +{ + template + static void close( T * p ) + { + // This will fail only if T is an incomplete type. + static detail::static_assert<0 != sizeof( T )> const cannot_delete_an_incomplete_type; + + #ifdef SMART_ANY_PTS + // This checks to make sure we're not calling delete on a HANDLE + static detail::static_assert::value> const cannot_delete_a_handle; + #endif + + delete p; + } + + #ifdef _MANAGED + template + static void close( gcroot const & p ) + { + delete static_cast( p ); + } + #endif +}; + +// free an array allocated with new[] by calling delete[] +struct close_delete_array +{ + template + static void close( T * p ) + { + // This will fail only if T is an incomplete type. + static detail::static_assert<0 != sizeof( T )> const cannot_delete_an_incomplete_type; + + #ifdef SMART_ANY_PTS + // This checks to make sure we're not calling delete on a HANDLE + static detail::static_assert::value> const cannot_delete_a_handle; + #endif + + delete [] p; + } + + //#ifdef _MANAGED + // This is broken because of compiler bugs + //template + //static void close( gcroot const & p ) + //{ + // delete [] static_cast( p ); + //} + //#endif +}; + +// for releasing a COM object +struct close_release_com +{ + template + static void close( T p ) + { + p->Release(); + } +}; + +// for releasing a __gc IDisposable object +struct close_dispose +{ + template + static void close( T p ) + { + p->Dispose(); + } +}; + +// some generic invalid_value policies + +struct null_t +{ + template + operator T const() const + { + return detail::null_helper::value>::template inner::get(); + } +}; + +template +struct value_const +{ + operator T const() const + { + return value; + } +}; + +template +struct value_const_ptr +{ + operator T const&() const + { + return *value_ptr; + } +}; + +#ifdef SMART_ANY_PTS +template +struct value_ref +{ + operator T const&() const + { + return value; + } +}; +#endif + +#endif // SMART_ANY_FWD + + +// +// Define some other useful close polcies +// + +#if defined(_INC_STDLIB) | defined(_INC_MALLOC) +typedef void (__cdecl *pfn_free_t)( void* ); +typedef close_fun(&free)> close_free; +#endif + +#if defined(_INC_STDIO) & !defined(SMART_CLOSE_FILE_PTR) +# define SMART_CLOSE_FILE_PTR + // don't close a FILE* if it is stdin, stdout or stderr + struct close_file_ptr + { + static void close( FILE * pfile ) + { + if( pfile != stdin && pfile != stdout && pfile != stderr ) + { + fclose( pfile ); + } + } + }; +#endif + +#ifdef _WINDOWS_ + +# ifndef SMART_VIRTUAL_FREE +# define SMART_VIRTUAL_FREE + // free memory allocated with VirtualAlloc + struct close_virtual_free + { + static void close( void * p ) + { + ::VirtualFree( p, 0, MEM_RELEASE ); + } + }; +# endif + + typedef close_fun close_handle; + typedef close_fun close_find; + typedef close_fun close_library; + typedef close_fun close_regkey; + typedef close_fun close_file_view; + typedef close_fun close_hicon; + typedef close_fun close_hgdiobj; + typedef close_fun close_haccel; + typedef close_fun close_hdc; + typedef close_fun close_hmenu; + typedef close_fun close_hcursor; + typedef close_fun close_window; + typedef close_fun close_heap_destroy; + typedef close_fun close_local_free; + typedef close_fun close_hdesk; + typedef close_fun close_hhook; + typedef close_fun close_hwinsta; + typedef close_fun close_event_source; + typedef close_fun close_global_free; + + typedef value_const invalid_handle_t; +#endif + +#ifdef _OLEAUTO_H_ + typedef close_fun close_bstr; +#endif + +#ifdef __MSGQUEUE_H__ + typedef close_fun close_msg_queue; +#endif + +#if defined(_WININET_) | defined(_DUBINET_) + typedef close_fun close_hinternet; +#endif + +#ifdef _RAS_H_ + typedef close_fun close_hrasconn; +#endif + +#if defined(__RPCDCE_H__) & !defined(SMART_ANY_RPC) +# define SMART_ANY_RPC + // for releaseing an rpc binding + struct close_rpc_binding + { + static void close( RPC_BINDING_HANDLE & h ) + { + ::RpcBindingFree( &h ); + } + }; + // for releaseing an rpc binding vector + struct close_rpc_vector + { + static void close( RPC_BINDING_VECTOR __RPC_FAR * & p ) + { + ::RpcBindingVectorFree( &p ); + } + }; + // for releasing a RPC string + struct close_rpc_string + { + static void close( unsigned char __RPC_FAR * & p ) + { + ::RpcStringFreeA(&p); + } + static void close( unsigned short __RPC_FAR * & p ) + { + ::RpcStringFreeW(&p); + } + }; +#endif + +#ifdef _WINSVC_ + typedef close_fun close_service; + typedef close_fun unlock_service; +#endif + +#ifdef _WINSOCKAPI_ + typedef int (__stdcall *pfn_closock_t)( SOCKET ); + typedef close_fun(&closesocket)> close_socket; + typedef value_const invalid_socket_t; +#endif + +#ifdef _OBJBASE_H_ + // For use when releasing memory allocated with CoTaskMemAlloc + typedef close_fun close_co_task_free; +#endif + + +// +// Below are useful smart typedefs for some common Windows/CRT resource types. +// + +#undef DECLARE_SMART_ANY_TYPEDEFS_STDIO +#undef DECLARE_SMART_ANY_TYPEDEFS_WINDOWS +#undef DECLARE_SMART_ANY_TYPEDEFS_OLEAUTO +#undef DECLARE_SMART_ANY_TYPEDEFS_MSGQUEUE +#undef DECLARE_SMART_ANY_TYPEDEFS_WININET +#undef DECLARE_SMART_ANY_TYPEDEFS_RAS +#undef DECLARE_SMART_ANY_TYPEDEFS_RPCDCE +#undef DECLARE_SMART_ANY_TYPEDEFS_WINSVC +#undef DECLARE_SMART_ANY_TYPEDEFS_WINSOCKAPI +#undef DECLARE_SMART_ANY_TYPEDEFS_OBJBASE + +#define DECLARE_SMART_ANY_TYPEDEFS_STDIO(prefix) +#define DECLARE_SMART_ANY_TYPEDEFS_WINDOWS(prefix) +#define DECLARE_SMART_ANY_TYPEDEFS_OLEAUTO(prefix) +#define DECLARE_SMART_ANY_TYPEDEFS_MSGQUEUE(prefix) +#define DECLARE_SMART_ANY_TYPEDEFS_WININET(prefix) +#define DECLARE_SMART_ANY_TYPEDEFS_RAS(prefix) +#define DECLARE_SMART_ANY_TYPEDEFS_RPCDCE(prefix) +#define DECLARE_SMART_ANY_TYPEDEFS_WINSVC(prefix) +#define DECLARE_SMART_ANY_TYPEDEFS_WINSOCKAPI(prefix) +#define DECLARE_SMART_ANY_TYPEDEFS_OBJBASE(prefix) + +#ifdef _INC_STDIO +# undef DECLARE_SMART_ANY_TYPEDEFS_STDIO +# define DECLARE_SMART_ANY_TYPEDEFS_STDIO(prefix) \ + typedef prefix ## _any prefix ## _file_ptr; +#endif + +#ifdef _WINDOWS_ +# undef DECLARE_SMART_ANY_TYPEDEFS_WINDOWS +# define DECLARE_SMART_ANY_TYPEDEFS_WINDOWS(prefix) \ + typedef prefix ## _any prefix ## _hkey; \ + typedef prefix ## _any prefix ## _hfind; \ + typedef prefix ## _any prefix ## _hfile; \ + typedef prefix ## _any prefix ## _communications_device; \ + typedef prefix ## _any prefix ## _console_input; \ + typedef prefix ## _any prefix ## _console_input_buffer; \ + typedef prefix ## _any prefix ## _console_output; \ + typedef prefix ## _any prefix ## _mailslot; \ + typedef prefix ## _any prefix ## _pipe; \ + typedef prefix ## _any prefix ## _handle; \ + typedef prefix ## _any prefix ## _access_token; \ + typedef prefix ## _any prefix ## _event; \ + typedef prefix ## _any prefix ## _file_mapping; \ + typedef prefix ## _any prefix ## _job; \ + typedef prefix ## _any prefix ## _mutex; \ + typedef prefix ## _any prefix ## _process; \ + typedef prefix ## _any prefix ## _semaphore; \ + typedef prefix ## _any prefix ## _thread; \ + typedef prefix ## _any prefix ## _timer; \ + typedef prefix ## _any prefix ## _completion_port; \ + typedef prefix ## _any prefix ## _hdc; \ + typedef prefix ## _any prefix ## _hicon; \ + typedef prefix ## _any prefix ## _hmenu; \ + typedef prefix ## _any prefix ## _hcursor; \ + typedef prefix ## _any prefix ## _hpen; \ + typedef prefix ## _any prefix ## _hrgn; \ + typedef prefix ## _any prefix ## _hfont; \ + typedef prefix ## _any prefix ## _hbrush; \ + typedef prefix ## _any prefix ## _hbitmap; \ + typedef prefix ## _any prefix ## _hpalette; \ + typedef prefix ## _any prefix ## _haccel; \ + typedef prefix ## _any prefix ## _window; \ + typedef prefix ## _any prefix ## _library; \ + typedef prefix ## _any prefix ## _file_view; \ + typedef prefix ## _any prefix ## _virtual_ptr; \ + typedef prefix ## _any prefix ## _heap; \ + typedef prefix ## _any prefix ## _hlocal; \ + typedef prefix ## _any prefix ## _hdesk; \ + typedef prefix ## _any prefix ## _hhook; \ + typedef prefix ## _any prefix ## _hwinsta; \ + typedef prefix ## _any prefix ## _event_source; \ + typedef prefix ## _any prefix ## _hglobal; +#endif + +// +// Define some other useful typedefs +// + +#ifdef _OLEAUTO_H_ +# undef DECLARE_SMART_ANY_TYPEDEFS_OLEAUTO +# define DECLARE_SMART_ANY_TYPEDEFS_OLEAUTO(prefix) \ + typedef prefix ## _any prefix ## _bstr; +#endif + +#ifdef __MSGQUEUE_H__ +# undef DECLARE_SMART_ANY_TYPEDEFS_MSGQUEUE +# define DECLARE_SMART_ANY_TYPEDEFS_MSGQUEUE(prefix) \ + typedef prefix ## _any prefix ## _msg_queue; +#endif + +#if defined(_WININET_) | defined(_DUBINET_) +# undef DECLARE_SMART_ANY_TYPEDEFS_WININET +# define DECLARE_SMART_ANY_TYPEDEFS_WININET(prefix) \ + typedef prefix ## _any prefix ## _hinternet; +#endif + +#ifdef _RAS_H_ +# undef DECLARE_SMART_ANY_TYPEDEFS_RAS +# define DECLARE_SMART_ANY_TYPEDEFS_RAS(prefix) \ + typedef prefix ## _any prefix ## _hrasconn +#endif + +#ifdef __RPCDCE_H__ +# undef DECLARE_SMART_ANY_TYPEDEFS_RPCDCE +# ifdef UNICODE +# define DECLARE_SMART_ANY_TYPEDEFS_RPCDCE(prefix) \ + typedef prefix ## _any prefix ## _rpc_binding; \ + typedef prefix ## _any prefix ## _rpc_binding_vector; \ + typedef prefix ## _any prefix ## _rpc_string_A; \ + typedef prefix ## _any prefix ## _rpc_string_W; \ + typedef prefix ## _rpc_string_W prefix ## _rpc_string; +# else +# define DECLARE_SMART_ANY_TYPEDEFS_RPCDCE(prefix) \ + typedef prefix ## _any prefix ## _rpc_binding; \ + typedef prefix ## _any prefix ## _rpc_binding_vector; \ + typedef prefix ## _any prefix ## _rpc_string_A; \ + typedef prefix ## _any prefix ## _rpc_string_W; \ + typedef prefix ## _rpc_string_A prefix ## _rpc_string; +# endif +#endif + +#ifdef _WINSVC_ +# undef DECLARE_SMART_ANY_TYPEDEFS_WINSVC +# define DECLARE_SMART_ANY_TYPEDEFS_WINSVC(prefix) \ + typedef prefix ## _any prefix ## _service; \ + typedef prefix ## _any prefix ## _service_lock; +#endif + +#ifdef _WINSOCKAPI_ +# undef DECLARE_SMART_ANY_TYPEDEFS_WINSOCKAPI +# define DECLARE_SMART_ANY_TYPEDEFS_WINSOCKAPI(prefix) \ + typedef prefix ## _any prefix ## _socket; +#endif + +#if defined(_OBJBASE_H_) & !defined(SMART_ANY_CO_INIT) +# define SMART_ANY_CO_INIT + inline HRESULT smart_co_init_helper( DWORD dwCoInit ) + { + (void) dwCoInit; +# if (_WIN32_WINNT >= 0x0400 ) | defined(_WIN32_DCOM) + return ::CoInitializeEx(0,dwCoInit); +# else + return ::CoInitialize(0); +# endif + } + inline void smart_co_uninit_helper( HRESULT hr ) + { + if (SUCCEEDED(hr)) + ::CoUninitialize(); + } + typedef close_fun close_co; + typedef value_const co_not_init; +# undef DECLARE_SMART_ANY_TYPEDEFS_OBJBASE +# define DECLARE_SMART_ANY_TYPEDEFS_OBJBASE(prefix) \ + typedef prefix ## _any prefix ## _co_task_ptr; +#endif + + +#define DECLARE_SMART_ANY_TYPEDEFS(prefix) \ + DECLARE_SMART_ANY_TYPEDEFS_STDIO(prefix) \ + DECLARE_SMART_ANY_TYPEDEFS_WINDOWS(prefix) \ + DECLARE_SMART_ANY_TYPEDEFS_OLEAUTO(prefix) \ + DECLARE_SMART_ANY_TYPEDEFS_MSGQUEUE(prefix) \ + DECLARE_SMART_ANY_TYPEDEFS_WININET(prefix) \ + DECLARE_SMART_ANY_TYPEDEFS_RAS(prefix) \ + DECLARE_SMART_ANY_TYPEDEFS_RPCDCE(prefix) \ + DECLARE_SMART_ANY_TYPEDEFS_WINSVC(prefix) \ + DECLARE_SMART_ANY_TYPEDEFS_WINSOCKAPI(prefix) \ + DECLARE_SMART_ANY_TYPEDEFS_OBJBASE(prefix) diff --git a/xcompute_native/locality.cpp b/xcompute_native/locality.cpp new file mode 100644 index 0000000..114dffb --- /dev/null +++ b/xcompute_native/locality.cpp @@ -0,0 +1,116 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/*++ +Module Name: + + locality.cpp + +Abstract: + + This module contains the public interface and support routines for + managing locality for xcompute on the HPC scheduler +--*/ +#include "stdafx.h" + + +/*++ + +XcGetNetworkLocalityPathOfProcessNode + +Description: + +This API translates a set of process node IDs into +network locality paths. + +Arguments: + + hSession + Handle to a session associated with + this call + + pProcessNodeId + The Process Node for which the + path is required + + ppNetworkLocalityPath + Returned network locality path for the ProcessNode. + The pNetworkLocalityPath vector should be freed with + XcFreeMemory(ppNetworkLocalityPath) + + pNetworkLocalityParam + The affinity param to be used to get the locality path. + The affinity param lets the user identify the affinity + level relative to the given ProcessNodeId, + which is reflected in the returned ppNetworkLocalityPath. + Thus given a ProcessNodeId, the user might say, + L2Switch as the NetworkLocalityParam, which means + the affinity is to all process nodes under that L2Switch. + + Different affinity params are defined in the + XComputeTypes.h. See Network Locality Params for + more details. + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetNetworkLocalityPathOfProcessNode( + IN XCSESSIONHANDLE hSession, + IN XCPROCESSNODEID processNodeId, + IN PSTR pNetworkLocalityParam, + OUT PCSTR* ppNetworkLocalityPath +) +{ + // + // + // This is a bit of a hack - we define each machine to be its own + // distinct "pod". Need better locality hints in the HPC scheduler + // to do anything more clever. + // + if ((pNetworkLocalityParam == NULL) || + (lstrcmpiA(pNetworkLocalityParam, XCLOCALITYPARAM_POD) == 0)) { + PCSTR pszNodeName; + XCERROR hr = XcProcessNodeNameFromId(hSession,processNodeId,&pszNodeName); + if (FAILED(hr)) { + return hr; + } + *ppNetworkLocalityPath = ::StrDupA(pszNodeName); + } else { + *ppNetworkLocalityPath = ::StrDupA("E_NOTIMPL"); + } + return S_OK; +} + diff --git a/xcompute_native/node.cpp b/xcompute_native/node.cpp new file mode 100644 index 0000000..a997d21 --- /dev/null +++ b/xcompute_native/node.cpp @@ -0,0 +1,435 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/*++ +Module Name: + + node.cpp + +Abstract: + + This module contains the public interface and support routines for + the xcompute node functionality on top of the HPC scheduler. +--*/ +#include "stdafx.h" +using namespace System; +using namespace System::Collections::Generic; +using namespace System::Collections::Specialized; +using namespace System::Runtime::InteropServices; + +gcroot g_NodeList = gcnew System::Collections::Specialized::StringCollection(); +CComAutoCriticalSection g_NodeListLock; + +XCPROCESSNODEID +NodeNameToID(System::String ^name) +{ + String ^N = name->ToUpper(); + CComCritSecLock lock(g_NodeListLock); + int index = g_NodeList->IndexOf(N); + if (index == -1) + { + index = g_NodeList->Add(N); + } + return (XCPROCESSNODEID)(index+1); +} + +System::String ^ +NodeIDToName(XCPROCESSNODEID nodeID) +{ + int i = (int)nodeID; + CComCritSecLock lock(g_NodeListLock); + return (String ^)g_NodeList->default[i-1]; +} + +/*++ + +XcGetProcessNodeId API + +Description: + +Gets the process node on which the process has been assigned. +If the process state anything other than XCPROCESSSTATE_ASSIGNEDTOPN +an error is returned. + +Arguments: + + hProcessHandle + Process handle + + pProcessNodeId + Pointer to process node Id + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessNodeId( + IN XCPROCESSHANDLE hProcessHandle, + OUT PXCPROCESSNODEID pProcessNodeId +) +{ + + try + { + if (hProcessHandle == CURRENT_PROCESS_ID) + { + *pProcessNodeId = NodeNameToID(Microsoft::Research::Dryad::AzureUtils::CurrentHostName); + } + else + { + System::String ^nodeName = Microsoft::Research::Dryad::VertexScheduler::GetInstance()->GetAssignedNode((int)hProcessHandle); + if (System::String::IsNullOrEmpty(nodeName)) + { + return E_FAIL; + } + *pProcessNodeId = NodeNameToID(nodeName); + } + } + catch (System::Exception ^e) + { + return System::Runtime::InteropServices::Marshal::GetHRForException(e); + } + return S_OK; +} + +/*++ + +XcEnumerateProcessNodes + + +Description: + +This API enumerates all the process nodes that are controlled +by the Process scheduler and returns an array of processNodeIds + +Arguments: + + hSession + Handle to a session associated with + this call + + pNumNodeIds + Pointer to a int which gets filled with the + number of process Node Ids in the + ppProcessNodeIds array + + ppProcessNodeIds + Pointer to array of processNode Ids. Use the + XcFreeMemory() API to deallocate. + + pAsyncInfo + The async info structure. Its an alias to + the CS_ASYNC_INFO defined in Cosmos.h. If + this parameter is NULL, then function + completes in synchronous manner and error + code is returned as return value. + + If parameter is not NULL then operation is + carried on in asynchronous manner. If + asynchronous operation has been successfully + started then function terminates + immediately with + HRESULT_FROM_WIN32(ERROR_IO_PENDING) return + value. + + Any other return value indicates that it was + impossible to start asynchronous operation. + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcEnumerateProcessNodes( + IN XCSESSIONHANDLE hSession, + OUT UINT32* pNumNodeIds, + OUT PXCPROCESSNODEID* ppProcessNodeIds, + IN PCXC_ASYNC_INFO pAsyncInfo +) +{ + PASYNC async; + CAPTURE_ASYNC(async); + + if (pNumNodeIds == NULL || ppProcessNodeIds == NULL) + { + return COMPLETE_ASYNC(async, E_INVALIDARG); + } + + array ^nodes =Microsoft::Research::Dryad::VertexScheduler::GetInstance()->EnumerateProcessNodes(); + + if (nodes == nullptr) + { + return COMPLETE_ASYNC(async, E_UNEXPECTED); + } + + // + // Now get the list of allocated nodes. + // + int numNodes = nodes->Length; + + // + // Allocate a buffer large enough to hold all the node IDs + // + XCPROCESSNODEID *nodeIds = (XCPROCESSNODEID *)::LocalAlloc(LMEM_FIXED, numNodes * sizeof(XCPROCESSNODEID)); + if (nodeIds == NULL) + { + return COMPLETE_ASYNC(async, E_OUTOFMEMORY); + } + + int i=0; + for each (String ^node in nodes) + { + nodeIds[i++] = NodeNameToID(node); + } + + *ppProcessNodeIds = nodeIds; + *pNumNodeIds = i; + + return COMPLETE_ASYNC(async, S_OK); +} + + + +/*++ + +XcFetchProcessNodeMetaData + +Description: + +This API fetches the process node related metadata. This +call can result in a call to the Process Scheduler, if the +metadata for a given process node is missing. + +Arguments: + + hSession + Handle to a session associated with + this call + + pProcessNodeIds + Array of IDs of the nodes for which the + metadata is required + + numNodeIds + Number of node ids in the + pProcessNodeIds array + + pAsyncInfo + The async info structure. Its an alias to + the CS_ASYNC_INFO defined in Cosmos.h. If + this parameter is NULL, then function + completes in synchronous manner and error + code is returned as return value. + + If parameter is not NULL then operation is + carried on in asynchronous manner. If + asynchronous operation has been successfully + started then function terminates + immediately with + HRESULT_FROM_WIN32(ERROR_IO_PENDING) return + value. + + Any other return value indicates that it was + impossible to start asynchronous operation. + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcFetchProcessNodeMetaData( + IN XCSESSIONHANDLE hSession, + IN UINT32 numNodeIds, + IN PXCPROCESSNODEID pProcessNodeIds, + IN PCXC_ASYNC_INFO pAsyncInfo +) +{ + PASYNC async; + CAPTURE_ASYNC(async); + + return COMPLETE_ASYNC(async, E_NOTIMPL); +} + +/*++ + +XcGetCurrentProcessNodeId API + +Description: + +Gets the current Process Node Id. The Process Node Id to +the node name map is maintained internally. + +Arguments: + + hSession + Handle to an XCompute session associated with + this call. + + pProcessNodeId + Pointer to Pointer of the Id of the node + + Return Value: + + CsError_OK + indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetCurrentProcessNodeId( + IN XCSESSIONHANDLE hSession, + OUT PXCPROCESSNODEID pProcessNodeId +) +{ + *pProcessNodeId = NodeNameToID(Microsoft::Research::Dryad::AzureUtils::CurrentHostName); + return S_OK; +} + + + +/*++ + +XcProcessNodeIdFromName API + +Description: + +Gets the Process Node Id for a node given the node name. The +Process Node Id to the node name map is maintained internally. +If a node name is not found in the internal map, then a new +entry is created and the corrosponding id is returned back + +Arguments: + + hSession + Handle to an XCompute session associated with this + call. Reserved for future use. Must be NULL. + + + pszProcessNodeName + Name of the process node for which Id is needed + + pProcessNodeId + Pointer to Pointer of the Id of the node + + Return Value: + + CsError_OK indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcProcessNodeIdFromName( + IN XCSESSIONHANDLE hSession, + IN PCSTR pszProcessNodeName, + OUT PXCPROCESSNODEID pProcessNodeId +) +{ + *pProcessNodeId = NodeNameToID(gcnew String(pszProcessNodeName)); + return S_OK; +} + + + +/*++ + +XcProcessNodeNameFromId API + +Description: + +Gets the Process Node name from the given Process Node Id.The +Process Node Id to the node name map is maintained internally. + +Arguments: + + hSession + Handle to an XCompute session associated with this + call. + + + processNodeId + The process Node Id for which the node name + is needed + + ppszProcessNodeName + Name of the process node corrosponding to Id + Note: the returned process node name string + is permanently allocated and will remain + valid for the life of the process. There + is no need to make a copy of this string. + + Return Value: + + CsError_OK indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcProcessNodeNameFromId( + IN XCSESSIONHANDLE hSession, + IN XCPROCESSNODEID processNodeId, + OUT PCSTR* ppszProcessNodeName +) +{ + // + // If we've returned this name before, return the + // same pointer again. + // + String ^s = NodeIDToName(processNodeId); + + *ppszProcessNodeName = (PCSTR)Marshal::StringToHGlobalAnsi(s).ToPointer(); + return S_OK; +} + diff --git a/xcompute_native/path.cpp b/xcompute_native/path.cpp new file mode 100644 index 0000000..5a4381d --- /dev/null +++ b/xcompute_native/path.cpp @@ -0,0 +1,199 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/*++ +Module Name: + + path.cpp + +Abstract: + + This module contains the public interface and support routines for + mapping between global URIs and local paths for xcompute on the + HPC scheduler + +--*/ +#include "stdafx.h" +using namespace System; +using namespace System::Runtime::InteropServices; + +/*++ + +XcGetProcessUri API + +Description: + +Gets the Uri to a file or directory local to XCompute process. +The returned Uri is a fully qualified and can be used to +create paths for file URI's in the processes root directory or +another directory under the root by appending path/s relative +to the initial working directory. + +NOTE: + +The Job does not have access to directories above the +Process's Root Directory. +All directories e.g. Process Working Directory, Data directory +are sub directories under the Process's Root directory + + +Arguments: + + hProcessHandle + The process handle for which to get the + Process File Uri. + + pszRelativePath + The path relative to process's working directory + that will be appended to the working directory. + NOTE: + If relative path is NULL. or '.' or '/', then the working directory + path is returned. + + ppszProcessRootDirUri + The Processes Root directory Uri. + Use the XcFreeMemory() API to free this buffer. + + Return Value: + + CsError_OK indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessUri( + IN XCPROCESSHANDLE hProcessHandle, + IN PCSTR pszRelativePath, + OUT PSTR* ppszProcessRootDirUri +) +{ + PSTR FilePath = NULL; + PSTR UriPath = NULL; + + try + { + PSTR prefix="file://"; + HRESULT hr = XcGetProcessPath(hProcessHandle, pszRelativePath, &FilePath); + + if (FAILED(hr)) + { + throw System::Runtime::InteropServices::Marshal::GetExceptionForHR( hr ); + } + + size_t cbLen = strlen(prefix)+strlen(FilePath)+2; + UriPath = (PSTR)::LocalAlloc(LMEM_FIXED, cbLen); + if (UriPath == NULL) + { + throw System::Runtime::InteropServices::Marshal::GetExceptionForHR( E_OUTOFMEMORY ); + } + strcpy_s(UriPath, cbLen, prefix); + strcat_s(UriPath, cbLen, FilePath); + + *ppszProcessRootDirUri = UriPath; + } + catch (System::Exception ^e) + { + // UriPath is freed on error + if (UriPath != NULL) + { + XcFreeMemory(UriPath); + UriPath = NULL; + } + + return System::Runtime::InteropServices::Marshal::GetHRForException(e); + } + finally + { + // FilePath is always freed + if (FilePath != NULL) + { + XcFreeMemory(FilePath); + FilePath = NULL; + } + } + + return S_OK; +} + +/*++ + +XcGetProcessPath API + +Description: + +Gets the path to a file or directory local to XCompute process. +The returned path is fully qualified and can be used to +create paths for files in the processes root directory or +another directory under the root by appending path/s relative +to the initial working directory. + +The returned path is suitable for passing to OS APIs like CreateFile() + +NOTE: + +The Job does not have access to directories above the +Process's Root Directory. +All directories e.g. Process Working Directory, Data directory +are sub directories under the Process's Root directory + + +Arguments: + + hProcessHandle + The process handle for which to get the + Process File Uri. + + pszRelativePath + The path relative to process's working directory + that will be appended to the working directory. + NOTE: + If relative path is NULL. or '.' or '/', then the working directory + path is returned. + + ppszProcessRootDirPath + The Processes Root directory Uri. + Use the XcFreeMemory() API to free this buffer. + + Return Value: + + CsError_OK indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessPath( + IN XCPROCESSHANDLE hProcessHandle, + IN PCSTR pszRelativePath, + OUT PSTR* ppszProcessRootDirUri +) +{ + String ^path = Microsoft::Research::Dryad::VertexScheduler::GetInstance()->GetProcessPath((int)hProcessHandle, gcnew String(pszRelativePath)); + + if (System::String::IsNullOrEmpty(path)) + { + return E_FAIL; + } + + *ppszProcessRootDirUri = (PSTR)Marshal::StringToHGlobalAnsi(path).ToPointer(); + + return S_OK; +} diff --git a/xcompute_native/process.cpp b/xcompute_native/process.cpp new file mode 100644 index 0000000..3bae70e --- /dev/null +++ b/xcompute_native/process.cpp @@ -0,0 +1,710 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/*++ + +Module Name: + + process.cpp + +Abstract: + + This module contains the public interface and support routines for + the xcompute process functionality on top of the HPC scheduler. + + Note that "process" in the xcompute terminology maps to "task" in + the HPC terminology. + +--*/ +#include "stdafx.h" +#include + +static CComAutoCriticalSection g_idLock; +static LONG g_dwNextProcessId = 1; + +DWORD GetNextProcessId() +{ + return InterlockedIncrement(&g_dwNextProcessId); +} + +static CComAutoCriticalSection g_referenceLock; +typedef std::map HandleMapT; +HandleMapT g_handleRefs; + +void AddProcessHandleReference(DWORD dwId) +{ + CComCritSecLock lock(g_referenceLock); + + HandleMapT::const_iterator iter = g_handleRefs.find(dwId); + if (iter == g_handleRefs.end()) + { + g_handleRefs[dwId] = 1; + } + else + { + g_handleRefs[dwId] = ++(g_handleRefs[dwId]); + } +} + +bool ReleaseProcessHandleReference(DWORD dwId) +{ + bool bCanFree = true; + + CComCritSecLock lock(g_referenceLock); + + HandleMapT::iterator iter = g_handleRefs.find(dwId); + if (iter != g_handleRefs.end()) + { + DWORD dwCount = iter->second; + if (--dwCount == 0) + { + g_handleRefs.erase(iter); + } + else + { + g_handleRefs[dwId] = dwCount; + bCanFree = false; + } + } + + return bCanFree; +} + + +XCPROCESSSTATE TranslateProcessState( + IN ProcessState fromState + ) +{ + XCPROCESSSTATE toState = XCPROCESSSTATE_UNSCHEDULED; + + if (fromState == ProcessState::Unscheduled) + { + toState = XCPROCESSSTATE_UNSCHEDULED; + } + else if (fromState == ProcessState::SchedulingFailed) + { + toState = XCPROCESSSTATE_SCHEDULINGFAILED; + } + else if (fromState == ProcessState::AssignedToNode) + { + toState = XCPROCESSSTATE_ASSIGNEDTONODE; + } + else if (fromState == ProcessState::Running) + { + toState = XCPROCESSSTATE_RUNNING; + } + else if (fromState == ProcessState::Completed) + { + toState = XCPROCESSSTATE_COMPLETED; + } + + return toState; +} + + +/*++ + +XcCreateNewProcessHandle API + +Description: + +Creates a new process handle for a new process in the +given Job. + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +Note: +1. + This method just creates the handle to + the XCompute process. It does not schedule the process itself. + Use the XcScheduleProcessAPI to schedule the XCompute process. + +2. + Use the XcCloseProcessHandle() to free the handle + +3. + Do not copy handle using the simple assignment operator.Use the + DuplicateProcessHandle() API. Each handle variable needs to be + freed using the XcCloseProcessHandle(). + +Arguments: + + hSession + Handle to a session associated with this call + + pJobId + The Id of the job under which the process will + be created. A NULL value will cause the current + processes JobId to be automatically picked up. + NOTE: + This parameter is only interesting to the Task Scheduler. + For all other cases, it should be assined to NULL + + phProcessHandle + The handle to the process. + + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCreateNewProcessHandle( + IN XCSESSIONHANDLE hSession, + IN PCXC_JOBID pJobId, + OUT PXCPROCESSHANDLE phProcessHandle +) +{ + DWORD dwId = GetNextProcessId(); + + VertexScheduler::GetInstance()->CreateVertexProcess(dwId); + + *phProcessHandle = (XCPROCESSHANDLE)dwId; + AddProcessHandleReference(*((LPDWORD)phProcessHandle)); + + return S_OK; +} + + +/*++ + +XcOpenCurrentProcessHandle API + +Description: + +Opens the current processes handle + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +Note: +1. + This method creates the handle to the current process and assigns + it to the session on that process. + +2. + Use the XcCloseProcessHandle() to free the handle + +3. + Do not copy handle using the simple assignment operator.Use the + DuplicateProcessHandle() API. Each handle variable needs to be + freed using the XcCloseProcessHandle(). + +Arguments: + + hSession + Handle to a session associated with this call. + + phProcessHandle + The handle to the process. This must be closed using the + XcClosePorcessHandle() + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcOpenCurrentProcessHandle( + IN XCSESSIONHANDLE hSession, + OUT PXCPROCESSHANDLE phProcessHandle +) +{ + *phProcessHandle = CURRENT_PROCESS_ID; + + return S_OK; +} + + +/*++ + +XcCloseProcessHandle API + +Description: + +Closes a process handle created either by a call to +XcCreateNewProcessHandle() or XcDupProcessHandle(). + +This call is synchronous and does not cross +machine boundaries/process boundaries. + + +NOTE: +Every call to the XcCreateNewProcessHandle() or +DupProcessHandle() should ultimately +result in a call to XcCloseProcessHandle() to deallocated the handle. + +Arguments: + + hProcessHandle + Process handle to be closed + +Return Value: + + XCERROR_OK + The call succeded + + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCloseProcessHandle ( + IN XCPROCESSHANDLE hProcessHandle +) +{ + + if (hProcessHandle != CURRENT_PROCESS_ID) + { + DWORD dwId = (DWORD)hProcessHandle; + if (ReleaseProcessHandleReference(dwId)) + { + // No more references, free associated resources in XComputeLib.dll + VertexScheduler::GetInstance()->CloseVertexProcess(dwId); + } + } + + return S_OK; +} + + + +/*++ + +XcDupProcessHandle API + +Description: + +Duplicates a process handle. Use this api, if a copy of the +process handle is needed. + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +NOTE: + a. Every call to the DupProcessHandle should ultimately result + in a call to XcCloseProcessHandle() to deallocated the handle. + +Arguments: + + hProcessHandle + Process handle to be duplicated + + phDupProcessHandle + The duplicated process handle + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcDupProcessHandle( + IN XCPROCESSHANDLE hProcessHandle, + OUT PXCPROCESSHANDLE phDupProcessHandle +) +{ + AddProcessHandleReference((DWORD)hProcessHandle); + *phDupProcessHandle = hProcessHandle; + + return S_OK; +} + + + +/*++ + +XcSerializeProcessHandle API + +Description: + +Creates a serialized process handle. A XCompute process can serialize a +process handle, and pass it to another XCompute process where the other +XCompute process can use the XcUnSerializeProcessHandle() API, to +recreate the process handle. Then it can use that process handle to +communicate with the process. e.g by using XcSetAndGetProcessInfo() API + +Arguments: + + hProcessHandle + The handle to the process to serialize. + + ppXcSerializedHandleBlock + The serialized process handle. Use the XcFreeMemory() API + to de-allocated the pXcSerializedHandleBlock. + + pcbBlockLength + The length in bytes of the serialized process handle block + +NOTE: + The UserContext assiciated with the process handle will *NOT* + be serialzed. + +Return Value: + + XCERROR_OK + The call succeded + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcSerializeProcessHandle ( + IN XCPROCESSHANDLE hProcessHandle, + OUT PCVOID* ppXcSerializedHandleBlock, + OUT PSIZE_T pcbBlockLength +) +{ + *ppXcSerializedHandleBlock = (PCVOID)hProcessHandle; + *pcbBlockLength = sizeof(DWORD); + + return S_OK; +} + + +/*++ + +XcUnSerializeProcessHandle API + +Description: + +Un-serializes a serialized process handle. See XcSerializeProcessHandle() API +for more details + +Arguments: + + hSession + The session to which to associate the un-serialized process handle with. + + pXcSerializedHandleBlock + The serialized process handle. + + pcbBlockLength + The length in bytes of the serialized process handle block + + phProcessHandle + The un-serialized process handle. + +Return Value: + + XCERROR_OK + The call succeded + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcUnSerializeProcessHandle ( + IN XCSESSIONHANDLE hSession, + IN PCVOID pXcSerializedHandleBlock, + IN SIZE_T cbBlockLength, + OUT PXCPROCESSHANDLE phProcessHandle +) +{ + *phProcessHandle = (XCPROCESSHANDLE)pXcSerializedHandleBlock; + + return S_OK; +} + + +/*++ + +XcGetProcessState API + +Description: + +Gets the process state information. If Schedule process is not +yet been called, the API will return error. + +This call is synchronous and does not cross +machine boundaries/process boundaries. + +Arguments: + + hProcessHandle + Process handle + + pProcessState + Describes the process state. The different states + are described in XComputeTypes.h + + pProcessSchedulingError + if process state is XCPROCESSSTATE_COMPLETED + then the error code indicates reson. + S_OK means process compeleted without errors. + Other error codes indicate reasons for failed completion. + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessState( + IN XCPROCESSHANDLE hProcessHandle, + OUT PXCPROCESSSTATE pProcessState, + OUT XCERROR* pProcessSchedulingError + ) +{ + HRESULT hr = S_OK; + + if (pProcessState == NULL) + { + hr = E_INVALIDARG; + goto Exit; + } + + if (pProcessSchedulingError == NULL) + { + hr = E_INVALIDARG; + goto Exit; + } + + try + { + ProcessState state = VertexScheduler::GetInstance()->GetProcessState((int)hProcessHandle); + + *pProcessState = TranslateProcessState(state); + + switch (*pProcessState) + { + case XCPROCESSSTATE_SCHEDULINGFAILED: + *pProcessSchedulingError = E_FAIL; + break; + case XCPROCESSSTATE_COMPLETED: + if (VertexScheduler::GetInstance()->ProcessCancelled((int)hProcessHandle)) + { + *pProcessSchedulingError = E_FAIL; + } + break; + default: + *pProcessSchedulingError = S_OK; + break; + } + } + catch (System::Exception ^e) + { + hr = System::Runtime::InteropServices::Marshal::GetHRForException(e); + } + +Exit: + return hr; +} + +/*++ + +XcGetProcessId API + +Description: + +Gets the process Id of the process associated with the process handle. +If the process state is anything less than XCPROCESSSTATE_ASSIGNEDTOPN +an error is returned. + +Arguments: + + hProcessHandle + Process handle + + + pProcessId + The id of the process + +Return Value: + + XCERROR_OK + The call succeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcGetProcessId( + IN XCPROCESSHANDLE hProcessHandle, + OUT XC_PROCESSID* pProcessId +) +{ + // TODO: anything else needed here? + pProcessId->Data1 = (DWORD)hProcessHandle; + + return S_OK; +} + +/*++ + +XcWaitForStateChange API + +Description: + +The API allows users to get async completion status for +XCompute process when it reaches a desired state. (see XCPROCESSSTATE) +When the desired state is reached the async completion is dispatched. + +NOTE: +1. If the process gets cancelled, then completion is dispatched immediately +2. The pOperationState of the AsyncInfo will have the error code. + +Arguments: + + hProcessHandle + Handle to an XCompute process for which the + state change event is needed + + waitForState + The state to wait for the XCompute to be in, so + that completion can be dispatched + + tiMaxWaitInterval + The maximum amount of time (not including network + request latencies) that the API should wait for a + change in the process list before completing. If + XCTIMEINTERVAL_ZERO, the API will return changes + that can be immediately determined without + communication with the process scheduler. If + XCTIMEINTERVAL_INFINITE, the API will wait until a + change occurs or the process is cancelled. + + pAsyncInfo + The async info structure. Its an alias to the + CS_ASYNC_INFO defined in Cosmos.h. If this + parameter is NULL, then the function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then the operation is carried + on in asynchronous manner. If an asynchronous + operation has been successfully started then + this function terminates immediately with + an HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start the asynchronous operation. + + Return Value: + + CsError_OK indicates call succeeded + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcWaitForStateChange( + IN XCPROCESSHANDLE hProcessHandle, + IN XCPROCESSSTATE waitForState, + IN XCTIMEINTERVAL tiMaxWaitInterval, + IN PCXC_ASYNC_INFO pAsyncInfo +) +{ + ProcessState targetState; + + PASYNC async; + + // + // Map the requested state into the HPC task state + // + switch (waitForState) + { + case XCPROCESSSTATE_UNSCHEDULED: + targetState = ProcessState::Unscheduled; + break; + + case XCPROCESSSTATE_SCHEDULINGFAILED: + targetState = ProcessState::SchedulingFailed; + break; + + case XCPROCESSSTATE_ASSIGNEDTONODE: + targetState = ProcessState::AssignedToNode; + break; + + case XCPROCESSSTATE_RUNNING: + targetState = ProcessState::Running; + break; + + case XCPROCESSSTATE_COMPLETED: + targetState = ProcessState::Completed; + break; + + default: + return E_NOTIMPL; + } + + CAPTURE_ASYNC(async); + + try + { + + VertexScheduler ^client = VertexScheduler::GetInstance(); + ProcessState currentState = client->GetProcessState((int)hProcessHandle); + + if (currentState >= targetState) + { + return COMPLETE_ASYNC(async, S_OK); + } + + // + // If our timeout has elapsed, return timeout + // + if (tiMaxWaitInterval == XCTIMEINTERVAL_ZERO) + { + return COMPLETE_ASYNC(async, HRESULT_FROM_WIN32(ERROR_TIMEOUT)); + } + + if (async == NULL) + { + if (client->WaitForStateChange((int)hProcessHandle, tiMaxWaitInterval, targetState)) + { + return S_OK; + } + else + { + return HRESULT_FROM_WIN32(ERROR_TIMEOUT); + } + } + else + { + AsyncWrapper ^wrapper = gcnew AsyncWrapper(async); + StateChangeEventHandler ^handler = gcnew StateChangeEventHandler(wrapper, &AsyncWrapper::StateChangeHandler); + client->NotifyStateChange((int)hProcessHandle, tiMaxWaitInterval, targetState, handler); + return HRESULT_FROM_WIN32(ERROR_IO_PENDING); + } + } + catch (System::Exception ^e) + { + return COMPLETE_ASYNC(async,System::Runtime::InteropServices::Marshal::GetHRForException(e)); + } +} + + + diff --git a/xcompute_native/property.cpp b/xcompute_native/property.cpp new file mode 100644 index 0000000..1baf6cf --- /dev/null +++ b/xcompute_native/property.cpp @@ -0,0 +1,423 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/*++ +Module Name: + + property.cpp + +Abstract: + + This module contains the public interface and support routines for + process properties for xcompute on the HPC scheduler + +--*/ +#include "stdafx.h" +//#include + +using namespace System; +using namespace System::Threading; +using namespace System::Runtime::InteropServices; + +// +// Prototypes for functions internal to this module +// +HRESULT +SetGetProps( + XCPROCESSHANDLE hProcessHandle, + long cPropCount, + PXC_PROCESSPROPERTY_INFO setProp[], + LPCSTR blockOnLabel, + UINT64 blockOnVersion, + XCTIMEINTERVAL maxBlockTime, + LPCSTR getPropLabel, + DWORD fetchOptions, + PXC_SETANDGETPROCESSINFO_REQRESULTS *presults, + PCXC_ASYNC_INFO pAsyncInfo + ); + +static void ConvertManagedProcessInfoToNative(ProcessInfo ^managedInfo, PXC_PROCESS_INFO &pNativeInfo) +{ + char *strTemp = NULL; + + pNativeInfo = (PXC_PROCESS_INFO)::LocalAlloc(LMEM_FIXED, sizeof(XC_PROCESS_INFO)); + ZeroMemory(pNativeInfo, sizeof(XC_PROCESS_INFO)); + + pNativeInfo->Size = sizeof(XC_PROCESS_INFO); + pNativeInfo->Flags = managedInfo->flags; + pNativeInfo->ProcessState = TranslateProcessState(managedInfo->processState); + pNativeInfo->ProcessStatus = managedInfo->processStatus; + pNativeInfo->ExitCode = managedInfo->exitCode; + + if (managedInfo->propertyInfos != nullptr) + { + pNativeInfo->ppProperties = (PXC_PROCESSPROPERTY_INFO*)LocalAlloc(LMEM_FIXED, sizeof(PXC_PROCESSPROPERTY_INFO) * managedInfo->propertyInfos->Length); + pNativeInfo->NumberofProcessProperties = managedInfo->propertyInfos->Length; + + for (int i = 0; i < managedInfo->propertyInfos->Length; i ++) + { + PXC_PROCESSPROPERTY_INFO processProp = (XC_PROCESSPROPERTY_INFO*)LocalAlloc(LMEM_FIXED, sizeof(XC_PROCESSPROPERTY_INFO)); + ZeroMemory(processProp, sizeof(XC_PROCESSPROPERTY_INFO)); + processProp->Size = sizeof(XC_PROCESSPROPERTY_INFO); + + size_t len = managedInfo->propertyInfos[i]->propertyLabel->Length + 1; + strTemp = (char*)(void*)Marshal::StringToHGlobalAnsi(managedInfo->propertyInfos[i]->propertyLabel); + processProp->pPropertyLabel = (char*)LocalAlloc(LMEM_FIXED, sizeof(char) * len); + strncpy(processProp->pPropertyLabel, strTemp, len); + Marshal::FreeHGlobal((IntPtr)strTemp); + strTemp = NULL; + + len = managedInfo->propertyInfos[i]->propertyString->Length + 1; + strTemp = (char*)(void*)Marshal::StringToHGlobalAnsi(managedInfo->propertyInfos[i]->propertyString); + processProp->pPropertyString = (char*)LocalAlloc(LMEM_FIXED, sizeof(char) * len); + strncpy(processProp->pPropertyString, strTemp, len); + Marshal::FreeHGlobal((IntPtr)strTemp); + strTemp = NULL; + + processProp->PropertyVersion = managedInfo->propertyInfos[i]->propertyVersion; + + if (managedInfo->propertyInfos[i]->propertyBlock != nullptr && managedInfo->propertyInfos[i]->propertyBlock->Length > 0) + { + processProp->PropertyBlockSize = managedInfo->propertyInfos[i]->propertyBlock->Length; + processProp->pPropertyBlock = (char*)LocalAlloc(LMEM_FIXED, sizeof(char) * processProp->PropertyBlockSize); + Marshal::Copy(managedInfo->propertyInfos[i]->propertyBlock, + 0, + (IntPtr)processProp->pPropertyBlock, + processProp->PropertyBlockSize); + } + + pNativeInfo->ppProperties[i] = processProp; + } + } + + if (managedInfo->processStatistics != nullptr) + { + PXC_PROCESS_STATISTICS stats = (XC_PROCESS_STATISTICS*)LocalAlloc(LMEM_FIXED, sizeof(XC_PROCESS_STATISTICS)); + ZeroMemory(stats, sizeof(XC_PROCESS_STATISTICS)); + + stats->Size = sizeof(XC_PROCESS_STATISTICS); + stats->Flags = managedInfo->processStatistics->flags; + stats->ProcessUserTime = managedInfo->processStatistics->processUserTime; + stats->ProcessKernelTime = managedInfo->processStatistics->processKernelTime; + stats->PageFaults = managedInfo->processStatistics->pageFaults; + stats->TotalProcessesCreated = managedInfo->processStatistics->totalProcessesCreated; + stats->PeakVMUsage = managedInfo->processStatistics->peakVMUsage; + stats->PeakMemUsage = managedInfo->processStatistics->peakMemUsage; + stats->MemUsageSeconds = managedInfo->processStatistics->memUsageSeconds; + stats->TotalIo = managedInfo->processStatistics->totalIo; + + pNativeInfo->pProcessStatistics = stats; + } + + if (strTemp != NULL) + { + Marshal::FreeHGlobal((IntPtr)strTemp); + strTemp = NULL; + } +} + + +ref class AsyncPropWrapper +{ +public: + ASYNC *m_pASYNC; + ManualResetEvent ^m_event; + XCPROCESSHANDLE m_process; + PXC_SETANDGETPROCESSINFO_REQRESULTS *m_ppResults; + HRESULT m_hResult; + + AsyncPropWrapper(ASYNC *pAsync, XCPROCESSHANDLE hProc, PXC_SETANDGETPROCESSINFO_REQRESULTS *ppResults) + { + m_pASYNC = pAsync; + m_event = nullptr; + if (m_pASYNC == NULL) + { + m_event = gcnew ManualResetEvent(false); + } + m_process = hProc; + m_ppResults = ppResults; + m_hResult = HRESULT_FROM_WIN32(ERROR_IO_PENDING); + } + + ~AsyncPropWrapper() + { + if (m_pASYNC) + { + delete m_pASYNC; + } + } + + void GetSetPropertyHandler(System::Object ^sender, Microsoft::Research::Dryad::XComputeProcessGetSetPropertyEventArgs ^e) + { + if (m_ppResults != NULL) + { + PXC_SETANDGETPROCESSINFO_REQRESULTS pResults = (PXC_SETANDGETPROCESSINFO_REQRESULTS)::LocalAlloc(LMEM_FIXED, sizeof(XC_SETANDGETPROCESSINFO_REQRESULTS)); + if (pResults == NULL ) + { + m_hResult = E_OUTOFMEMORY; + goto Exit; + } + ZeroMemory(pResults, sizeof(XC_SETANDGETPROCESSINFO_REQRESULTS)); + + try + { + if (e->ProcessInfo != nullptr) + { + ConvertManagedProcessInfoToNative(e->ProcessInfo, pResults->pProcessInfo); + } + + if (e->PropertyVersions != nullptr && e->PropertyVersions->Length > 0) + { + pResults->NumberOfPropertyVersions = e->PropertyVersions->Length; + pResults->pPropertyVersions = (UINT64 *)LocalAlloc(LMEM_FIXED, sizeof(UINT64) * e->PropertyVersions->Length); + for (UINT32 i = 0; i < pResults->NumberOfPropertyVersions; i++) + { + pResults->pPropertyVersions[i] = e->PropertyVersions[i]; + } + } + } + catch(Exception ^e) + { + m_hResult = Marshal::GetHRForException(e); + Console::WriteLine("[XComputeNative.GetSetPropertyHandler] Exception: {0}", e->Message); + goto Exit; + } + + *m_ppResults = pResults; + } + + m_hResult = S_OK; + +Exit: + if (m_pASYNC) + { + m_pASYNC->Complete(m_hResult); + } + else + { + m_event->Set(); + } + } +}; + + +HRESULT +SetAppProcessConstraints( + IN XCPROCESSHANDLE hProcess, + IN PCXC_PROCESS_CONSTRAINTS Constraints + ) +{ + // TODO implement this! (return success for now because Dryad tries to dink with the runtime + return S_OK; +} + +// #pragma unmanaged +// push managed state on to stack and set unmanaged state +#pragma managed(push, off) +/*++ + +XcSetAndGetProcessInfo API + +Description: + +Gets the process related information from the Process Node. +JobManager (e.g. Dryad Job manager), will use this API to get +information about a given XCompute process, of a job. +Various bit flags (explained below) control the amount of data +retreived for a given process +It also provides the user with the ability to block on a +particular property, for maxBlockTime amount of time, before the +API finishes (synchronously or asynchronously). Dryad uses this +to extend the lease period for a given process + +Arguments: + + hProcessHandle + Handle to the process. + Use the XcCreateNewProcessHandle () API + to get obtain the handle to the process + + pXcRequestInputs + Pointer to the + XC_SETANDGETPROCESSINFO_REQINPUT struct. + It contains the various inputs to the API + clubbed together. This structure needs to + be preserverd by the user till the Async + call is completed + + ppXcRequestResults + The results structure.The user should use + the XcFreeMemory(ppXcPnProcessInfo) to free + the memory after the results have been + consumed. + See PXC_SETANDGETPROCESSINFO_REQRESULTS for + more info. + + pAsyncInfo + The async info structure. Its an alias to + the CS_ASYNC_INFO defined in Cosmos.h. If + this parameter is NULL, then function + completes in synchronous manner and error + code is returned as return value. + + If parameter is not NULL then operation is + carried on in asynchronous manner. If + asynchronous operation has been successfully + started then function terminates + immediately with + HRESULT_FROM_WIN32(ERROR_IO_PENDING) return + value. + + Any other return value indicates that it was + impossible to start asynchronous operation. + + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcSetAndGetProcessInfo( + IN XCPROCESSHANDLE hProcessHandle, + IN PXC_SETANDGETPROCESSINFO_REQINPUT pXcRequestInputs, + OUT PXC_SETANDGETPROCESSINFO_REQRESULTS* ppXcRequestResults, + IN PCXC_ASYNC_INFO pAsyncInfo +) +{ + HRESULT hr; + printf("In native XcSetAndGetProcessInfo\n"); + // check for malformed input + if (pXcRequestInputs->Size < sizeof(XC_SETANDGETPROCESSINFO_REQINPUT)) + { + return E_NOTIMPL; + } + + // no support for process constraints + if (pXcRequestInputs->pAppProcessConstraints != NULL) + { + hr = SetAppProcessConstraints(hProcessHandle, pXcRequestInputs->pAppProcessConstraints); + if (FAILED(hr)) + { + return hr; + } + } + *ppXcRequestResults = NULL; + + if (hProcessHandle == CURRENT_PROCESS_ID) + { + // We need to use the real dryad process id so the vertex service host can find it + CStringA procId; + procId.GetEnvironmentVariable("CCP_DRYADPROCID"); + hProcessHandle = (XCPROCESSHANDLE)atoi((LPCSTR)procId); + } + + hr = SetGetProps( + hProcessHandle, + pXcRequestInputs->NumberOfProcessPropertiesToSet, + pXcRequestInputs->ppPropertiesToSet, + pXcRequestInputs->pBlockOnPropertyLabel, + pXcRequestInputs->BlockOnPropertyversionLastSeen, + pXcRequestInputs->MaxBlockTime, + pXcRequestInputs->pPropertyFetchTemplate, + pXcRequestInputs->ProcessInfoFetchOptions, + ppXcRequestResults, + pAsyncInfo); + + return hr; +} +// #pragma managed +#pragma managed(pop) + + +HRESULT +SetGetProps( + XCPROCESSHANDLE hProcessHandle, + long cPropCount, + PXC_PROCESSPROPERTY_INFO setProp[], + LPCSTR blockOnLabel, + UINT64 blockOnVersion, + XCTIMEINTERVAL maxBlockTime, + LPCSTR getPropLabel, + DWORD fetchOptions, + PXC_SETANDGETPROCESSINFO_REQRESULTS *ppResults, + PCXC_ASYNC_INFO pAsyncInfo + ) +{ + PASYNC async; + CAPTURE_ASYNC(async); + HRESULT hr = S_OK; + + // Build a managed ProcessPropertyInfo array from the setProp[] array + array ^infos = gcnew array(cPropCount); + for (int i = 0; i < cPropCount; i++) + { + ProcessPropertyInfo ^info = gcnew ProcessPropertyInfo(); + info->propertyLabel = gcnew String(setProp[i]->pPropertyLabel); + info->propertyVersion = setProp[i]->PropertyVersion; + info->propertyString = gcnew String(setProp[i]->pPropertyString); + if (setProp[i]->PropertyBlockSize > 0) + { + info->propertyBlock = gcnew array(setProp[i]->PropertyBlockSize); + Marshal::Copy((IntPtr)setProp[i]->pPropertyBlock, info->propertyBlock, 0, setProp[i]->PropertyBlockSize); + } + + infos[i] = info; + } + + + AsyncPropWrapper ^wrapper = gcnew AsyncPropWrapper(async, hProcessHandle, ppResults); + GetSetPropertyEventHandler ^handler = gcnew GetSetPropertyEventHandler(wrapper, &AsyncPropWrapper::GetSetPropertyHandler); + + VertexScheduler ^vs = VertexScheduler::GetInstance(); + bool bRetVal = vs->SetGetProps( + (int)hProcessHandle, + infos, + gcnew String(blockOnLabel), + blockOnVersion, + maxBlockTime, + gcnew String(getPropLabel), + (fetchOptions & XCPROCESSINFOOPTION_PROCESSSTAT) != 0, + handler); + + if (pAsyncInfo == NULL) + { + wrapper->m_event->WaitOne(); + return wrapper->m_hResult; + } + else + { + return HRESULT_FROM_WIN32(ERROR_IO_PENDING); + } + +} diff --git a/xcompute_native/scheduler.cpp b/xcompute_native/scheduler.cpp new file mode 100644 index 0000000..14dd4e1 --- /dev/null +++ b/xcompute_native/scheduler.cpp @@ -0,0 +1,258 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/*++ + +Module Name: + + scheduler.cpp + +Abstract: + + This module contains the public interface and support routines for + the xcompute scheduling functionality on top of the HPC scheduler. + +--*/ +#include "stdafx.h" +using namespace System; +using namespace System::Threading; +using namespace System::Collections::Generic; +using namespace System::Collections::Specialized; +using namespace System::Runtime::InteropServices; +using namespace Microsoft::Research::Dryad; + +/*++ + +XcScheduleProcess API + +Description: + +Contacts the Process Scheduler to schedule an XCompute Process. +Any XCompute Process in a Job may schedule additional +XCompute Processes in the same Job by requesting their creation +through the XCompute Process Scheduler, using this API. + +NOTE: +This call always returns immediately. +A successful return code from the API indicates that the +XcScheduleProcess request was added to the local scheduleProcess queue. +The user should use the XcWaitForStateChange(XCPROCESSSTATE_ASSIGNEDTOPN) +API to see when the process actually gets scheduled to the Process Scheduler + +Arguments: + + hProcessHandle + + + Handle to the process. + Use the XcCreateNewProcessHandle () API + to get obtain the handle to the process + + pScheduleProcessDescriptor + See PCXC_SCHEDULEPROCESS_DESCRIPTOR in + XComputeTypes.h. This datastructure is + copied before the function returns and + so it is not necessary for the caller + to preserve the contents during a + async call + + Return Value: + + S_OK indicating the operation was successfully started. + + Any other return value indicates the scheduleprocess request + could not be started + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcScheduleProcess( + IN XCPROCESSHANDLE hProcessHandle, + IN PCXC_SCHEDULEPROCESS_DESCRIPTOR pScheduleProcessDescriptor +) +{ + HRESULT hr = S_OK; + + try + { + if (pScheduleProcessDescriptor->Size != sizeof(XC_SCHEDULEPROCESS_DESCRIPTOR)) + { + return E_INVALIDARG; + } + PCXC_CREATEPROCESS_DESCRIPTOR proc = pScheduleProcessDescriptor->pCreateProcessDescriptor; + if (proc->Size != sizeof(XC_CREATEPROCESS_DESCRIPTOR)) + { + return E_INVALIDARG; + } + + // + // Set command line and friendly name + // + StringDictionary ^env = gcnew StringDictionary(); + + String ^cmdLine = gcnew String(proc->pCommandLine); + String ^name = gcnew String(proc->pProcessFriendlyName); + + // + // Walk through the list of environment strings and set each one + // + PCSTR envValue = NULL; + PCSTR envName = proc->pEnvironmentStrings; + + while (envName && *envName) + { + envValue = envName + strlen(envName) + 1; + + env->Add(gcnew System::String(envName), gcnew System::String(envValue)); + envName = envValue + strlen(envValue) + 1; + } + + + // + // If the caller asked for a single machine affinity, and it's a hard affinity, + // then put the machine requirement on the task. Otherwise let the job scheduler + // pick whatever it likes. + // + + List ^requestedNodes = gcnew List(); + String ^requiredNode = nullptr; + + if (pScheduleProcessDescriptor->pLocalityDescriptor) + { + if (pScheduleProcessDescriptor->pLocalityDescriptor->NumberOfAffinities == 1 && + (pScheduleProcessDescriptor->pLocalityDescriptor->pAffinities[0].Flags & XCAFFINITY_HARD)) + { + requiredNode = gcnew String(pScheduleProcessDescriptor->pLocalityDescriptor->pAffinities->pNetworkLocalityPaths[0]); + } + else + { + for (unsigned int i = 0; i < pScheduleProcessDescriptor->pLocalityDescriptor->NumberOfAffinities; i++) + { + if (pScheduleProcessDescriptor->pLocalityDescriptor->pAffinities[i].NumberOfNetworkLocalityPaths > 0) + { + String ^machineName = gcnew String(pScheduleProcessDescriptor->pLocalityDescriptor->pAffinities[i].pNetworkLocalityPaths[0]); + + if (!machineName->Equals(Microsoft::Research::Dryad::AzureUtils::CurrentHostName, StringComparison::OrdinalIgnoreCase)) + { + SoftAffinity ^affinity = gcnew SoftAffinity( + machineName, + pScheduleProcessDescriptor->pLocalityDescriptor->pAffinities[i].Weight + ); + requestedNodes->Add(affinity); + } + } + } + } + } + + + VertexScheduler ^client = VertexScheduler::GetInstance(); + + if ( client->ScheduleProcess((int)hProcessHandle, cmdLine, requestedNodes, requiredNode, env)) + { + hr = S_OK; + } + else + { + hr = E_FAIL; + } + } + catch (Exception ^e) + { + hr = System::Runtime::InteropServices::Marshal::GetHRForException(e); + } + + return hr; +} + + + + +/*++ + +XcCancelScheduleProcess API + +Description: + +Contacts the Process Scheduler to cancel the scheduled +XCompute Process. This API is used by the Parent XCompute process +that originally scheduled the XCompute process to cancel its +creation. +NOTE: The XCompute process will get cancelled, only if has not +already been created on a process node. The returned error code +indicates whether the process was successfully cancelled or not. + +Arguments: + + hProcessHandle + Handle to the process. + + pProcessId + The processId that needs to be cancelled + + pAsyncInfo + The async info structure. Its an alias to the + CS_ASYNC_INFO defined in Cosmos.h. IF this + parameter is NULL, then function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then operation is carried + on in asynchronous manner. If asynchronous + operation has been successfully started then + function terminates immediately with + HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start asynchronous operation. + + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCancelScheduleProcess( + IN XCPROCESSHANDLE hProcessHandle, + IN PCXC_ASYNC_INFO pAsyncInfo +) +{ + PASYNC async; + CAPTURE_ASYNC(async); + + Microsoft::Research::Dryad::VertexScheduler::GetInstance()->CancelScheduleProcess((int)hProcessHandle); + + return COMPLETE_ASYNC(async, S_OK); +} diff --git a/xcompute_native/session.cpp b/xcompute_native/session.cpp new file mode 100644 index 0000000..b3eaefc --- /dev/null +++ b/xcompute_native/session.cpp @@ -0,0 +1,152 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/*++ + +Module Name: + + session.cpp + +Abstract: + + This module contains the support routines for the xcompute + session functionality on top of the HPC scheduler. + +--*/ + +#include "stdafx.h" + +using namespace System::Runtime::InteropServices; +using namespace Microsoft::Research::Dryad; + + + +/*++ + +XcOpenSession API + +Description: + +Opens an XCompute session for a given cluster. Each session +is associated with a cluster and is independent of other sessiosn. +The session (apart from other things) is associated with +user credientials. + +It is possible to create multiple sessions for the same cluster and +these multiple sessions will behave independent of each other. This +is particularly useful for applications like WebServer which will run +multiple sessions, one per user. + +Use the XcCloseSession to close the handle returned as a result of +XcOpenSession call. + +Arguments: + + pOpenSessionParams + The Open Session Parameters. Passes info about cluster to + establish session with, clientId, etc. + See XC_OPEN_SESSION_PARAMS for details. + + Pass NULL for defaults - Default cluster and a default cliend id. + + phSessionHandle + Handle to session + + pAsyncInfo + The async info structure. Its an alias to the + CS_ASYNC_INFO defined in Cosmos.h. IF this + parameter is NULL, then function completes in + synchronous manner and error code is returned as + return value. + + If parameter is not NULL then operation is carried + on in asynchronous manner. If asynchronous + operation has been successfully started then + function terminates immediately with + HRESULT_FROM_WIN32(ERROR_IO_PENDING) return value. + Any other return value indicates that it was + impossible to start asynchronous operation. + + + Return Value: + + if pAsyncInfo is NULL + XCERROR_OK indicates call succeeded + + Any other error code, indicates the failure reason. + + + if pAsyncInfo != NULL + HRESULT_FROM_WIN32(ERROR_IO_PENDING) indicates the async + operation was successfully started + + Any other return value indicates it was impossible to start + asynchronous operation + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcOpenSession( + IN PCXC_OPEN_SESSION_PARAMS pOpenSessionParams, + OUT PXCSESSIONHANDLE pSessionHandle, + IN PCXC_ASYNC_INFO pAsyncInfo +) +{ + HRESULT hr = S_OK; + + *pSessionHandle = (XCSESSIONHANDLE)1; + + PASYNC async; + CAPTURE_ASYNC(async); + + return COMPLETE_ASYNC(async, hr); +} + +/*++ + +XcCloseSession API + +Description: + +Closes the session. + +Arguments: + + hSessionHandle + Handle to session to close + +Return Value: + + XCERROR_OK + Call succeeded. + +--*/ + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCloseSession( + IN XCSESSIONHANDLE hSessionHandle +) +{ + VertexScheduler::GetInstance()->Shutdown(0); + return S_OK; +} diff --git a/xcompute_native/status.cpp b/xcompute_native/status.cpp new file mode 100644 index 0000000..9fab306 --- /dev/null +++ b/xcompute_native/status.cpp @@ -0,0 +1,111 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/*++ +Module Name: + + file.cpp + +Abstract: + + This module contains the public interface and support routines for + the xcompute process file functionality on top of the HPC scheduler. + +--*/ +#include "stdafx.h" + +using namespace Microsoft::Research::Dryad; + + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcResetProgress( + IN XCSESSIONHANDLE SessionHandle, + IN ULONG nTotalProgressSteps, + IN bool bUpdate +) +{ + VertexScheduler::GetInstance()->JobStatus->ResetProgress(nTotalProgressSteps, bUpdate); + return S_OK; +} + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcIncrementTotalSteps( + IN XCSESSIONHANDLE SessionHandle, + IN bool bUpdate + ) +{ + VertexScheduler::GetInstance()->JobStatus->IncrementTotalSteps(bUpdate); + return S_OK; +} + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcDecrementTotalSteps( + IN XCSESSIONHANDLE SessionHandle, + IN bool bUpdate + ) +{ + VertexScheduler::GetInstance()->JobStatus->DecrementTotalSteps(bUpdate); + return S_OK; +} + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcSetProgress( + IN XCSESSIONHANDLE SessionHandle, + IN ULONG nCompletedProgressSteps, + IN PCSTR pMessage +) +{ + VertexScheduler::GetInstance()->JobStatus->SetProgress(nCompletedProgressSteps, gcnew System::String(pMessage)); + return S_OK; +} + + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcIncrementProgress( + IN XCSESSIONHANDLE SessionHandle, + IN PCSTR pMessage +) +{ + VertexScheduler::GetInstance()->JobStatus->IncrementProgress(gcnew System::String(pMessage)); + return S_OK; +} + + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcCompleteProgress( + IN XCSESSIONHANDLE SessionHandle, + IN PCSTR pMessage +) +{ + VertexScheduler::GetInstance()->JobStatus->CompleteProgress(gcnew System::String(pMessage)); + return S_OK; +} diff --git a/xcompute_native/stdafx.cpp b/xcompute_native/stdafx.cpp new file mode 100644 index 0000000..5b94ae8 --- /dev/null +++ b/xcompute_native/stdafx.cpp @@ -0,0 +1,28 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// stdafx.cpp : source file that includes just the standard includes +// xcompute.pch will be the pre-compiled header +// stdafx.obj will contain the pre-compiled type information + +#include "stdafx.h" + +// TODO: reference any additional headers you need in STDAFX.H +// and not in this file diff --git a/xcompute_native/stdafx.h b/xcompute_native/stdafx.h new file mode 100644 index 0000000..88aefe1 --- /dev/null +++ b/xcompute_native/stdafx.h @@ -0,0 +1,47 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +// stdafx.h : include file for standard system include files, +// or project specific include files that are used frequently, but +// are changed infrequently +// + +#pragma once + +#include "targetver.h" + +#define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers +// Windows Header Files: +#include +#include +#include + + +#define _ATL_CSTRING_EXPLICIT_CONSTRUCTORS // some CString constructors will be explicit + +#include +#include +#include +#include +#include +#define XCOMPUTE_EXPORTS +#include "XComputeTypes.h" +#include "XCompute.h" +#include "xcimpl.h" diff --git a/xcompute_native/targetver.h b/xcompute_native/targetver.h new file mode 100644 index 0000000..76fc280 --- /dev/null +++ b/xcompute_native/targetver.h @@ -0,0 +1,44 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +#pragma once + +// The following macros define the minimum required platform. The minimum required platform +// is the earliest version of Windows, Internet Explorer etc. that has the necessary features to run +// your application. The macros work by enabling all features available on platform versions up to and +// including the version specified. + +// Modify the following defines if you have to target a platform prior to the ones specified below. +// Refer to MSDN for the latest info on corresponding values for different platforms. +#ifndef WINVER // Specifies that the minimum required platform is Windows Vista. +#define WINVER 0x0600 // Change this to the appropriate value to target other versions of Windows. +#endif + +#ifndef _WIN32_WINNT // Specifies that the minimum required platform is Windows Vista. +#define _WIN32_WINNT 0x0600 // Change this to the appropriate value to target other versions of Windows. +#endif + +#ifndef _WIN32_WINDOWS // Specifies that the minimum required platform is Windows 98. +#define _WIN32_WINDOWS 0x0410 // Change this to the appropriate value to target Windows Me or later. +#endif + +#ifndef _WIN32_IE // Specifies that the minimum required platform is Internet Explorer 7.0. +#define _WIN32_IE 0x0700 // Change this to the appropriate value to target other versions of IE. +#endif diff --git a/xcompute_native/xcimpl.h b/xcompute_native/xcimpl.h new file mode 100644 index 0000000..ab0d6bf --- /dev/null +++ b/xcompute_native/xcimpl.h @@ -0,0 +1,127 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/*++ +Module Name: + + xcimpl.h + +Abstract: + + This module contains the private prototypes and definitions for + the xcompute API implementation on top of the HPC scheduler. + +--*/ + +using namespace Microsoft::Research::Dryad; + + +// +// timing infrastructure +// +extern DWORD g_StartTime; +#define TS ((::GetTickCount() - g_StartTime)/1000.0) + + +// +// +// + +#define INVALID_PROCESS_ID ((XCPROCESSHANDLE)0) +#define CURRENT_PROCESS_ID ((XCPROCESSHANDLE)1) + +DWORD GetNextProcessId(); + +// +// Utility routines in async.cpp for managing the asynchronous completion +// notification functionality of xcompute +// + +// +// The copy of the async structure we keep around until the +// operation completes. +// +class ASYNC +{ +public: + ASYNC(PCXC_ASYNC_INFO pAsyncInfo); + virtual ~ASYNC(); + static ASYNC *Capture(PCXC_ASYNC_INFO pAsyncInfo); + HRESULT Complete(HRESULT hr); +private: + HRESULT *pOperationState; + HANDLE hEvent; + HANDLE hIOCP; + LPOVERLAPPED pOverlapped; + UINT_PTR CompletionKey; +}; +typedef ASYNC *PASYNC; + + +ref class AsyncWrapper +{ + PASYNC ap; +public: + AsyncWrapper(PASYNC ap):ap(ap) {} + + void StateChangeHandler(System::Object ^sender,Microsoft::Research::Dryad::XComputeProcessStateChangeEventArgs ^e) + { + if (e->TimedOut) + { + ap->Complete(HRESULT_FROM_WIN32(ERROR_TIMEOUT)); + } + else + { + ap->Complete(S_OK); + } + } + + void GetSetPropertyHandler(System::Object ^sender, Microsoft::Research::Dryad::XComputeProcessGetSetPropertyEventArgs ^e) + { + ap->Complete(S_OK); + } +}; + + +// +// N.B. This macro will RETURN HRESULT on failure, so be sure no additional +// error cleanup is required. +// +#define CAPTURE_ASYNC(_pasync_) \ + { \ + if (pAsyncInfo) { \ + _pasync_ = new ASYNC(pAsyncInfo); \ + if (_pasync_ == NULL) \ + { \ + return E_OUTOFMEMORY; \ + } \ + } else { \ + _pasync_ = NULL; \ + } \ + } + +#define COMPLETE_ASYNC(_pasync_, _hr_) ((_pasync_) ? _pasync_->Complete(_hr_) : _hr_) + +// +// Translate from managed ProcessState to native XCPROCESSSTATE +// +XCPROCESSSTATE TranslateProcessState( + IN ProcessState fromState +); diff --git a/xcompute_native/xcompute.cpp b/xcompute_native/xcompute.cpp new file mode 100644 index 0000000..f910241 --- /dev/null +++ b/xcompute_native/xcompute.cpp @@ -0,0 +1,111 @@ +/* +Copyright (c) Microsoft Corporation + +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License +at http://www.apache.org/licenses/LICENSE-2.0 + + +THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER +EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF +TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + + +See the Apache Version 2.0 License for specific language governing permissions and +limitations under the License. + +*/ + +/*++ +Module Name: + + xcompute.cpp + +Abstract: + + This module contains initialization, cleanup, and some + utility functions for xcompute on top of the HPC scheduler. + +--*/ + +#include "stdafx.h" + +/*++ + +XcInitialize API + +Description: + +Call this function at the start to initialize the various internal +data structures of the XCompute SDK library. + +Arguments: + + configFile + Name of the config file + + componentName + The name of the component + +Return Value: + + XCERROR_OK + Call succeeded. + NOTE: + S_FALSE will be returned if the initialize + has already been called. + +--*/ + +DWORD g_StartTime; + +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcInitialize( + IN PCSTR configFile, + IN PCSTR componentName +) +{ + g_StartTime = ::GetTickCount(); + return S_OK; +} + + +/*++ + +XcFreeMemory API + +Description: + +Frees the memory allocated by the XCompute API. +All the memory returned as a result of call to +the XCompute API should use the XcFreeMemory to +deallocate the memory + +Arguments: + + prt + Pointer to the memory + +Return Value: + + XCERROR_OK + Memory was successfully deallocated + +--*/ +XCOMPUTEAPI_EXT +XCERROR +XCOMPUTEAPI +XcFreeMemory( + IN PCVOID ptr +) +{ + if (ptr) + { + LocalFree((HLOCAL)ptr); + } + return S_OK; +}