Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 20 additions & 40 deletions source/MaterialXGenOsl/OslShaderGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,9 @@ ShaderPtr OslShaderGenerator::generate(const string& name, ElementPtr element, G
emitShaderInputs(stage.getInputBlock(OSL::INPUTS), stage);
emitShaderInputs(stage.getUniformBlock(OSL::UNIFORMS), stage);

// Emit shader output
// Emit shader outputs
const VariableBlock& outputs = stage.getOutputBlock(OSL::OUTPUTS);
const ShaderPort* singleOutput = outputs.size() == 1 ? outputs[0] : NULL;

const bool isSurfaceShaderOutput = singleOutput && singleOutput->getType() == Type::SURFACESHADER;

if (isSurfaceShaderOutput)
{
// Special case for having 'surfaceshader' as final output type.
// This type is a struct internally (BSDF, EDF, opacity) so we must
// declare this as a single closure color type in order for renderers
// to understand this output.
emitLine("output closure color " + singleOutput->getVariable() + " = 0", stage, false);
}
else
{
// Just emit all outputs the way they are declared.
emitShaderOutputs(outputs, stage);
}
emitShaderOutputs(outputs, stage);

// End shader signature
emitScopeEnd(stage);
Expand Down Expand Up @@ -178,29 +162,11 @@ ShaderPtr OslShaderGenerator::generate(const string& name, ElementPtr element, G
}
}

// Emit final outputs
if (isSurfaceShaderOutput)
{
// Special case for having 'surfaceshader' as final output type.
// This type is a struct internally (BSDF, EDF, opacity) so we must
// convert this to a single closure color type in order for renderers
// to understand this output.
const ShaderGraphOutputSocket* socket = graph.getOutputSocket(0);
const string result = getUpstreamResult(socket, context);
emitScopeBegin(stage);
emitLine("float opacity_weight = clamp(" + result + ".opacity, 0.0, 1.0)", stage);
emitLine(singleOutput->getVariable() + " = (" + result + ".bsdf + " + result + ".edf) * opacity_weight + transparent() * (1.0 - opacity_weight)", stage);
emitScopeEnd(stage);
}
else
// Assign results to final outputs.
for (size_t i = 0; i < outputs.size(); ++i)
{
// Assign results to final outputs.
for (size_t i = 0; i < outputs.size(); ++i)
{
const ShaderGraphOutputSocket* outputSocket = graph.getOutputSocket(i);
const string result = getUpstreamResult(outputSocket, context);
emitLine(outputSocket->getVariable() + " = " + result, stage);
}
const ShaderGraphOutputSocket* outputSocket = graph.getOutputSocket(i);
emitLine(outputSocket->getVariable() + " = " + getUpstreamResult(outputSocket, context), stage);
}

// End shader body
Expand Down Expand Up @@ -249,6 +215,20 @@ ShaderPtr OslShaderGenerator::createShader(const string& name, ElementPtr elemen
{
// Create the root shader graph
ShaderGraphPtr graph = ShaderGraph::create(nullptr, name, element, context);

// Special handling for surfaceshader type output - if we have a material
// that outputs a single surfaceshader then we will implicitly add a surfacematerial
// node to create the final closure color - the surfaceshader type is a struct and needs
// flattening to a single closure in the surfacematerial node.
const auto& outputSockets = graph->getOutputSockets();
const auto* singleOutput = outputSockets.size() == 1 ? outputSockets[0] : NULL;

const bool isSurfaceShaderOutput = singleOutput && singleOutput->getType() == Type::SURFACESHADER;
if (isSurfaceShaderOutput)
{
graph->inlineNodeBeforeOutput(outputSockets[0], "_surfacematerial_", "ND_surfacematerial", "surfaceshader", "out", context);
}

ShaderPtr shader = std::make_shared<Shader>(name, graph);

// Create our stage.
Expand Down
98 changes: 89 additions & 9 deletions source/MaterialXGenShader/ShaderGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ ShaderGraphPtr ShaderGraph::create(const ShaderGraph* parent, const string& name
}

// Apply color and unit transforms to each input.
graph->applyInputTransforms(node, newNode, context);
graph->applyInputTransforms(node, newNode.get(), context);

// Set root for upstream dependency traversal
root = node;
Expand All @@ -651,7 +651,7 @@ ShaderGraphPtr ShaderGraph::create(const ShaderGraph* parent, const string& name
return graph;
}

void ShaderGraph::applyInputTransforms(ConstNodePtr node, ShaderNodePtr shaderNode, GenContext& context)
void ShaderGraph::applyInputTransforms(ConstNodePtr node, ShaderNode* shaderNode, GenContext& context)
{
ColorManagementSystemPtr colorManagementSystem = context.getShaderGenerator().getColorManagementSystem();
UnitSystemPtr unitSystem = context.getShaderGenerator().getUnitSystem();
Expand Down Expand Up @@ -701,20 +701,29 @@ void ShaderGraph::applyInputTransforms(ConstNodePtr node, ShaderNodePtr shaderNo
}
}

ShaderNode* ShaderGraph::createNode(ConstNodePtr node, GenContext& context)
ShaderNode* ShaderGraph::createNode(const string& name, ConstNodeDefPtr nodeDef, GenContext& context)
{
NodeDefPtr nodeDef = node->getNodeDef();
if (!nodeDef)
{
throw ExceptionShaderGenError("Could not find a nodedef for node '" + node->getName() + "'");
throw ExceptionShaderGenError("Could not find a nodedef for node '" + name + "'");
}

// Create this node in the graph.
ShaderNodePtr newNode = ShaderNode::create(this, name, *nodeDef, context);
_nodeMap[name] = newNode;
_nodeOrder.push_back(newNode.get());

return newNode.get();
}

ShaderNode* ShaderGraph::createNode(ConstNodePtr node, GenContext& context)
{
ConstNodeDefPtr nodeDef = node->getNodeDef();

// Create this node in the graph.
context.pushParentNode(node);
ShaderNodePtr newNode = ShaderNode::create(this, node->getName(), *nodeDef, context);
ShaderNode* newNode = createNode(node->getName(), nodeDef, context);
newNode->initialize(*node, *nodeDef, context);
_nodeMap[node->getName()] = newNode;
_nodeOrder.push_back(newNode.get());
context.popParentNode();

// Check if any of the node inputs should be connected to the graph interface
Expand Down Expand Up @@ -757,7 +766,78 @@ ShaderNode* ShaderGraph::createNode(ConstNodePtr node, GenContext& context)
// Apply color and unit transforms to each input.
applyInputTransforms(node, newNode, context);

return newNode.get();
return newNode;
}

// Insert a new node between the output of the graph and its upstream connection, reconnecting the upstream to the specified input on
// the new node, if present.
ShaderNode* ShaderGraph::inlineNodeBeforeOutput(ShaderGraphOutputSocket* output, const std::string& newNodeName, const std::string& nodeDefName, const std::string& inputName, const std::string& outputName, GenContext& context)
{
auto nodeDef = _document->getNodeDef(nodeDefName);
if (!nodeDef)
{
throw ExceptionShaderGenError("Cannot find NodeDef '"+nodeDefName+"' when inserting node '"+newNodeName+"'");
}

// Check to see if the nodedef has the specified input/output ports
OutputPtr nodeDefOutput = nodeDef->getOutput(outputName);
if (!nodeDefOutput)
{
throw ExceptionShaderGenError("Output '"+outputName+"' not found on NodeDef '"+nodeDefName+"'");
}

InputPtr nodeDefInput = nullptr;
if (!inputName.empty())
{
// Only look for the input if we are given an inputName. It's valid to insert the node
// without any upstream connection, and so an input name is not required.
nodeDefInput = nodeDef->getInput(inputName);
if (!nodeDefInput)
{
throw ExceptionShaderGenError("Input '"+inputName+"' not found on NodeDef '"+nodeDefName+"'");
}
}

// record the previously connected upstream
auto originalUpstream = output->getConnection();

if (nodeDefInput && originalUpstream)
{
// if we're going to attempt to connect these - we need to check the types match
// we do this before creating any new data
if (nodeDefInput->getType() != originalUpstream->getType().getName())
{
throw ExceptionShaderGenError("Cannot connect ports of mismatched types '"+nodeDefInput->getType()+"' and '"+originalUpstream->getType().getName()+"' when inserting node");
}
}

// create the new node, and connect its output to the provided graph output
auto newNode = createNode(newNodeName, nodeDef, context);
if (!newNode)
{
throw ExceptionShaderGenError("Error while creating node '"+newNodeName+"' of type '"+nodeDefName+"'");
}
auto newNodeOutput = newNode->getOutput(outputName);
newNodeOutput->setVariable(newNodeOutput->getFullName());
output->makeConnection(newNodeOutput);

// update the type of the graph output port to match the new node output
output->setType(newNodeOutput->getType());

// if there was an original upstream node connected to graph output
// and we found the named input port - which means we were given a inputName (it was not empty string)
// connect this to the new node at the provided input name.
if (originalUpstream && nodeDefInput)
{
// update the variable name for the input and connect the original upstream
// we already validated the types match above.
auto newNodeInput = newNode->getInput(inputName);
newNodeInput->setVariable(newNodeInput->getFullName());

originalUpstream->makeConnection(newNodeInput);
}

return newNode;
}

ShaderGraphInputSocket* ShaderGraph::addInputSocket(const string& name, TypeDesc type)
Expand Down
14 changes: 13 additions & 1 deletion source/MaterialXGenShader/ShaderGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,18 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode
const vector<ShaderGraphOutputSocket*>& getOutputSockets() const { return _inputOrder; }

/// Apply color and unit transforms to each input of a node.
void applyInputTransforms(ConstNodePtr node, ShaderNodePtr shaderNode, GenContext& context);
void applyInputTransforms(ConstNodePtr node, ShaderNode* shaderNode, GenContext& context);

/// Create a new node in the graph
ShaderNode* createNode(ConstNodePtr node, GenContext& context);

ShaderNode* inlineNodeBeforeOutput(ShaderGraphOutputSocket* output,
const std::string& newNodeName,
const std::string& nodeDefName,
const std::string& inputName,
const std::string& outputName,
GenContext& context);

/// Add input sockets
ShaderGraphInputSocket* addInputSocket(const string& name, TypeDesc type);
[[deprecated]] ShaderGraphInputSocket* addInputSocket(const string& name, const TypeDesc* type) { return addInputSocket(name, *type); }
Expand Down Expand Up @@ -129,6 +136,11 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode
ElementPtr connectingElement,
GenContext& context);

/// Create a new node in a graph from a node definition.
/// Note - this does not initialize the node instance with any concrete values, but
/// instead creates an empty instance of the provided node definition
ShaderNode* createNode(const string& name, ConstNodeDefPtr nodeDef, GenContext& context);

/// Add a node to the graph
void addNode(ShaderNodePtr node);

Expand Down