add guesslang tf model

This commit is contained in:
Chris W 2023-10-30 19:25:30 -06:00
parent 1ca24c58e5
commit 4ee081df66
12 changed files with 708 additions and 170 deletions

1
.gitignore vendored Normal file → Executable file
View File

@ -1 +1,2 @@
/target
venv

733
Cargo.lock generated

File diff suppressed because it is too large Load Diff

3
Cargo.toml Normal file → Executable file
View File

@ -10,6 +10,7 @@ actix-web = "4"
silicon = { git = "https://github.com/watzon/silicon.git" }
lazy_static = "1.4.0"
serde = { version = "1.0.130", features = ["derive"] }
serde_json = "1.0.107"
structopt = "0.3.26"
image = "0.24.7"
anyhow = "1.0.75"
@ -19,3 +20,5 @@ font-kit = "0.11.0"
reqwest = "0.11.22"
hyperpolyglot = "0.1.7"
tempfile = "3.8.0"
tensorflow = "0.17.0"
clap = { version = "4.4.7", features = ["derive"] }

52
Dockerfile Normal file → Executable file
View File

@ -1,7 +1,44 @@
FROM debian:buster-slim as tensorflow
WORKDIR /usr/src/build
# Install dependencies
RUN apt-get update && apt-get install -y \
git \
wget \
gnupg \
python3 \
python3-dev \
python3-pip \
python3-numpy \
llvm \
clang
RUN pip3 install wheel packaging requests opt_einsum
RUN pip3 install keras_preprocessing --no-deps
# Install bazel
RUN wget https://github.com/bazelbuild/bazelisk/releases/download/v1.18.0/bazelisk-linux-amd64
RUN chmod +x bazelisk-linux-amd64
RUN mv bazelisk-linux-amd64 /usr/local/bin/bazel
# Install tensorflow
RUN git clone https://github.com/tensorflow/tensorflow \
&& cd tensorflow \
&& git checkout v2.5.0
RUN cd tensorflow && ./configure
RUN cd tensorflow && bazel build --compilation_mode=opt --copt=-march=native --jobs=12 tensorflow:libtensorflow.so
FROM rust:1.73.0-buster as builder
WORKDIR /usr/src/app
# Copy tensorflow shared libraries from tensorflow image
COPY --from=tensorflow /usr/src/build/tensorflow/bazel-bin/tensorflow/libtensorflow.so* /usr/local/lib/
COPY --from=tensorflow /usr/src/build/tensorflow/bazel-bin/tensorflow/libtensorflow_framework.so* /usr/local/lib/
RUN ldconfig
# Install dependencies
RUN apt-get update && apt-get install -y \
libssl-dev \
@ -30,6 +67,8 @@ RUN bash ./download_nerd_fonts.sh
FROM debian:buster-slim
WORKDIR /usr/src/app
# Install dependencies
RUN apt-get update && apt-get install -y \
libssl-dev \
@ -41,8 +80,15 @@ RUN apt-get update && apt-get install -y \
COPY --from=fonts /data/fonts/nerd_fonts/* /usr/share/fonts/truetype/
RUN fc-cache -fv
# Copy binary
COPY --from=builder /usr/src/app/target/release/inkify /usr/local/bin/inkify
# Copy binary abd tensorflow model files
COPY --from=builder /usr/src/app/target/release/inkify /usr/src/app/
COPY --from=builder /usr/src/app/src/tensorflow /usr/src/app/tensorflow/
# Copy tensorflow shared libraries from tensorflow image
COPY --from=tensorflow /usr/src/build/tensorflow/bazel-bin/tensorflow/libtensorflow.so* /usr/local/lib/
COPY --from=tensorflow /usr/src/build/tensorflow/bazel-bin/tensorflow/libtensorflow_framework.so* /usr/local/lib/
RUN ldconfig
ARG PORT=8080
ARG HOST=0.0.0.0
@ -53,4 +99,4 @@ ENV HOST=$HOST
EXPOSE $PORT
# Run
ENTRYPOINT ["/usr/local/bin/inkify"]
CMD ["/usr/src/app/inkify", "--tensorflow-model-dir", "/usr/src/app/tensorflow"]

0
LICENSE Normal file → Executable file
View File

0
README.md Normal file → Executable file
View File

73
src/config.rs Normal file → Executable file
View File

@ -1,17 +1,17 @@
use anyhow::Error;
use silicon::formatter::{ImageFormatter, ImageFormatterBuilder};
use silicon::utils::{Background, ShadowAdder};
use std::io::Write;
use std::path::PathBuf;
use syntect::highlighting::{Theme, ThemeSet};
use syntect::parsing::{SyntaxReference, SyntaxSet};
use tensorflow::{Graph, SavedModelBundle, SessionOptions, Tensor};
use crate::rgba::{ImageRgba, Rgba};
type FontList = Vec<(String, f32)>;
type Lines = Vec<u32>;
#[derive(Debug, Clone, serde::Deserialize)]
#[derive(Debug, serde::Deserialize)]
pub struct Config {
/// Background image URL
pub background_image: Option<Vec<u8>>,
@ -72,6 +72,12 @@ pub struct Config {
/// The syntax highlight theme. It can be a theme name or path to a .tmTheme file.
pub theme: String,
#[serde(skip_deserializing)]
pub tf_model_graph: Option<Graph>,
#[serde(skip_deserializing)]
pub tf_model: Option<SavedModelBundle>,
}
impl Config {
@ -97,9 +103,25 @@ impl Config {
shadow_offset_x: 0,
tab_width: 4,
theme: "Dracula".to_owned(),
tf_model_graph: None,
tf_model: None,
}
}
pub fn load_tensorflow_model(&mut self, export_dir: &str) {
let mut graph = Graph::new();
let model = match SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir) {
Ok(model) => model,
Err(e) => {
eprintln!("Failed to load TensorFlow model: {}", e);
return;
}
};
self.tf_model = Some(model);
self.tf_model_graph = Some(graph);
}
pub fn language<'a>(&self, ps: &'a SyntaxSet) -> Result<&'a SyntaxReference, Error> {
let language = match &self.language {
Some(language) => ps
@ -108,21 +130,52 @@ impl Config {
None => {
let first_line = self.code.lines().next().unwrap_or_default();
ps.find_syntax_by_first_line(first_line).unwrap_or_else(|| {
// hyperpolyglot requires a file, so we need to create a temp file
let mut temp_file = tempfile::NamedTempFile::new().unwrap();
write!(temp_file, "{}", self.code).unwrap();
let language = hyperpolyglot::detect(temp_file.path()).unwrap();
match language {
Some(language) => ps.find_syntax_by_token(language.language()).unwrap(),
None => ps.find_syntax_by_token("log").unwrap(),
}
// Try using tensorflow to detect the language
let input_data = Tensor::new(&[1]).with_values(&[self.code.clone()]).unwrap();
self.predict_language_with_tensorflow(ps, input_data)
.unwrap_or_else(|_| ps.find_syntax_by_token("log").unwrap())
})
},
};
Ok(language)
}
pub fn predict_language_with_tensorflow<'a>(&self, ps: &'a SyntaxSet, input_data: Tensor<String>) -> Result<&'a SyntaxReference, Error> {
if self.tf_model_graph.is_none() || self.tf_model.is_none() {
return Err(Error::msg("TensorFlow model not loaded"));
}
let graph = self.tf_model_graph.as_ref().unwrap();
let model = self.tf_model.as_ref().unwrap();
let mut args = tensorflow::SessionRunArgs::new();
let input_tensor = graph.operation_by_name_required("Placeholder")?;
let output_tensor_scores = graph.operation_by_name_required("head/predictions/probabilities")?;
let output_tensor_classes = graph.operation_by_name_required("head/Tile")?;
args.add_feed(&input_tensor, 0, &input_data);
let output_token_scores = args.request_fetch(&output_tensor_scores, 0);
let output_token_classes = args.request_fetch(&output_tensor_classes, 0);
model.session.run(&mut args)?;
let scores: Tensor<f32> = args.fetch(output_token_scores)?;
let classes: Tensor<String> = args.fetch(output_token_classes)?;
// Find the index of the highest score
let max_index = scores.iter().enumerate().max_by(|a, b| a.1.partial_cmp(b.1).unwrap()).unwrap().0;
let language = &classes[max_index];
let language = ps.find_syntax_by_token(language).unwrap();
Ok(language)
}
pub fn theme(&self, ts: &ThemeSet) -> Result<Theme, Error> {
if let Some(theme) = ts.themes.get(&self.theme) {
Ok(theme.clone())

14
src/main.rs Normal file → Executable file
View File

@ -1,6 +1,7 @@
#[macro_use]
extern crate anyhow;
use clap::Parser;
use actix_web::{get, web, App, HttpResponse, HttpServer, Responder};
use anyhow::Error;
use lazy_static::lazy_static;
@ -28,6 +29,13 @@ macro_rules! unwrap_or_return {
};
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct CliArgs {
#[arg(short, long)]
tensorflow_model_dir: Option<String>,
}
fn parse_font_str(s: &str) -> Vec<(String, f32)> {
let mut result = vec![];
for font in s.split(';') {
@ -144,6 +152,7 @@ async fn fonts() -> impl Responder {
#[get("/generate")]
async fn generate(info: web::Query<config::ConfigQuery>) -> impl Responder {
let args = CliArgs::parse();
let ha = &*HIGHLIGHTING_ASSETS;
let (ps, ts) = (&ha.syntax_set, &ha.theme_set);
@ -156,6 +165,10 @@ async fn generate(info: web::Query<config::ConfigQuery>) -> impl Responder {
.body(r#"{"error": "code parameter is required"}"#);
}
if args.tensorflow_model_dir.is_some() {
conf.load_tensorflow_model(args.tensorflow_model_dir.unwrap().as_str());
}
conf.language = info.language.clone();
if let Some(theme) = info.theme.clone() {
conf.theme = theme;
@ -227,6 +240,7 @@ async fn generate(info: web::Query<config::ConfigQuery>) -> impl Responder {
.append_header(("Content-Type", "application/json"))
.body(r#"{"error": "Unable to determine language, please provide one explicitly"}"#)
);
let theme = unwrap_or_return!(
conf.theme(ts),
HttpResponse::BadRequest()

0
src/rgba.rs Normal file → Executable file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.