Skip to content

Commit

Permalink
feat: Support Starcoder2 (#20)
Browse files Browse the repository at this point in the history
* feat: Support Starcoder2
  • Loading branch information
maiqingqiang committed Mar 8, 2024
1 parent e876e18 commit a94bf79
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 2 deletions.
5 changes: 5 additions & 0 deletions Libraries/LLM/Configuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public enum ModelType: String, Codable {
case phi
case gemma
case qwen2
case starcoder2

func createModel(configuration: URL) throws -> LLMModel {
switch self {
Expand All @@ -51,6 +52,10 @@ public enum ModelType: String, Codable {
let configuration = try JSONDecoder().decode(
Qwen2Configuration.self, from: Data(contentsOf: configuration))
return Qwen2Model(configuration)
case .starcoder2:
let configuration = try JSONDecoder().decode(
Starcoder2Configuration.self, from: Data(contentsOf: configuration))
return Starcoder2Model(configuration)
}
}
}
Expand Down
266 changes: 266 additions & 0 deletions Libraries/LLM/Starcoder2.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
//
// Starcoder2.swift
// LLM
//
// Created by John Mai on 2024/3/7.
//

import Foundation
import MLX
import MLXNN

// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/starcoder2.py

private class Attention: Module {
let args: Starcoder2Configuration
let repeats: Int
let scale: Float

@ModuleInfo(key: "q_proj") var wq: Linear
@ModuleInfo(key: "k_proj") var wk: Linear
@ModuleInfo(key: "v_proj") var wv: Linear
@ModuleInfo(key: "o_proj") var wo: Linear

let rope: RoPE

public init(_ args: Starcoder2Configuration) {
self.args = args

let dim = args.hiddenSize
let heads = args.attentionHeads
let kvHeads = args.kvHeads

self.repeats = heads / kvHeads

let headDim = args.hiddenSize / heads
self.scale = pow(Float(headDim), -0.5)

_wq.wrappedValue = Linear(dim, heads * headDim, bias: true)
_wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true)
_wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true)
_wo.wrappedValue = Linear(heads * headDim, dim, bias: true)

self.rope = RoPE(dimensions: headDim, traditional: false, base: args.ropeTheta)
}

public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
let (B, L) = (x.dim(0), x.dim(1))

var queries = wq(x)
var keys = wk(x)
var values = wv(x)

// prepare the queries, keys and values for the attention computation
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)

if repeats > 1 {
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
values = MLXArray.repeat(values, count: repeats, axis: 1)
}

if let (keyCache, valueCache) = cache {
queries = rope(queries, offset: keyCache.dim(2))
keys = rope(keys, offset: keyCache.dim(2))
keys = concatenated([keyCache, keys], axis: 2)
values = concatenated([valueCache, values], axis: 2)
} else {
queries = rope(queries)
keys = rope(keys)
}

var scores = (queries * scale).matmul(keys.transposed(0, 1, 3, 2))
if let mask {
scores = scores + mask
}

scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype)

let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)

return (wo(output), (keys, values))
}
}

private class MLP: Module, UnaryLayer {
@ModuleInfo(key: "c_fc") var cFc: Linear
@ModuleInfo(key: "c_proj") var cProj: Linear

public init(dimensions: Int, hiddenDimensions: Int) {
_cFc.wrappedValue = Linear(dimensions, hiddenDimensions, bias: true)
_cProj.wrappedValue = Linear(hiddenDimensions, dimensions, bias: true)
}

public func callAsFunction(_ x: MLXArray) -> MLXArray {
cProj(gelu(cFc(x)))
}
}

private class TransformerBlock: Module {
@ModuleInfo(key: "self_attn") var attention: Attention
let mlp: MLP

@ModuleInfo(key: "input_layernorm") var inputLayerNorm: LayerNorm
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: LayerNorm

public init(_ args: Starcoder2Configuration) {
_attention.wrappedValue = Attention(args)
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
_inputLayerNorm.wrappedValue = LayerNorm(
dimensions: args.hiddenSize, eps: args.normEpsilon)
_postAttentionLayerNorm.wrappedValue = LayerNorm(
dimensions: args.hiddenSize, eps: args.normEpsilon)
}

public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
let h = x + r
r = mlp(postAttentionLayerNorm(h))
let out = h + r
return (out, cache)
}
}

public class Starcoder2ModelInner: Module {
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding

fileprivate let layers: [TransformerBlock]
let norm: LayerNorm

public init(_ args: Starcoder2Configuration) {
precondition(args.vocabularySize > 0)

_embedTokens.wrappedValue = Embedding(
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)

self.layers = (0 ..< args.hiddenLayers)
.map { _ in
TransformerBlock(args)
}
self.norm = LayerNorm(dimensions: args.hiddenSize, eps: args.normEpsilon)
}

public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
var h = embedTokens(inputs)

var mask: MLXArray? = nil
if h.dim(1) > 1 {
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
mask = mask?.asType(h.dtype)
}

var newCache = [(MLXArray, MLXArray)]()

for (i, layer) in layers.enumerated() {
var cacheUpdate: (MLXArray, MLXArray)
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
newCache.append(cacheUpdate)
}

return (norm(h), newCache)
}
}

public class Starcoder2Model: Module, LLMModel {
public var vocabularySize: Int

public let tieWordEmbeddings: Bool
let model: Starcoder2ModelInner

@ModuleInfo(key: "lm_head") var lmHead: Linear

public init(_ args: Starcoder2Configuration) {
self.vocabularySize = args.vocabularySize
self.model = Starcoder2ModelInner(args)
self.tieWordEmbeddings = args.tieWordEmbeddings
if !self.tieWordEmbeddings {
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
}
}

public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
var (out, cache) = model(inputs, cache: cache)

if !tieWordEmbeddings {
return (lmHead(out), cache)
} else {
out = matmul(out, model.embedTokens.weight.T)
return (out, cache)
}
}
}

public struct Starcoder2Configuration: Codable {
var hiddenSize: Int
var hiddenLayers: Int
var intermediateSize: Int
var attentionHeads: Int
var kvHeads: Int
var maxPositionEmbeddings: Int = 16384
var normEpsilon: Float = 1e-5
var normType: String = "layer_norm"
var vocabularySize: Int = 49152
var ropeTheta: Float = 100000
var tieWordEmbeddings: Bool = true

enum CodingKeys: String, CodingKey {
case hiddenSize = "hidden_size"
case hiddenLayers = "num_hidden_layers"
case intermediateSize = "intermediate_size"
case attentionHeads = "num_attention_heads"
case kvHeads = "num_key_value_heads"
case maxPositionEmbeddings = "max_position_embeddings"
case normEpsilon = "norm_epsilon"
case normType = "norm_type"
case vocabularySize = "vocab_size"
case ropeTheta = "rope_theta"
case tieWordEmbeddings = "tie_word_embeddings"
}

public init(from decoder: Decoder) throws {
// custom implementation to handle optional keys with required values
let container: KeyedDecodingContainer<Starcoder2Configuration.CodingKeys> =
try decoder.container(
keyedBy: Starcoder2Configuration.CodingKeys.self)

self.hiddenSize = try container.decode(
Int.self, forKey: Starcoder2Configuration.CodingKeys.hiddenSize)
self.hiddenLayers = try container.decode(
Int.self, forKey: Starcoder2Configuration.CodingKeys.hiddenLayers)
self.intermediateSize = try container.decode(
Int.self, forKey: Starcoder2Configuration.CodingKeys.intermediateSize)
self.attentionHeads = try container.decode(
Int.self, forKey: Starcoder2Configuration.CodingKeys.attentionHeads)
self.kvHeads = try container.decode(
Int.self, forKey: Starcoder2Configuration.CodingKeys.kvHeads)
self.maxPositionEmbeddings =
try container.decodeIfPresent(
Int.self, forKey: Starcoder2Configuration.CodingKeys.maxPositionEmbeddings) ?? 16384
self.normEpsilon =
try container.decodeIfPresent(
Float.self, forKey: Starcoder2Configuration.CodingKeys.normEpsilon) ?? 1e-5
self.normType =
try container.decodeIfPresent(
String.self, forKey: Starcoder2Configuration.CodingKeys.normType) ?? "layer_norm"
self.vocabularySize =
try container.decodeIfPresent(
Int.self, forKey: Starcoder2Configuration.CodingKeys.vocabularySize) ?? 49152
self.ropeTheta =
try container.decodeIfPresent(
Float.self, forKey: Starcoder2Configuration.CodingKeys.ropeTheta)
?? 100000
self.tieWordEmbeddings =
try container.decodeIfPresent(
Bool.self, forKey: Starcoder2Configuration.CodingKeys.tieWordEmbeddings)
?? true
}
}
4 changes: 4 additions & 0 deletions mlx-swift-examples.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
objects = {

/* Begin PBXBuildFile section */
525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */; };
52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; };
C3288D762B6D9313009FF608 /* LinearModelTraining.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3288D752B6D9313009FF608 /* LinearModelTraining.swift */; };
C3288D7B2B6D9339009FF608 /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C3288D7A2B6D9339009FF608 /* ArgumentParser */; };
Expand Down Expand Up @@ -181,6 +182,7 @@
/* End PBXCopyFilesBuildPhase section */

/* Begin PBXFileReference section */
525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Starcoder2.swift; sourceTree = "<group>"; };
52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = "<group>"; };
C325DE3F2B648CDB00628871 /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
C3288D732B6D9313009FF608 /* LinearModelTraining */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = LinearModelTraining; sourceTree = BUILT_PRODUCTS_DIR; };
Expand Down Expand Up @@ -354,6 +356,7 @@
C38935C62B869C7A0037B833 /* LLM */ = {
isa = PBXGroup;
children = (
525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */,
C34E48EF2B696E6500FCB841 /* Configuration.swift */,
C3A8B3AB2B9283150002EFB8 /* Models.swift */,
C34E48EE2B696E6500FCB841 /* Llama.swift */,
Expand Down Expand Up @@ -826,6 +829,7 @@
C38935E12B869F420037B833 /* LLMModel.swift in Sources */,
C38935E32B86C0FE0037B833 /* Gemma.swift in Sources */,
C38935CD2B869C870037B833 /* Configuration.swift in Sources */,
525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */,
C38935DF2B869DD00037B833 /* Phi.swift in Sources */,
C38935CE2B869C870037B833 /* Load.swift in Sources */,
C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,17 @@
isEnabled = "NO">
</CommandLineArgument>
<CommandLineArgument
argument = "--model mlx-community/Qwen1.5-0.5B-Chat-4bit"
argument = "--model mlx-community/starcoder2-3b-4bit"
isEnabled = "YES">
</CommandLineArgument>
<CommandLineArgument
argument = "--prompt &apos;func sortArray(_ array: [Int]) -&gt; String { &lt;FILL_ME&gt; }&apos;"
argument = "--model mlx-community/Qwen1.5-0.5B-Chat-4bit"
isEnabled = "NO">
</CommandLineArgument>
<CommandLineArgument
argument = "--prompt &apos;def quick_sort(arr, left=None, right=None):&apos;"
isEnabled = "YES">
</CommandLineArgument>
<CommandLineArgument
argument = "--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
isEnabled = "NO">
Expand Down

0 comments on commit a94bf79

Please sign in to comment.