diff --git a/libraries/bxdf/gltf_pbr.mtlx b/libraries/bxdf/gltf_pbr.mtlx
index 8dde49ca9e..03b5617f0a 100644
--- a/libraries/bxdf/gltf_pbr.mtlx
+++ b/libraries/bxdf/gltf_pbr.mtlx
@@ -140,7 +140,6 @@
-
@@ -150,6 +149,11 @@
+
+
+
+
+
@@ -159,21 +163,6 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
@@ -185,20 +174,18 @@
-
-
-
-
-
-
-
-
-
+
+
+
-
+
+
+
+
+
@@ -206,11 +193,7 @@
-
-
-
-
+
@@ -219,18 +202,16 @@
-
-
+
-
+
-
-
-
+
+
diff --git a/source/MaterialXGenShader/GenOptions.h b/source/MaterialXGenShader/GenOptions.h
index 5b77c5ca35..4fb71df718 100644
--- a/source/MaterialXGenShader/GenOptions.h
+++ b/source/MaterialXGenShader/GenOptions.h
@@ -94,7 +94,8 @@ class MX_GENSHADER_API GenOptions
hwNormalizeUdimTexCoords(false),
hwWriteAlbedoTable(false),
hwWriteEnvPrefilter(false),
- hwImplicitBitangents(true)
+ hwImplicitBitangents(true),
+ optReplaceBsdfMixWithLinearCombination(false)
{
}
virtual ~GenOptions() { }
@@ -197,6 +198,10 @@ class MX_GENSHADER_API GenOptions
/// Calculate fallback bitangents from existing normals and tangents
/// inside the bitangent node.
bool hwImplicitBitangents;
+
+ /// Analyse the graph of ShaderNodes and replace any ND_mix_bsdf nodes
+ /// with a linear combination of their weighted inputs
+ bool optReplaceBsdfMixWithLinearCombination;
};
MATERIALX_NAMESPACE_END
diff --git a/source/MaterialXGenShader/ShaderGraph.cpp b/source/MaterialXGenShader/ShaderGraph.cpp
index 5658de328c..d9c8b3d71e 100644
--- a/source/MaterialXGenShader/ShaderGraph.cpp
+++ b/source/MaterialXGenShader/ShaderGraph.cpp
@@ -629,7 +629,7 @@ ShaderGraphPtr ShaderGraph::create(const ShaderGraph* parent, const string& name
}
// Apply color and unit transforms to each input.
- graph->applyInputTransforms(node, newNode, context);
+ graph->applyInputTransforms(node, newNode.get(), context);
// Set root for upstream dependency traversal
root = node;
@@ -651,7 +651,7 @@ ShaderGraphPtr ShaderGraph::create(const ShaderGraph* parent, const string& name
return graph;
}
-void ShaderGraph::applyInputTransforms(ConstNodePtr node, ShaderNodePtr shaderNode, GenContext& context)
+void ShaderGraph::applyInputTransforms(ConstNodePtr node, ShaderNode* shaderNode, GenContext& context)
{
ColorManagementSystemPtr colorManagementSystem = context.getShaderGenerator().getColorManagementSystem();
UnitSystemPtr unitSystem = context.getShaderGenerator().getUnitSystem();
@@ -701,20 +701,29 @@ void ShaderGraph::applyInputTransforms(ConstNodePtr node, ShaderNodePtr shaderNo
}
}
-ShaderNode* ShaderGraph::createNode(ConstNodePtr node, GenContext& context)
+ShaderNode* ShaderGraph::createNode(const string& name, ConstNodeDefPtr nodeDef, GenContext& context)
{
- NodeDefPtr nodeDef = node->getNodeDef();
if (!nodeDef)
{
- throw ExceptionShaderGenError("Could not find a nodedef for node '" + node->getName() + "'");
+ throw ExceptionShaderGenError("Could not find a nodedef for node '" + name + "'");
}
+ // Create this node in the graph.
+ ShaderNodePtr newNode = ShaderNode::create(this, name, *nodeDef, context);
+ _nodeMap[name] = newNode;
+ _nodeOrder.push_back(newNode.get());
+
+ return newNode.get();
+}
+
+ShaderNode* ShaderGraph::createNode(ConstNodePtr node, GenContext& context)
+{
+ ConstNodeDefPtr nodeDef = node->getNodeDef();
+
// Create this node in the graph.
context.pushParentNode(node);
- ShaderNodePtr newNode = ShaderNode::create(this, node->getName(), *nodeDef, context);
+ ShaderNode* newNode = createNode(node->getName(), nodeDef, context);
newNode->initialize(*node, *nodeDef, context);
- _nodeMap[node->getName()] = newNode;
- _nodeOrder.push_back(newNode.get());
context.popParentNode();
// Check if any of the node inputs should be connected to the graph interface
@@ -757,7 +766,7 @@ ShaderNode* ShaderGraph::createNode(ConstNodePtr node, GenContext& context)
// Apply color and unit transforms to each input.
applyInputTransforms(node, newNode, context);
- return newNode.get();
+ return newNode;
}
ShaderGraphInputSocket* ShaderGraph::addInputSocket(const string& name, TypeDesc type)
@@ -787,6 +796,22 @@ ShaderNode* ShaderGraph::getNode(const string& name)
return it != _nodeMap.end() ? it->second.get() : nullptr;
}
+bool ShaderGraph::removeNode(ShaderNode* node)
+{
+ auto mapIt = _nodeMap.find(node->getName());
+ auto vecIt = std::find(_nodeOrder.begin(), _nodeOrder.end(), node);
+ if (mapIt == _nodeMap.end() || vecIt == _nodeOrder.end())
+ return false;
+
+ for (ShaderInput* input : node->getInputs())
+ {
+ input->breakConnection();
+ }
+ _nodeMap.erase(mapIt);
+ _nodeOrder.erase(vecIt);
+ return true;
+}
+
const ShaderNode* ShaderGraph::getNode(const string& name) const
{
return const_cast(this)->getNode(name);
@@ -835,8 +860,8 @@ void ShaderGraph::finalize(GenContext& context)
_inputUnitTransformMap.clear();
_outputUnitTransformMap.clear();
- // Optimize the graph, removing redundant paths.
- optimize();
+ // Optimize the graph, removing redundant paths and applying graph transformations
+ optimize(context);
// Sort the nodes in topological order.
topologicalSort();
@@ -896,7 +921,7 @@ void ShaderGraph::disconnect(ShaderNode* node) const
}
}
-void ShaderGraph::optimize()
+void ShaderGraph::optimize(GenContext& context)
{
size_t numEdits = 0;
for (ShaderNode* node : getNodes())
@@ -960,8 +985,382 @@ void ShaderGraph::optimize()
_nodeOrder = usedNodesVec;
}
+
+ // we take a copy of the node list because we might modify it during the optimization
+ if (context.getOptions().optReplaceBsdfMixWithLinearCombination)
+ {
+ const vector nodeList = getNodes();
+ for (ShaderNode* node : nodeList)
+ {
+ // first check the node is still in the graph, and hasn't been removed by a
+ // prior optimization
+ if (!getNode(node->getName()))
+ continue;
+
+ if (node->hasClassification(ShaderNode::Classification::MIX_BSDF))
+ {
+ if (!optimizeMixMixBsdf(node, context))
+ {
+ optimizeMixBsdf(node, context);
+ }
+ }
+ }
+ }
+}
+
+// Optimize the MixBsdf node by replacing it with a new node graph.
+//
+// The current nodegraph
+//
+// ┌────────┐┌────────┐┌───────┐
+// │A.weight││B.weight││Mix.mix│
+// └┬───────┘└┬───────┘└┬──────┘
+// ┌▽───────┐┌▽───────┐ │
+// │A (BSDF)││B (BSDF)│ │
+// └┬───────┘└┬───────┘ │
+// ┌▽─────────▽─────────▽┐
+// │Mix │
+// └─────────────────────┘
+//
+// New nodegraph
+// ┌────────┐
+// │Mix.mix │
+// └┬──────┬┘
+// ┌▽─────┐│┌────────┐┌────────┐
+// │Invert│││A.weight││B.weight│
+// └─────┬┘│└┬───────┘└──┬─────┘
+// └─│─│──┐ │
+// ┌───────▽─▽┐┌▽────────▽┐
+// │Multiply A││Multiply B│
+// └┬─────────┘└┬─────────┘
+// ┌▽───────┐┌──▽─────┐
+// │A (BSDF)││B (BSDF)│
+// └┬───────┘└┬───────┘
+// ┌▽─────────▽┐
+// │Add │
+// └───────────┘
+//
+// Motivation - the new graph is more efficient for shader backends to optimize
+// away possible expensive BSDF nodes that might not be used at all.
+bool ShaderGraph::optimizeMixBsdf(ShaderNode* mixNode, GenContext& context)
+{
+ // criteria for optimization...
+ // * upstream nodes for MixBsdf node both have `weight` input ports
+ // * We have the following node definitions available in the library
+ // * ND_add_bsdf
+ // * ND_invert_float
+ // * ND_multiply_float
+
+ auto mixFgInput = mixNode->getInput("fg");
+ auto mixBgInput = mixNode->getInput("bg");
+
+ // the standard data library ND_mix_bsdf should always have "fg" and "bg" inputs
+ // but we check here anyway to ensure we're not using a custom data library that doesn't follow that convention
+ if (!mixFgInput || !mixBgInput)
+ return false;
+
+ auto fgNode = mixFgInput->getConnectedSibling();
+ auto bgNode = mixBgInput->getConnectedSibling();
+
+ // We check to see we have two upstream nodes - there are almost certainly other optimizations possible
+ // if this isn't true, but we will leave those for a later PR.
+ if (!fgNode && !bgNode)
+ return false;
+
+ auto fgNodeWeightInput = fgNode->getInput("weight");
+ auto bgNodeWeightInput = bgNode->getInput("weight");
+
+ // We require both upstream nodes to have a "weight" input for this optimization to work.
+ if (!fgNodeWeightInput || !bgNodeWeightInput)
+ return false;
+
+ // we also require the following list of node definitions
+ auto addBsdfNodeDef = _document->getNodeDef("ND_add_bsdf");
+ auto floatInvertNodeDef = _document->getNodeDef("ND_invert_float");
+ auto floatMultNodeDef = _document->getNodeDef("ND_multiply_float");
+ if (!addBsdfNodeDef || !floatInvertNodeDef || !floatMultNodeDef)
+ return false;
+
+ // We meet the requirements for the optimization.
+ // We can now create the new nodes and connect them up.
+
+ // Helper function to redirect the incoming connection to from one input port
+ // to another.
+ // We intentionally skip error checking here, as we're doing it below.
+ // If this proves useful we should make it a method somewhere, and add
+ // more robust error checking.
+ auto redirectInput = [](ShaderInput* fromPort, ShaderInput* toPort) -> void
+ {
+ auto connection = fromPort->getConnection();
+ if (connection)
+ {
+ // we have a connection - so transfer it
+ toPort->makeConnection(connection);
+ }
+ else
+ {
+ // we just remap the value.
+ toPort->setValue(fromPort->getValue());
+ }
+ };
+
+ // Helper function to connect two nodes together, consolidating the valiation of the ports existance.
+ // If this proves useful we should make it a method somewhere.
+ auto connectNodes = [](ShaderNode* fromNode, const string& fromPortName, ShaderNode* toNode, const string& toPortName) -> void
+ {
+ auto fromPort = fromNode->getOutput(fromPortName);
+ auto toPort = toNode->getInput(toPortName);
+ if (!fromPort || !toPort)
+ return;
+
+ fromPort->makeConnection(toPort);
+ };
+
+ auto mixWeightInput = mixNode->getInput("mix");
+
+ // create a node that represents the inverted mix value, ie. 1.0-mix
+ // to be used for the "bg" side of the mix
+ auto invertMixNode = this->createNode(mixNode->getName()+"_INV", floatInvertNodeDef, context);
+ redirectInput(mixWeightInput, invertMixNode->getInput("in"));
+
+ // create a multiply node to calculate the new weight value, weighted by the mix value.
+ auto multFgWeightNode = this->createNode(mixNode->getName()+"_MULT_FG", floatMultNodeDef, context);
+ redirectInput(fgNodeWeightInput, multFgWeightNode->getInput("in1"));
+ redirectInput(mixWeightInput, multFgWeightNode->getInput("in2"));
+
+ // create a multiply node to calculate the new weight value, weighted by the inverted mix value.
+ auto multBgWeightNode = this->createNode(mixNode->getName()+"_MULT_BG", floatMultNodeDef, context);
+ redirectInput(bgNodeWeightInput, multBgWeightNode->getInput("in1"));
+ connectNodes(invertMixNode, "out", multBgWeightNode, "in2");
+
+ // connect the two newly created weights to the fg and bg BSDF nodes.
+ connectNodes(multFgWeightNode, "out", fgNode, "weight");
+ connectNodes(multBgWeightNode, "out", bgNode, "weight");
+
+ // Create the ND_add_bsdf node that will add the two BSDF nodes with the modified weights
+ // this replaces the original mix node.
+ auto addNode = this->createNode(mixNode->getName()+"_ADD", addBsdfNodeDef, context);
+ connectNodes(bgNode, "out", addNode, "in1");
+ connectNodes(fgNode, "out", addNode, "in2");
+
+ // Finally for all the previous outgoing connections from the original mix node
+ // replace those with the outgoing connection from the new add node.
+ auto mixNodeOutput = mixNode->getOutput("out");
+ auto mixNodeOutputConns = mixNodeOutput->getConnections();
+ for (auto conn : mixNodeOutputConns)
+ {
+ addNode->getOutput("out")->makeConnection(conn);
+ }
+
+ removeNode(mixNode);
+
+ return true;
+}
+
+
+// Optimize the combination of two MixBsdf nodes by replacing it with a new node graph.
+//
+// The current nodegraph
+// ┌────────┐┌────────┐┌────────┐
+// │A.weight││B.weight││Mix1.mix│
+// └┬───────┘└┬───────┘└┬───────┘
+// ┌▽───────┐┌▽───────┐ │
+// │A (BSDF)││B (BSDF)│ │
+// └┬───────┘└┬───────┘ │
+// ┌▽─────────▽─────────▽┐┌────────┐┌────────┐
+// │Mix_1 ││C.weight││Mix2.mix│
+// └┬────────────────────┘└┬───────┘└┬───────┘
+// | ┌▽───────┐ |
+// | │C (BSDF)│ |
+// | └┬───────┘ |
+// ┌▽──────────────────────▽─────────▽───────┐
+// |Mix_2 |
+// └─────────────────────────────────────────┘
+//
+// New nodegraph //TODO
+// ┌────────┐
+// │Mix.mix │
+// └┬──────┬┘
+// ┌▽─────┐│┌────────┐┌────────┐
+// │Invert│││A.weight││B.weight│
+// └─────┬┘│└┬───────┘└──┬─────┘
+// └─│─│──┐ │
+// ┌───────▽─▽┐┌▽────────▽┐
+// │Multiply A││Multiply B│
+// └┬─────────┘└┬─────────┘
+// ┌▽───────┐┌──▽─────┐
+// │A (BSDF)││B (BSDF)│
+// └┬───────┘└┬───────┘
+// ┌▽─────────▽┐
+// │Add │
+// └───────────┘
+//
+// Motivation - the new graph is more efficient for shader backends to optimize
+// away possible expensive BSDF nodes that might not be used at all.
+bool ShaderGraph::optimizeMixMixBsdf(ShaderNode* mixNode_x, GenContext& context)
+{
+ // criteria for optimization...
+ // * upstream nodes for MixBsdf node both have `weight` input ports
+ // * We have the following node definitions available in the library
+ // * ND_add_bsdf
+ // * ND_invert_float
+ // * ND_multiply_float
+
+ ShaderNode* mix2Node = mixNode_x;
+
+ auto mix2FgInput = mix2Node->getInput("fg");
+ auto mix2BgInput = mix2Node->getInput("bg");
+
+ // the standard data library ND_mix_bsdf should always have "fg" and "bg" inputs
+ // but we check here anyway to ensure we're not using a custom data library that doesn't follow that convention
+ if (!mix2FgInput || !mix2BgInput)
+ return false;
+
+ auto mix2FgNode = mix2FgInput->getConnectedSibling();
+ auto mix2BgNode = mix2BgInput->getConnectedSibling();
+
+ // We check to see we have two upstream nodes - there are almost certainly other optimizations possible
+ // if this isn't true, but we will leave those for a later PR.
+ if (!mix2FgNode && !mix2BgNode)
+ return false;
+
+ // we require the node connected to the "fg" input to also be a mixBsdf node
+ if (!mix2FgNode->hasClassification(ShaderNode::Classification::MIX_BSDF))
+ return false;
+
+ ShaderNode* mix1Node = mix2FgNode;
+
+
+ auto mix1FgInput = mix1Node->getInput("fg");
+ auto mix1BgInput = mix1Node->getInput("bg");
+
+ // the standard data library ND_mix_bsdf should always have "fg" and "bg" inputs
+ // but we check here anyway to ensure we're not using a custom data library that doesn't follow that convention
+ if (!mix1FgInput || !mix1BgInput)
+ return false;
+
+ auto mix1FgNode = mix1FgInput->getConnectedSibling();
+ auto mix1BgNode = mix1BgInput->getConnectedSibling();
+
+ // We check to see we have two upstream nodes - there are almost certainly other optimizations possible
+ // if this isn't true, but we will leave those for a later PR.
+ if (!mix1FgNode && !mix1BgNode)
+ return false;
+
+ auto mix1FgNodeWeightInput = mix1FgNode->getInput("weight");
+ auto mix1BgNodeWeightInput = mix1BgNode->getInput("weight");
+ auto mix2BgNodeWeightInput = mix2BgNode->getInput("weight");
+
+ // We require both upstream nodes to have a "weight" input for this optimization to work.
+ if (!mix1FgNodeWeightInput || !mix1BgNodeWeightInput || !mix2BgNodeWeightInput)
+ return false;
+
+
+ // we also require the following list of node definitions
+ auto addBsdfNodeDef = _document->getNodeDef("ND_add_bsdf");
+ auto floatInvertNodeDef = _document->getNodeDef("ND_invert_float");
+ auto floatMultNodeDef = _document->getNodeDef("ND_multiply_float");
+ if (!addBsdfNodeDef || !floatInvertNodeDef || !floatMultNodeDef)
+ return false;
+
+ // We meet the requirements for the optimization.
+ // We can now create the new nodes and connect them up.
+
+ // Helper function to redirect the incoming connection to from one input port
+ // to another.
+ // We intentionally skip error checking here, as we're doing it below.
+ // If this proves useful we should make it a method somewhere, and add
+ // more robust error checking.
+ auto redirectInput = [](ShaderInput* fromPort, ShaderInput* toPort) -> void
+ {
+ auto connection = fromPort->getConnection();
+ if (connection)
+ {
+ // we have a connection - so transfer it
+ toPort->makeConnection(connection);
+ }
+ else
+ {
+ // we just remap the value.
+ toPort->setValue(fromPort->getValue());
+ }
+ };
+
+ // Helper function to connect two nodes together, consolidating the valiation of the ports existance.
+ // If this proves useful we should make it a method somewhere.
+ auto connectNodes = [](ShaderNode* fromNode, const string& fromPortName, ShaderNode* toNode, const string& toPortName) -> void
+ {
+ auto fromPort = fromNode->getOutput(fromPortName);
+ auto toPort = toNode->getInput(toPortName);
+ if (!fromPort || !toPort)
+ return;
+
+ fromPort->makeConnection(toPort);
+ };
+
+ auto mix1WeightInput = mix1Node->getInput("mix");
+ auto mix2WeightInput = mix2Node->getInput("mix");
+
+ // create nodes that represents the inverted mix values, ie. 1.0-mix
+ // to be used for the "bg" side of the mix
+ auto invertMix1Node = this->createNode(mix1Node->getName()+"_INV", floatInvertNodeDef, context);
+ redirectInput(mix1WeightInput, invertMix1Node->getInput("in"));
+ auto invertMix2Node = this->createNode(mix2Node->getName()+"_INV", floatInvertNodeDef, context);
+ redirectInput(mix2WeightInput, invertMix2Node->getInput("in"));
+
+
+ // create a multiply node to calculate the new weight value, weighted by the mix value.
+ auto multFg1WeightNode_intermediate = this->createNode(mix1Node->getName()+"_MULT_FG1_INTERMEDIATE", floatMultNodeDef, context);
+ redirectInput(mix1FgNodeWeightInput, multFg1WeightNode_intermediate->getInput("in1"));
+ redirectInput(mix1WeightInput, multFg1WeightNode_intermediate->getInput("in2"));
+ auto multFg1WeightNode = this->createNode(mix1Node->getName()+"_MULT_FG1", floatMultNodeDef, context);
+ redirectInput(mix2WeightInput, multFg1WeightNode->getInput("in1"));
+ connectNodes(multFg1WeightNode_intermediate, "out", multFg1WeightNode, "in2");
+
+ auto multBg1WeightNode_intermediate = this->createNode(mix1Node->getName()+"_MULT_BG1_INTERMEDIATE", floatMultNodeDef, context);
+ redirectInput(mix1BgNodeWeightInput, multBg1WeightNode_intermediate->getInput("in1"));
+ connectNodes(invertMix1Node, "out", multBg1WeightNode_intermediate, "in2");
+ auto multBg1WeightNode = this->createNode(mix1Node->getName()+"_MULT_BG1", floatMultNodeDef, context);
+ redirectInput(mix2WeightInput, multBg1WeightNode->getInput("in1"));
+ connectNodes(multBg1WeightNode_intermediate, "out", multBg1WeightNode, "in2");
+
+ auto multBg2WeightNode = this->createNode(mix2Node->getName()+"_MULT_BG2", floatMultNodeDef, context);
+ redirectInput(mix2BgNodeWeightInput, multBg2WeightNode->getInput("in1"));
+ connectNodes(invertMix2Node, "out", multBg2WeightNode, "in2");
+
+
+ // connect the two newly created weights to the fg and bg BSDF nodes.
+ connectNodes(multFg1WeightNode, "out", mix1FgNode, "weight");
+ connectNodes(multBg1WeightNode, "out", mix1BgNode, "weight");
+ connectNodes(multBg2WeightNode, "out", mix2BgNode, "weight");
+
+
+ // Create the ND_add_bsdf nodes that will add the three BSDF nodes with the modified weights
+ // this replaces the original mix nodes.
+ auto addNode_intermediate = this->createNode(mix2Node->getName()+"_ADD_INTERMEDIATE", addBsdfNodeDef, context);
+ connectNodes(mix1BgNode, "out", addNode_intermediate, "in1");
+ connectNodes(mix1FgNode, "out", addNode_intermediate, "in2");
+
+ auto addNode = this->createNode(mix2Node->getName()+"_ADD", addBsdfNodeDef, context);
+ connectNodes(addNode_intermediate, "out", addNode, "in1");
+ connectNodes(mix2BgNode, "out", addNode, "in2");
+
+ // Finally for all the previous outgoing connections from the original mix node
+ // replace those with the outgoing connection from the new add node.
+ auto mixNodeOutput = mix2Node->getOutput("out");
+ auto mixNodeOutputConns = mixNodeOutput->getConnections();
+ for (auto conn : mixNodeOutputConns)
+ {
+ addNode->getOutput("out")->makeConnection(conn);
+ }
+
+ removeNode(mix1Node);
+ removeNode(mix2Node);
+
+ return true;
}
+
void ShaderGraph::bypass(ShaderNode* node, size_t inputIndex, size_t outputIndex)
{
ShaderInput* input = node->getInput(inputIndex);
diff --git a/source/MaterialXGenShader/ShaderGraph.h b/source/MaterialXGenShader/ShaderGraph.h
index dda19de78a..0fa1f72654 100644
--- a/source/MaterialXGenShader/ShaderGraph.h
+++ b/source/MaterialXGenShader/ShaderGraph.h
@@ -93,7 +93,7 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode
const vector& getOutputSockets() const { return _inputOrder; }
/// Apply color and unit transforms to each input of a node.
- void applyInputTransforms(ConstNodePtr node, ShaderNodePtr shaderNode, GenContext& context);
+ void applyInputTransforms(ConstNodePtr node, ShaderNode* shaderNode, GenContext& context);
/// Create a new node in the graph
ShaderNode* createNode(ConstNodePtr node, GenContext& context);
@@ -129,9 +129,16 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode
ElementPtr connectingElement,
GenContext& context);
+ /// Create a new node in a graph from a node definition.
+ /// Note - this does not initialize the node instance with any concrete values, but
+ /// instead creates an empty instance of the provided node definition
+ ShaderNode* createNode(const string& name, ConstNodeDefPtr nodeDef, GenContext& context);
+
/// Add a node to the graph
void addNode(ShaderNodePtr node);
+ bool removeNode(ShaderNode* node);
+
/// Add input sockets from an interface element (nodedef, nodegraph or node)
void addInputSockets(const InterfaceElement& elem, GenContext& context);
@@ -159,7 +166,10 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode
void finalize(GenContext& context);
/// Optimize the graph, removing redundant paths.
- void optimize();
+ void optimize(GenContext& context);
+
+ bool optimizeMixBsdf(ShaderNode* node, GenContext& context);
+ bool optimizeMixMixBsdf(ShaderNode* node, GenContext& context);
/// Bypass a node for a particular input and output,
/// effectively connecting the input's upstream connection
diff --git a/source/MaterialXGenShader/ShaderNode.cpp b/source/MaterialXGenShader/ShaderNode.cpp
index 73f56b17e6..505d2772b2 100644
--- a/source/MaterialXGenShader/ShaderNode.cpp
+++ b/source/MaterialXGenShader/ShaderNode.cpp
@@ -306,6 +306,11 @@ ShaderNodePtr ShaderNode::create(const ShaderGraph* parent, const string& name,
newNode->_classification = Classification::TEXTURE | Classification::FILETEXTURE;
}
+ if (nodeDef.getName() == "ND_mix_bsdf")
+ {
+ newNode->_classification |= Classification::MIX_BSDF;
+ }
+
// Add in classification based on group name
if (groupName == TEXTURE2D_GROUPNAME || groupName == PROCEDURAL2D_GROUPNAME)
{
diff --git a/source/MaterialXGenShader/ShaderNode.h b/source/MaterialXGenShader/ShaderNode.h
index 527c588fc7..c86dcb3b2f 100644
--- a/source/MaterialXGenShader/ShaderNode.h
+++ b/source/MaterialXGenShader/ShaderNode.h
@@ -351,6 +351,8 @@ class MX_GENSHADER_API ShaderNode
static const uint32_t SAMPLE3D = 1 << 18; /// Can be sampled in 3D (position)
static const uint32_t GEOMETRIC = 1 << 19; /// Geometric input
static const uint32_t DOT = 1 << 20; /// A dot node
+ // Types based on node definition
+ static const uint32_t MIX_BSDF = 1 << 21; /// A BSDF mix node
};
static const ShaderNodePtr NONE;
diff --git a/source/MaterialXGraphEditor/RenderView.cpp b/source/MaterialXGraphEditor/RenderView.cpp
index cf44c4a925..48c5d41217 100644
--- a/source/MaterialXGraphEditor/RenderView.cpp
+++ b/source/MaterialXGraphEditor/RenderView.cpp
@@ -162,6 +162,7 @@ RenderView::RenderView(mx::DocumentPtr doc,
// Make sure all uniforms are added so value updates can
// find the corresponding uniform.
_genContext.getOptions().shaderInterfaceType = mx::SHADER_INTERFACE_COMPLETE;
+ _genContext.getOptions().optReplaceBsdfMixWithLinearCombination = true;
setDocument(doc);
_stdLib = stdLib;
diff --git a/source/MaterialXView/Viewer.cpp b/source/MaterialXView/Viewer.cpp
index 1fc7afc399..960a4058c0 100644
--- a/source/MaterialXView/Viewer.cpp
+++ b/source/MaterialXView/Viewer.cpp
@@ -255,6 +255,7 @@ Viewer::Viewer(const std::string& materialFilename,
_genContext.getOptions().fileTextureVerticalFlip = true;
_genContext.getOptions().hwShadowMap = true;
_genContext.getOptions().hwImplicitBitangents = false;
+ _genContext.getOptions().optReplaceBsdfMixWithLinearCombination = true;
#ifdef MATERIALXVIEW_METAL_BACKEND
_renderPipeline = MetalRenderPipeline::create(this);
@@ -267,6 +268,7 @@ Viewer::Viewer(const std::string& materialFilename,
_genContextEssl.getOptions().targetColorSpaceOverride = "lin_rec709";
_genContextEssl.getOptions().fileTextureVerticalFlip = false;
_genContextEssl.getOptions().hwMaxActiveLightSources = 1;
+ _genContextEssl.getOptions().optReplaceBsdfMixWithLinearCombination = true;
#endif
#if MATERIALX_BUILD_GEN_OSL
// Set OSL generator options.
@@ -846,6 +848,13 @@ void Viewer::createAdvancedSettings(ng::ref parent)
setShaderInterfaceType(interfaceType);
});
+ ng::ref optimizeBsdfMixBox = new ng::CheckBox(settingsGroup, "Optimize BSDF Mix");
+ optimizeBsdfMixBox->set_checked(_genContext.getOptions().optReplaceBsdfMixWithLinearCombination);
+ optimizeBsdfMixBox->set_callback([this](bool enable)
+ {
+ _genContext.getOptions().optReplaceBsdfMixWithLinearCombination = enable;
+ });
+
ng::ref albedoGroup = new Widget(settingsGroup);
albedoGroup->set_layout(new ng::BoxLayout(ng::Orientation::Horizontal));
new ng::Label(albedoGroup, "Albedo Method:");