Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Sources/Generation/Decoders.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import CoreML

// MARK: Greedy Decoding

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
func selectNextTokenUsingGreedyDecoding(from scores: MLTensor) -> MLTensor {
let indices = scores.argmax(alongAxis: -1).reshaped(to: [1, 1])
// Ensure indices are Int32 for concatenation with input tokens
Expand All @@ -19,7 +19,7 @@ func selectNextTokenUsingGreedyDecoding(from scores: MLTensor) -> MLTensor {
///
/// - Parameter scores: Processed logits tensor [batch_size, vocab_size]
/// - Returns: Sampled token ID tensor [batch_size, 1]
@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
func selectNextTokenUsingSampling(from scores: MLTensor) -> MLTensor {
// Convert logits to probabilities
let probs = scores.softmax(alongAxis: -1)
Expand Down
8 changes: 4 additions & 4 deletions Sources/Generation/Generation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public typealias GenerationOutput = [Int]
/// - Parameter tokens: Input token sequence
/// - Parameter config: Generation configuration
/// - Returns: Logits array for next token prediction
@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public typealias NextTokenModel = (MLTensor, GenerationConfig) async -> MLTensor

/// Callback for receiving generated tokens during streaming.
Expand All @@ -48,7 +48,7 @@ public typealias PredictionTokensCallback = (GenerationOutput) -> Void
public typealias PredictionStringCallback = (String) -> Void

/// Protocol for text generation implementations.
@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public protocol Generation {
/// Generates text from a prompt string.
///
Expand All @@ -62,7 +62,7 @@ public protocol Generation {
func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback?) async -> String
}

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
extension Generation {
public func generate(
config: GenerationConfig,
Expand Down Expand Up @@ -162,7 +162,7 @@ extension Generation {
}
}

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public extension Generation {
/// Performs greedy or sampling-based text generation based on generation configuration.
///
Expand Down
4 changes: 2 additions & 2 deletions Sources/Generation/LogitsWarper/LogitsProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import CoreML
/// such as temperature scaling, top-k/top-p filtering, and repetition penalties.
///
/// Based on: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py
@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public protocol LogitsProcessor {
/// Processes logits for next token prediction.
///
Expand All @@ -28,7 +28,7 @@ public protocol LogitsProcessor {
/// This class provides a convenient way to chain multiple logits processors together.
/// Each processor is applied in order to the logits tensor, with the output of one
/// processor becoming the input to the next.
@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public struct LogitsProcessorList {
public var processors: [any LogitsProcessor]

Expand Down
2 changes: 1 addition & 1 deletion Sources/Generation/LogitsWarper/MinPLogitsWarper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import CoreML
///
/// Based on:
/// - https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L460
@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public struct MinPLogitsWarper: LogitsProcessor {
public let minP: Float
public let minTokensToKeep: Int
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public enum LogitsProcessorError: Error {
/// Based on:
/// - https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L297
/// - Paper: https://arxiv.org/abs/1909.05858
@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public struct RepetitionPenaltyLogitsProcessor: LogitsProcessor {
public let penalty: Float

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import CoreML
/// Often used together with `TopPLogitsWarper` and `TopKLogitsWarper`.
///
/// Based on: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L231
@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public struct TemperatureLogitsWarper: LogitsProcessor {
public let temperature: Float

Expand Down
2 changes: 1 addition & 1 deletion Sources/Generation/LogitsWarper/TopKLogitsWarper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import CoreML
/// Pro tip: In practice, LLMs use top_k in the 5-50 range.
///
/// Based on: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L532
@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public struct TopKLogitsWarper: LogitsProcessor {
public let topK: Int
public let filterValue: Float
Expand Down
2 changes: 1 addition & 1 deletion Sources/Generation/LogitsWarper/TopPLogitsWarper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import CoreML
/// Based on:
/// - https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L465
/// - Paper: https://arxiv.org/abs/1904.09751 (Nucleus Sampling)
@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public struct TopPLogitsWarper: LogitsProcessor {
public let topP: Float
public let filterValue: Float
Expand Down
20 changes: 10 additions & 10 deletions Sources/Models/LanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import Generation
import Hub
import Tokenizers

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
/// A high-level interface for language model inference using CoreML.
///
/// `LanguageModel` provides a convenient way to load and interact with pre-trained
Expand Down Expand Up @@ -72,7 +72,7 @@ public class LanguageModel {
}
}

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
private extension LanguageModel {
static func contextRange(from model: MLModel) -> (min: Int, max: Int) {
contextRange(from: model, inputKey: Keys.inputIds)
Expand Down Expand Up @@ -109,7 +109,7 @@ private extension LanguageModel {
}
}

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
extension LanguageModel {
struct Configurations {
var modelConfig: Config
Expand All @@ -118,7 +118,7 @@ extension LanguageModel {
}
}

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
extension LanguageModel {
enum Keys {
// Input keys
Expand All @@ -135,7 +135,7 @@ extension LanguageModel {
}
}

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public extension LanguageModel {
/// Loads a compiled CoreML model from disk.
///
Expand All @@ -155,7 +155,7 @@ public extension LanguageModel {
}
}

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
extension LanguageModel {
enum KVCacheAvailability {
/// Language models that support KV cache via state. Implementation details for handling state
Expand All @@ -167,7 +167,7 @@ extension LanguageModel {
}
}

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public extension LanguageModel {
/// Metadata fields associated to the Core ML model.
var metadata: [MLModelMetadataKey: Any] {
Expand Down Expand Up @@ -296,7 +296,7 @@ public extension LanguageModel {
// MARK: - Configuration Properties

/// Asynchronous properties that are downloaded from the Hugging Face Hub configuration.
@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public extension LanguageModel {
/// The model configuration dictionary.
///
Expand Down Expand Up @@ -402,7 +402,7 @@ public extension LanguageModel {

// MARK: - TextGenerationModel Conformance

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
extension LanguageModel: TextGenerationModel {
/// The default generation configuration for this model.
///
Expand All @@ -424,7 +424,7 @@ extension LanguageModel: TextGenerationModel {
///
/// Maintains a KV Cache as sequence generation progresses,
/// using stateful Core ML buffers to minimize latency.
@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public class LanguageModelWithStatefulKVCache: LanguageModel {
private enum Mode {
case prefilling
Expand Down
8 changes: 4 additions & 4 deletions Sources/Models/LanguageModelTypes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import Tokenizers
///
/// This protocol establishes the fundamental requirements for any language model
/// that can perform next-token prediction and text generation tasks.
@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public protocol LanguageModelProtocol {
/// The name or path of the model.
///
Expand Down Expand Up @@ -50,7 +50,7 @@ public protocol LanguageModelProtocol {
func predictNextTokenScores(_ input: MLTensor, config: GenerationConfig) async -> MLTensor
}

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public extension LanguageModelProtocol {
/// Function call syntax for next token prediction.
///
Expand All @@ -69,7 +69,7 @@ public extension LanguageModelProtocol {
///
/// This protocol extends `LanguageModelProtocol` and `Generation` to provide
/// high-level text generation functionality with configurable parameters.
@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public protocol TextGenerationModel: Generation, LanguageModelProtocol {
/// The default generation configuration for this model.
///
Expand All @@ -92,7 +92,7 @@ public protocol TextGenerationModel: Generation, LanguageModelProtocol {
) async throws -> String
}

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
public extension TextGenerationModel {
/// Default implementation of text generation that uses the underlying generation framework.
///
Expand Down
4 changes: 2 additions & 2 deletions Tests/GenerationTests/GenerationIntegrationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import XCTest

@testable import Generation

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
final class GenerationIntegrationTests: XCTestCase {

// MARK: - Mock Model for Testing
Expand Down Expand Up @@ -343,7 +343,7 @@ final class GenerationIntegrationTests: XCTestCase {

// MARK: - Test Helper

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
struct TestGeneration: Generation {
func generate(
config: GenerationConfig,
Expand Down
4 changes: 2 additions & 2 deletions Tests/GenerationTests/LogitsProcessorTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import XCTest

@testable import Generation

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
final class LogitsProcessorTests: XCTestCase {
private let accuracy: Float = 0.0001

Expand Down Expand Up @@ -319,7 +319,7 @@ final class LogitsProcessorTests: XCTestCase {

// MARK: - Test Helpers

@available(macOS 15.0, iOS 18.0, *)
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)
func assertMLTensorEqual(
_ tensor: MLTensor,
expected: [Float],
Expand Down