diff --git a/formatters/html/html.go b/formatters/html/html.go
index 52577da2f..9d7ddfed6 100644
--- a/formatters/html/html.go
+++ b/formatters/html/html.go
@@ -443,13 +443,37 @@ func (f *Formatter) styleToCSS(style *chroma.Style) map[chroma.TokenType]string
if t != chroma.Background {
entry = entry.Sub(bg)
}
- if !f.allClasses && entry.IsZero() && f.customCSS[t] == `` {
+
+ // Inherit from custom CSS provided by user
+ tokenCategory := t.Category()
+ tokenSubCategory := t.SubCategory()
+ if t != tokenCategory {
+ if css, ok := f.customCSS[tokenCategory]; ok {
+ classes[t] = css
+ }
+ }
+ if tokenCategory != tokenSubCategory {
+ if css, ok := f.customCSS[tokenSubCategory]; ok {
+ classes[t] += css
+ }
+ }
+ // Add custom CSS provided by user
+ if css, ok := f.customCSS[t]; ok {
+ classes[t] += css
+ }
+
+ if !f.allClasses && entry.IsZero() && classes[t] == `` {
continue
}
- classes[t] = f.customCSS[t] + StyleEntryToCSS(entry)
+
+ styleEntryCSS := StyleEntryToCSS(entry)
+ if styleEntryCSS != `` {
+ styleEntryCSS += `;`
+ }
+ classes[t] = styleEntryCSS + classes[t]
}
classes[chroma.Background] += f.tabWidthStyle()
- classes[chroma.PreWrapper] += classes[chroma.Background] + `;`
+ classes[chroma.PreWrapper] += classes[chroma.Background]
// Make PreWrapper a grid to show highlight style with full width.
if len(f.highlightRanges) > 0 && f.customCSS[chroma.PreWrapper] == `` {
classes[chroma.PreWrapper] += `display: grid;`
diff --git a/formatters/html/html_test.go b/formatters/html/html_test.go
index 0f133474f..d815e99e8 100644
--- a/formatters/html/html_test.go
+++ b/formatters/html/html_test.go
@@ -120,6 +120,21 @@ func TestWithCustomCSS(t *testing.T) {
assert.Regexp(t, `echo FOO`, buf.String())
}
+func TestWithCustomCSSStyleInheritance(t *testing.T) {
+ f := New(WithClasses(false), WithCustomCSS(map[chroma.TokenType]string{
+ chroma.String: `background: blue;`,
+ chroma.LiteralStringDouble: `color: tomato;`,
+ }))
+ it, err := lexers.Get("bash").Tokenise(nil, `echo "FOO"`)
+ assert.NoError(t, err)
+
+ var buf bytes.Buffer
+ err = f.Format(&buf, styles.Fallback, it)
+ assert.NoError(t, err)
+
+ assert.Regexp(t, ` "FOO"`, buf.String())
+}
+
func TestWrapLongLines(t *testing.T) {
f := New(WithClasses(false), WrapLongLines(true))
it, err := lexers.Get("go").Tokenise(nil, "package main\nfunc main()\n{\nprintln(\"hello world\")\n}\n")